diff --git a/commit.go b/commit.go index fbf418c..61ca0b3 100644 --- a/commit.go +++ b/commit.go @@ -31,7 +31,7 @@ func (proj *Project) GetCommit(h plumbing.Hash) (c Commit, err error) { } else if c.TreeObject, err = proj.GitRepo.TreeObject(c.Object.TreeHash); err != nil { return c, fmt.Errorf("getting git tree object %q: %w", c.Object.TreeHash, err) - } else if c.Payload.UnmarshalText([]byte(c.Object.Message)); err != nil { + } else if err = c.Payload.UnmarshalText([]byte(c.Object.Message)); err != nil { return c, fmt.Errorf("decoding commit message: %w", err) } c.Hash = c.Object.Hash diff --git a/payload.go b/payload.go index b04fd27..03a6758 100644 --- a/payload.go +++ b/payload.go @@ -4,7 +4,6 @@ import ( "bytes" "errors" "fmt" - "reflect" "sort" "strings" "time" @@ -155,12 +154,7 @@ func (p *PayloadUnion) UnmarshalText(msg []byte) error { if err := yaml.Unmarshal(msgBody, p); err != nil { return fmt.Errorf("unmarshaling commit payload from yaml: %w", err) - - } else if reflect.DeepEqual(*p, PayloadUnion{}) { - // a basic check, but worthwhile - return errors.New("commit message is malformed, could not unmarshal yaml object") } - return nil } diff --git a/typeobj/typeobj.go b/typeobj/typeobj.go index debe79a..af12c2e 100644 --- a/typeobj/typeobj.go +++ b/typeobj/typeobj.go @@ -30,6 +30,26 @@ func parseTag(tag string) tagInfo { } } +// structTypeWithYAMLTags takes a type of kind struct and returns that same +// type, except all fields with a "type" tag will also have a `yaml:"-"` tag +// attached. +func structTypeWithYAMLTags(typ reflect.Type) (reflect.Type, error) { + n := typ.NumField() + outFields := make([]reflect.StructField, n) + for i := 0; i < n; i++ { + field := typ.Field(i) + hasTypeTag := field.Tag.Get("type") != "" + if hasTypeTag && field.Tag.Get("yaml") != "" { + return nil, fmt.Errorf("field %s has yaml tag and type tag", field.Name) + } else if hasTypeTag { + field.Tag += ` yaml:"-"` + } + outFields[i] = field + } + + return reflect.StructOf(outFields), nil +} + func findTypeField(val reflect.Value, targetTypeTag string) (reflect.Value, reflect.StructField, error) { typ := val.Type() @@ -60,31 +80,39 @@ func findTypeField(val reflect.Value, targetTypeTag string) (reflect.Value, refl // data into that inner field. func UnmarshalYAML(i interface{}, unmarshal func(interface{}) error) error { val := reflect.Indirect(reflect.ValueOf(i)) - if !val.CanSet() { - return fmt.Errorf("cannot unmarshal into value of type %T", i) + if !val.CanSet() || val.Kind() != reflect.Struct { + return fmt.Errorf("cannot unmarshal into value of type %T: must be a struct pointer", i) + } + + // create a copy of the struct type, with `yaml:"-"` tags added to all + // fields with `type:"..."` tags. If we didn't do this then there would be + // conflicts in the next step if a type field's name was the same as one of + // its inner field names. + valTypeCP, err := structTypeWithYAMLTags(val.Type()) + if err != nil { + return fmt.Errorf("cannot unmarshal into value of type %T: %w", i, err) } // unmarshal in all non-typeobj fields. construct a type which wraps the // given one, hiding its UnmarshalYAML method (if it has one), and unmarshal // onto that directly. The "type" field is also unmarshaled at this stage. valWrap := reflect.New(reflect.StructOf([]reflect.StructField{ - reflect.StructField{ - Name: "Type", - Type: typeOfString, - Tag: `yaml:"type"`, - }, - { - Name: "Val", - Type: val.Type(), - Tag: `yaml:",inline"`, - }, + {Name: "Type", Type: typeOfString, Tag: `yaml:"type"`}, + {Name: "Val", Type: valTypeCP, Tag: `yaml:",inline"`}, })) if err := unmarshal(valWrap.Interface()); err != nil { return err } - // necessary to set non-type fields into the original value - val.Set(valWrap.Elem().Field(1)) + // set non-type fields into the original value + valWrapInnerVal := valWrap.Elem().Field(1) + for i := 0; i < valWrapInnerVal.NumField(); i++ { + fieldVal, fieldTyp := valWrapInnerVal.Field(i), valTypeCP.Field(i) + if fieldTyp.Tag.Get("type") != "" { + continue + } + val.Field(i).Set(fieldVal) + } typeVal := valWrap.Elem().Field(0).String() fieldVal, fieldTyp, err := findTypeField(val, typeVal) diff --git a/typeobj/typeobj_test.go b/typeobj/typeobj_test.go index 9b9e794..7761f52 100644 --- a/typeobj/typeobj_test.go +++ b/typeobj/typeobj_test.go @@ -17,9 +17,15 @@ type bar struct { B int `yaml:"b"` } +// baz has a field of the same name as the type, which is tricky +type baz struct { + Baz int `yaml:"baz"` +} + type outer struct { Foo foo `type:"foo"` Bar *bar `type:"bar"` + Baz baz `type:"baz"` Other string `yaml:"other_field,omitempty"` } @@ -81,7 +87,7 @@ func TestTypeObj(t *testing.T) { }, { descr: "invalid type value", - str: "type: baz", + str: "type: INVALID", expErr: "invalid type value", expObj: outer{}, }, @@ -106,6 +112,13 @@ func TestTypeObj(t *testing.T) { expTypeTag: "foo", expElem: foo{A: 1}, }, + { + descr: "type is same as field name", + str: "type: baz\nbaz: 3", + expObj: outer{Baz: baz{Baz: 3}}, + expTypeTag: "baz", + expElem: baz{Baz: 3}, + }, } for _, test := range tests {