// Package typeobj implements a set of utility functions intended to be used on
// union structs whose fields are tagged with the "type" tag and which expect
// only one of the fields to be set. For example:
//
//	type OuterType struct {
//		A *InnerTypeA `type:"a"`
//		B *InnerTypeB `type:"b"`
//		C *InnerTypeC `type:"c"`
//	}
//
package typeobj

import (
	"errors"
	"fmt"
	"reflect"
)

// UnmarshalYAML is intended to be used within the UnmarshalYAML method of a
// union struct. It will use the given input data's "type" field and match that
// to the struct field tagged with that value. it will then unmarshal the input
// data into that inner field.
func UnmarshalYAML(i interface{}, unmarshal func(interface{}) error) error {
	val := reflect.Indirect(reflect.ValueOf(i))
	if !val.CanSet() {
		return fmt.Errorf("cannot unmarshal into value of type %T", i)
	}

	// unmarshal in all non-typeobj fields. construct a type which wraps the
	// given one, hiding its UnmarshalYAML method (if it has one), and unmarshal
	// onto that directly. The "type" field is also unmarshaled at this stage.
	valWrap := reflect.New(reflect.StructOf([]reflect.StructField{
		reflect.StructField{
			Name: "Type",
			Type: typeOfString,
			Tag:  `yaml:"type"`,
		},
		{
			Name: "Val",
			Type: val.Type(),
			Tag:  `yaml:",inline"`,
		},
	}))
	if err := unmarshal(valWrap.Interface()); err != nil {
		return err
	}
	typeVal := valWrap.Elem().Field(0).String()
	val.Set(valWrap.Elem().Field(1))

	typ := val.Type()
	for i := 0; i < val.NumField(); i++ {
		fieldVal, fieldTyp := val.Field(i), typ.Field(i)
		if fieldTyp.Tag.Get("type") != typeVal {
			continue
		}

		var valInto interface{}
		if fieldVal.Kind() == reflect.Ptr {
			newFieldVal := reflect.New(fieldTyp.Type.Elem())
			fieldVal.Set(newFieldVal)
			valInto = newFieldVal.Interface()
		} else {
			valInto = fieldVal.Addr().Interface()
		}
		return unmarshal(valInto)
	}

	return fmt.Errorf("invalid type value %q", typeVal)
}

// val should be of kind struct
func element(val reflect.Value) (reflect.Value, string, []int, error) {
	typ := val.Type()
	numFields := val.NumField()

	var fieldVal reflect.Value
	var typeTag string
	nonTypeFields := make([]int, 0, numFields)
	for i := 0; i < numFields; i++ {
		innerFieldVal := val.Field(i)
		innerTypeTag := typ.Field(i).Tag.Get("type")
		if innerTypeTag == "" {
			nonTypeFields = append(nonTypeFields, i)
		} else if innerFieldVal.IsZero() {
			continue
		} else {
			fieldVal = innerFieldVal
			typeTag = innerTypeTag
		}
	}

	if fieldVal.IsZero() {
		return reflect.Value{}, "", nil, errors.New(`no non-zero fields tagged with "type"`)
	}
	return fieldVal, typeTag, nonTypeFields, nil
}

// Element returns the value of the first non-zero field tagged with "type", as
// well as the value of the "type" tag.
func Element(i interface{}) (interface{}, string, error) {
	val := reflect.Indirect(reflect.ValueOf(i))
	fieldVal, tag, _, err := element(val)
	if err != nil {
		return fieldVal, tag, err
	}
	return fieldVal.Interface(), tag, nil
}

var typeOfString = reflect.TypeOf("string")

// MarshalYAML is intended to be used within the MarshalYAML method of a union
// struct. It will find the first field of the given struct which has a "type"
// tag and is non-zero. It will then marshal that field's value, inlining an
// extra YAML field "type" whose value is the value of the "type" tag on the
// struct field, and return that.
func MarshalYAML(i interface{}) (interface{}, error) {
	val := reflect.Indirect(reflect.ValueOf(i))
	typ := val.Type()
	fieldVal, typeTag, nonTypeFields, err := element(val)
	if err != nil {
		return nil, err
	}

	fieldVal = reflect.Indirect(fieldVal)
	if fieldVal.Kind() != reflect.Struct {
		return nil, fmt.Errorf("cannot marshal non-struct type %T", fieldVal.Interface())
	}

	structFields := make([]reflect.StructField, 0, len(nonTypeFields)+2)
	structFields = append(structFields,
		reflect.StructField{
			Name: "Type",
			Type: typeOfString,
			Tag:  `yaml:"type"`,
		},
		reflect.StructField{
			Name: "Val",
			Type: fieldVal.Type(),
			Tag:  `yaml:",inline"`,
		},
	)

	nonTypeFieldVals := make([]reflect.Value, len(nonTypeFields))
	for i, fieldIndex := range nonTypeFields {
		fieldVal, fieldType := val.Field(fieldIndex), typ.Field(fieldIndex)
		structFields = append(structFields, fieldType)
		nonTypeFieldVals[i] = fieldVal
	}

	outVal := reflect.New(reflect.StructOf(structFields))
	outVal.Elem().Field(0).Set(reflect.ValueOf(typeTag))
	outVal.Elem().Field(1).Set(fieldVal)
	for i, fieldVal := range nonTypeFieldVals {
		outVal.Elem().Field(2 + i).Set(fieldVal)
	}

	return outVal.Interface(), nil
}