// Package typeobj implements a set of utility functions intended to be used on // union structs whose fields are tagged with the "type" tag and which expect // only one of the fields to be set. For example: // // type OuterType struct { // A *InnerTypeA `type:"a"` // B *InnerTypeB `type:"b"` // C *InnerTypeC `type:"c"` // } // package typeobj import ( "errors" "fmt" "reflect" "strings" ) type tagInfo struct { val string isDefault bool } func parseTag(tag string) tagInfo { parts := strings.Split(tag, ",") return tagInfo{ val: parts[0], isDefault: len(parts) > 1 && parts[1] == "default", } } // structTypeWithYAMLTags takes a type of kind struct and returns that same // type, except all fields with a "type" tag will also have a `yaml:"-"` tag // attached. func structTypeWithYAMLTags(typ reflect.Type) (reflect.Type, error) { n := typ.NumField() outFields := make([]reflect.StructField, n) for i := 0; i < n; i++ { field := typ.Field(i) hasTypeTag := field.Tag.Get("type") != "" if hasTypeTag && field.Tag.Get("yaml") != "" { return nil, fmt.Errorf("field %s has yaml tag and type tag", field.Name) } else if hasTypeTag { field.Tag += ` yaml:"-"` } outFields[i] = field } return reflect.StructOf(outFields), nil } func findTypeField(val reflect.Value, targetTypeTag string) (reflect.Value, reflect.StructField, error) { typ := val.Type() var defVal reflect.Value var defTyp reflect.StructField var defOk bool for i := 0; i < val.NumField(); i++ { fieldVal, fieldTyp := val.Field(i), typ.Field(i) tagInfo := parseTag(fieldTyp.Tag.Get("type")) if targetTypeTag != "" && tagInfo.val == targetTypeTag { return fieldVal, fieldTyp, nil } else if targetTypeTag == "" && tagInfo.isDefault { defVal, defTyp, defOk = fieldVal, fieldTyp, true } } if targetTypeTag == "" && defOk { return defVal, defTyp, nil } else if targetTypeTag == "" { return reflect.Value{}, reflect.StructField{}, errors.New("type field not set") } return reflect.Value{}, reflect.StructField{}, fmt.Errorf("invalid type value %q", targetTypeTag) } // UnmarshalYAML is intended to be used within the UnmarshalYAML method of a // union struct. It will use the given input data's "type" field and match that // to the struct field tagged with that value. it will then unmarshal the input // data into that inner field. func UnmarshalYAML(i interface{}, unmarshal func(interface{}) error) error { val := reflect.Indirect(reflect.ValueOf(i)) if !val.CanSet() || val.Kind() != reflect.Struct { return fmt.Errorf("cannot unmarshal into value of type %T: must be a struct pointer", i) } // create a copy of the struct type, with `yaml:"-"` tags added to all // fields with `type:"..."` tags. If we didn't do this then there would be // conflicts in the next step if a type field's name was the same as one of // its inner field names. valTypeCP, err := structTypeWithYAMLTags(val.Type()) if err != nil { return fmt.Errorf("cannot unmarshal into value of type %T: %w", i, err) } // unmarshal in all non-typeobj fields. construct a type which wraps the // given one, hiding its UnmarshalYAML method (if it has one), and unmarshal // onto that directly. The "type" field is also unmarshaled at this stage. valWrap := reflect.New(reflect.StructOf([]reflect.StructField{ {Name: "Type", Type: typeOfString, Tag: `yaml:"type"`}, {Name: "Val", Type: valTypeCP, Tag: `yaml:",inline"`}, })) if err := unmarshal(valWrap.Interface()); err != nil { return err } // set non-type fields into the original value valWrapInnerVal := valWrap.Elem().Field(1) for i := 0; i < valWrapInnerVal.NumField(); i++ { fieldVal, fieldTyp := valWrapInnerVal.Field(i), valTypeCP.Field(i) if fieldTyp.Tag.Get("type") != "" { continue } val.Field(i).Set(fieldVal) } typeVal := valWrap.Elem().Field(0).String() fieldVal, fieldTyp, err := findTypeField(val, typeVal) if err != nil { return err } var valInto interface{} if fieldVal.Kind() == reflect.Ptr { newFieldVal := reflect.New(fieldTyp.Type.Elem()) fieldVal.Set(newFieldVal) valInto = newFieldVal.Interface() } else { valInto = fieldVal.Addr().Interface() } return unmarshal(valInto) } // val should be of kind struct func element(val reflect.Value) (reflect.Value, string, []int, error) { typ := val.Type() numFields := val.NumField() var fieldVal reflect.Value var typeTag string nonTypeFields := make([]int, 0, numFields) for i := 0; i < numFields; i++ { innerFieldVal := val.Field(i) innerTagInfo := parseTag(typ.Field(i).Tag.Get("type")) if innerTagInfo.val == "" { nonTypeFields = append(nonTypeFields, i) } else if innerFieldVal.IsZero() { continue } else { fieldVal = innerFieldVal typeTag = innerTagInfo.val } } if !fieldVal.IsValid() { return reflect.Value{}, "", nil, errors.New(`no non-zero fields tagged with "type"`) } return fieldVal, typeTag, nonTypeFields, nil } // Element returns the value of the first non-zero field tagged with "type", as // well as the value of the "type" tag. func Element(i interface{}) (interface{}, string, error) { val := reflect.Indirect(reflect.ValueOf(i)) fieldVal, tag, _, err := element(val) if err != nil { return fieldVal, tag, err } return fieldVal.Interface(), tag, nil } var typeOfString = reflect.TypeOf("string") // MarshalYAML is intended to be used within the MarshalYAML method of a union // struct. It will find the first field of the given struct which has a "type" // tag and is non-zero. It will then marshal that field's value, inlining an // extra YAML field "type" whose value is the value of the "type" tag on the // struct field, and return that. func MarshalYAML(i interface{}) (interface{}, error) { val := reflect.Indirect(reflect.ValueOf(i)) typ := val.Type() fieldVal, typeTag, nonTypeFields, err := element(val) if err != nil { return nil, err } fieldVal = reflect.Indirect(fieldVal) if fieldVal.Kind() != reflect.Struct { return nil, fmt.Errorf("cannot marshal non-struct type %T", fieldVal.Interface()) } structFields := make([]reflect.StructField, 0, len(nonTypeFields)+2) structFields = append(structFields, reflect.StructField{ Name: "Type", Type: typeOfString, Tag: `yaml:"type"`, }, reflect.StructField{ Name: "Val", Type: fieldVal.Type(), Tag: `yaml:",inline"`, }, ) nonTypeFieldVals := make([]reflect.Value, len(nonTypeFields)) for i, fieldIndex := range nonTypeFields { fieldVal, fieldType := val.Field(fieldIndex), typ.Field(fieldIndex) structFields = append(structFields, fieldType) nonTypeFieldVals[i] = fieldVal } outVal := reflect.New(reflect.StructOf(structFields)) outVal.Elem().Field(0).Set(reflect.ValueOf(typeTag)) outVal.Elem().Field(1).Set(fieldVal) for i, fieldVal := range nonTypeFieldVals { outVal.Elem().Field(2 + i).Set(fieldVal) } return outVal.Interface(), nil }