diff --git a/typeobj/typeobj.go b/typeobj/typeobj.go index 6b6044d..fed9225 100644 --- a/typeobj/typeobj.go +++ b/typeobj/typeobj.go @@ -14,8 +14,46 @@ import ( "errors" "fmt" "reflect" + "strings" ) +type tagInfo struct { + val string + isDefault bool +} + +func parseTag(tag string) tagInfo { + parts := strings.Split(tag, ",") + return tagInfo{ + val: parts[0], + isDefault: len(parts) > 1 && parts[1] == "default", + } +} + +func findTypeField(val reflect.Value, targetTypeTag string) (reflect.Value, reflect.StructField, error) { + typ := val.Type() + + var defVal reflect.Value + var defTyp reflect.StructField + var defOk bool + for i := 0; i < val.NumField(); i++ { + fieldVal, fieldTyp := val.Field(i), typ.Field(i) + tagInfo := parseTag(fieldTyp.Tag.Get("type")) + if targetTypeTag != "" && tagInfo.val == targetTypeTag { + return fieldVal, fieldTyp, nil + } else if targetTypeTag == "" && tagInfo.isDefault { + defVal, defTyp, defOk = fieldVal, fieldTyp, true + } + } + + if targetTypeTag == "" && defOk { + return defVal, defTyp, nil + } else if targetTypeTag == "" { + return reflect.Value{}, reflect.StructField{}, errors.New("type field not set") + } + return reflect.Value{}, reflect.StructField{}, fmt.Errorf("invalid type value %q", targetTypeTag) +} + // UnmarshalYAML is intended to be used within the UnmarshalYAML method of a // union struct. It will use the given input data's "type" field and match that // to the struct field tagged with that value. it will then unmarshal the input @@ -44,28 +82,25 @@ func UnmarshalYAML(i interface{}, unmarshal func(interface{}) error) error { if err := unmarshal(valWrap.Interface()); err != nil { return err } - typeVal := valWrap.Elem().Field(0).String() - val.Set(valWrap.Elem().Field(1)) - typ := val.Type() - for i := 0; i < val.NumField(); i++ { - fieldVal, fieldTyp := val.Field(i), typ.Field(i) - if fieldTyp.Tag.Get("type") != typeVal { - continue - } + // necessary to set non-type fields into the original value + val.Set(valWrap.Elem().Field(1)) - var valInto interface{} - if fieldVal.Kind() == reflect.Ptr { - newFieldVal := reflect.New(fieldTyp.Type.Elem()) - fieldVal.Set(newFieldVal) - valInto = newFieldVal.Interface() - } else { - valInto = fieldVal.Addr().Interface() - } - return unmarshal(valInto) + typeVal := valWrap.Elem().Field(0).String() + fieldVal, fieldTyp, err := findTypeField(val, typeVal) + if err != nil { + return err } - return fmt.Errorf("invalid type value %q", typeVal) + var valInto interface{} + if fieldVal.Kind() == reflect.Ptr { + newFieldVal := reflect.New(fieldTyp.Type.Elem()) + fieldVal.Set(newFieldVal) + valInto = newFieldVal.Interface() + } else { + valInto = fieldVal.Addr().Interface() + } + return unmarshal(valInto) } // val should be of kind struct @@ -78,14 +113,14 @@ func element(val reflect.Value) (reflect.Value, string, []int, error) { nonTypeFields := make([]int, 0, numFields) for i := 0; i < numFields; i++ { innerFieldVal := val.Field(i) - innerTypeTag := typ.Field(i).Tag.Get("type") - if innerTypeTag == "" { + innerTagInfo := parseTag(typ.Field(i).Tag.Get("type")) + if innerTagInfo.val == "" { nonTypeFields = append(nonTypeFields, i) } else if innerFieldVal.IsZero() { continue } else { fieldVal = innerFieldVal - typeTag = innerTypeTag + typeTag = innerTagInfo.val } } diff --git a/typeobj/typeobj_test.go b/typeobj/typeobj_test.go index 0939578..9b9e794 100644 --- a/typeobj/typeobj_test.go +++ b/typeobj/typeobj_test.go @@ -2,6 +2,7 @@ package typeobj import ( "reflect" + "strings" "testing" "github.com/davecgh/go-spew/spew" @@ -31,83 +32,124 @@ func (o *outer) UnmarshalYAML(unmarshal func(interface{}) error) error { return UnmarshalYAML(o, unmarshal) } +type outerWDefault struct { + Foo foo `type:"foo,default"` + Bar *bar `type:"bar"` +} + +func (o outerWDefault) MarshalYAML() (interface{}, error) { + return MarshalYAML(o) +} + +func (o *outerWDefault) UnmarshalYAML(unmarshal func(interface{}) error) error { + return UnmarshalYAML(o, unmarshal) +} + func TestTypeObj(t *testing.T) { type test struct { descr string str string - err bool - other string - obj outer - typeTag string - elem interface{} + expErr string + expObj interface{} + expTypeTag string + expElem interface{} + expMarshalOut string // defaults to str } tests := []test{ { - descr: "no type set", - str: `{}`, - err: true, + descr: "no type set", + str: `{}`, + expErr: "type field not set", + expObj: outer{}, }, { - descr: "unknown type set", - str: "type: baz", - err: true, + descr: "no type set with nontype field", + str: `other_field: aaa`, + expErr: "type field not set", + expObj: outer{}, }, { - descr: "foo set", - str: "type: foo\na: 1\n", - obj: outer{Foo: foo{A: 1}}, - typeTag: "foo", - elem: foo{A: 1}, + descr: "no type set with default", + str: `a: 1`, + expObj: outerWDefault{Foo: foo{A: 1}}, + expTypeTag: "foo", + expElem: foo{A: 1}, + expMarshalOut: "type: foo\na: 1", }, { - descr: "bar set", - str: "type: bar\nb: 1\n", - obj: outer{Bar: &bar{B: 1}}, - typeTag: "bar", - elem: &bar{B: 1}, + descr: "invalid type value", + str: "type: baz", + expErr: "invalid type value", + expObj: outer{}, }, { - descr: "foo and other_field set", - str: "type: foo\na: 1\nother_field: aaa\n", - obj: outer{Foo: foo{A: 1}, Other: "aaa"}, - typeTag: "foo", - elem: foo{A: 1}, + descr: "foo set", + str: "type: foo\na: 1", + expObj: outer{Foo: foo{A: 1}}, + expTypeTag: "foo", + expElem: foo{A: 1}, + }, + { + descr: "bar set", + str: "type: bar\nb: 1", + expObj: outer{Bar: &bar{B: 1}}, + expTypeTag: "bar", + expElem: &bar{B: 1}, + }, + { + descr: "foo and other_field set", + str: "type: foo\na: 1\nother_field: aaa", + expObj: outer{Foo: foo{A: 1}, Other: "aaa"}, + expTypeTag: "foo", + expElem: foo{A: 1}, }, } for _, test := range tests { t.Run(test.descr, func(t *testing.T) { - var o outer - err := yaml.Unmarshal([]byte(test.str), &o) - if test.err && err != nil { + + intoV := reflect.New(reflect.TypeOf(test.expObj)) + + err := yaml.Unmarshal([]byte(test.str), intoV.Interface()) + if test.expErr != "" { + if err == nil || !strings.HasPrefix(err.Error(), test.expErr) { + t.Fatalf("expected error %q when unmarshaling but got: %v", test.expErr, err) + } return - } else if test.err && err == nil { - t.Fatal("expected error when unmarshaling but there was none") - } else if !test.err && err != nil { + } else if test.expErr == "" && err != nil { t.Fatalf("unmarshaling %q returned unexpected error: %v", test.str, err) } - if !reflect.DeepEqual(o, test.obj) { - t.Fatalf("test expected value:\n%s\nbut got value:\n%s", spew.Sprint(test.obj), spew.Sprint(o)) + into := intoV.Elem().Interface() + if !reflect.DeepEqual(into, test.expObj) { + t.Fatalf("test expected value:\n%s\nbut got value:\n%s", spew.Sprint(test.expObj), spew.Sprint(into)) } - elem, typeTag, err := Element(o) + elem, typeTag, err := Element(into) if err != nil { - t.Fatalf("error when calling Element(%s): %v", spew.Sprint(o), err) - } else if !reflect.DeepEqual(elem, test.elem) { - t.Fatalf("test expected elem value:\n%s\nbut got value:\n%s", spew.Sprint(test.elem), spew.Sprint(elem)) - } else if typeTag != test.typeTag { - t.Fatalf("test expected typeTag value %q but got %q", test.typeTag, typeTag) + t.Fatalf("error when calling Element(%s): %v", spew.Sprint(into), err) + } else if !reflect.DeepEqual(elem, test.expElem) { + t.Fatalf("test expected elem value:\n%s\nbut got value:\n%s", spew.Sprint(test.expElem), spew.Sprint(elem)) + } else if typeTag != test.expTypeTag { + t.Fatalf("test expected typeTag value %q but got %q", test.expTypeTag, typeTag) + } + + expMarshalOut := test.expMarshalOut + if expMarshalOut == "" { + expMarshalOut = test.str } + expMarshalOut = strings.TrimSpace(expMarshalOut) - b, err := yaml.Marshal(o) + b, err := yaml.Marshal(into) if err != nil { - t.Fatalf("error marshaling %s: %v", spew.Sprint(o), err) - } else if test.str != string(b) { - t.Fatalf("test expected to marshal to %q, but instead marshaled to %q", test.str, b) + t.Fatalf("error marshaling %s: %v", spew.Sprint(into), err) + } + marshalOut := strings.TrimSpace(string(b)) + if marshalOut != expMarshalOut { + t.Fatalf("test expected to marshal to %q, but instead marshaled to %q", expMarshalOut, marshalOut) } }) }