package typeobj import ( "reflect" "testing" "github.com/davecgh/go-spew/spew" "gopkg.in/yaml.v2" ) type foo struct { A int `yaml:"a"` } type bar struct { B int `yaml:"b"` } type outer struct { Foo foo `type:"foo"` Bar *bar `type:"bar"` Other string `yaml:"other_field,omitempty"` } func (o outer) MarshalYAML() (interface{}, error) { return MarshalYAML(o) } func (o *outer) UnmarshalYAML(unmarshal func(interface{}) error) error { return UnmarshalYAML(o, unmarshal) } func TestTypeObj(t *testing.T) { type test struct { descr string str string err bool other string obj outer typeTag string elem interface{} } tests := []test{ { descr: "no type set", str: `{}`, err: true, }, { descr: "unknown type set", str: "type: baz", err: true, }, { descr: "foo set", str: "type: foo\na: 1\n", obj: outer{Foo: foo{A: 1}}, typeTag: "foo", elem: foo{A: 1}, }, { descr: "bar set", str: "type: bar\nb: 1\n", obj: outer{Bar: &bar{B: 1}}, typeTag: "bar", elem: &bar{B: 1}, }, { descr: "foo and other_field set", str: "type: foo\na: 1\nother_field: aaa\n", obj: outer{Foo: foo{A: 1}, Other: "aaa"}, typeTag: "foo", elem: foo{A: 1}, }, } for _, test := range tests { t.Run(test.descr, func(t *testing.T) { var o outer err := yaml.Unmarshal([]byte(test.str), &o) if test.err && err != nil { return } else if test.err && err == nil { t.Fatal("expected error when unmarshaling but there was none") } else if !test.err && err != nil { t.Fatalf("unmarshaling %q returned unexpected error: %v", test.str, err) } if !reflect.DeepEqual(o, test.obj) { t.Fatalf("test expected value:\n%s\nbut got value:\n%s", spew.Sprint(test.obj), spew.Sprint(o)) } elem, typeTag, err := Element(o) if err != nil { t.Fatalf("error when calling Element(%s): %v", spew.Sprint(o), err) } else if !reflect.DeepEqual(elem, test.elem) { t.Fatalf("test expected elem value:\n%s\nbut got value:\n%s", spew.Sprint(test.elem), spew.Sprint(elem)) } else if typeTag != test.typeTag { t.Fatalf("test expected typeTag value %q but got %q", test.typeTag, typeTag) } b, err := yaml.Marshal(o) if err != nil { t.Fatalf("error marshaling %s: %v", spew.Sprint(o), err) } else if test.str != string(b) { t.Fatalf("test expected to marshal to %q, but instead marshaled to %q", test.str, b) } }) } }