// Package typeobj implements a set of utility functions intended to be used on // union structs whose fields are tagged with the "type" tag and which expect // only one of the fields to be set. For example: // // type OuterType struct { // A *InnerTypeA `type:"a"` // B *InnerTypeB `type:"b"` // C *InnerTypeC `type:"c"` // } // package typeobj import ( "errors" "fmt" "reflect" ) // UnmarshalYAML is intended to be used within the UnmarshalYAML method of a // union struct. It will use the given input data's "type" field and match that // to the struct field tagged with that value. it will then unmarshal the input // data into that inner field. func UnmarshalYAML(i interface{}, unmarshal func(interface{}) error) error { val := reflect.Indirect(reflect.ValueOf(i)) if !val.CanSet() { return fmt.Errorf("cannot unmarshal into value of type %T", i) } // unmarshal in all non-typeobj fields. construct a type which wraps the // given one, hiding its UnmarshalYAML method (if it has one), and unmarshal // onto that directly. The "type" field is also unmarshaled at this stage. valWrap := reflect.New(reflect.StructOf([]reflect.StructField{ reflect.StructField{ Name: "Type", Type: typeOfString, Tag: `yaml:"type"`, }, { Name: "Val", Type: val.Type(), Tag: `yaml:",inline"`, }, })) if err := unmarshal(valWrap.Interface()); err != nil { return err } typeVal := valWrap.Elem().Field(0).String() val.Set(valWrap.Elem().Field(1)) typ := val.Type() for i := 0; i < val.NumField(); i++ { fieldVal, fieldTyp := val.Field(i), typ.Field(i) if fieldTyp.Tag.Get("type") != typeVal { continue } var valInto interface{} if fieldVal.Kind() == reflect.Ptr { newFieldVal := reflect.New(fieldTyp.Type.Elem()) fieldVal.Set(newFieldVal) valInto = newFieldVal.Interface() } else { valInto = fieldVal.Addr().Interface() } return unmarshal(valInto) } return fmt.Errorf("invalid type value %q", typeVal) } // val should be of kind struct func element(val reflect.Value) (reflect.Value, string, []int, error) { typ := val.Type() numFields := val.NumField() var fieldVal reflect.Value var typeTag string nonTypeFields := make([]int, 0, numFields) for i := 0; i < numFields; i++ { innerFieldVal := val.Field(i) innerTypeTag := typ.Field(i).Tag.Get("type") if innerTypeTag == "" { nonTypeFields = append(nonTypeFields, i) } else if innerFieldVal.IsZero() { continue } else { fieldVal = innerFieldVal typeTag = innerTypeTag } } if fieldVal.IsZero() { return reflect.Value{}, "", nil, errors.New(`no non-zero fields tagged with "type"`) } return fieldVal, typeTag, nonTypeFields, nil } // Element returns the value of the first non-zero field tagged with "type", as // well as the value of the "type" tag. func Element(i interface{}) (interface{}, string, error) { val := reflect.Indirect(reflect.ValueOf(i)) fieldVal, tag, _, err := element(val) if err != nil { return fieldVal, tag, err } return fieldVal.Interface(), tag, nil } var typeOfString = reflect.TypeOf("string") // MarshalYAML is intended to be used within the MarshalYAML method of a union // struct. It will find the first field of the given struct which has a "type" // tag and is non-zero. It will then marshal that field's value, inlining an // extra YAML field "type" whose value is the value of the "type" tag on the // struct field, and return that. func MarshalYAML(i interface{}) (interface{}, error) { val := reflect.Indirect(reflect.ValueOf(i)) typ := val.Type() fieldVal, typeTag, nonTypeFields, err := element(val) if err != nil { return nil, err } fieldVal = reflect.Indirect(fieldVal) if fieldVal.Kind() != reflect.Struct { return nil, fmt.Errorf("cannot marshal non-struct type %T", fieldVal.Interface()) } structFields := make([]reflect.StructField, 0, len(nonTypeFields)+2) structFields = append(structFields, reflect.StructField{ Name: "Type", Type: typeOfString, Tag: `yaml:"type"`, }, reflect.StructField{ Name: "Val", Type: fieldVal.Type(), Tag: `yaml:",inline"`, }, ) nonTypeFieldVals := make([]reflect.Value, len(nonTypeFields)) for i, fieldIndex := range nonTypeFields { fieldVal, fieldType := val.Field(fieldIndex), typ.Field(fieldIndex) structFields = append(structFields, fieldType) nonTypeFieldVals[i] = fieldVal } outVal := reflect.New(reflect.StructOf(structFields)) outVal.Elem().Field(0).Set(reflect.ValueOf(typeTag)) outVal.Elem().Field(1).Set(fieldVal) for i, fieldVal := range nonTypeFieldVals { outVal.Elem().Field(2 + i).Set(fieldVal) } return outVal.Interface(), nil }