package nebula

import (
	"crypto/ecdh"
	"crypto/rand"
	"fmt"

	"github.com/slackhq/nebula/cert"
)

var (
	encPrivKeyPrefix = []byte("x0")
	encPubKeyPrefix  = []byte("X0")

	x25519 = ecdh.X25519()
)

// EncryptingPublicKey wraps an X25519-based ECDH public key to provide
// convenient text (un)marshaling methods.
type EncryptingPublicKey struct{ inner *ecdh.PublicKey }

// MarshalText implements the encoding.TextMarshaler interface.
func (pk EncryptingPublicKey) MarshalText() ([]byte, error) {
	if pk == (EncryptingPublicKey{}) {
		return []byte(""), nil
	}
	return encodeWithPrefix(encPubKeyPrefix, pk.Bytes()), nil
}

// Bytes returns the raw bytes of the EncryptingPublicKey, or nil if it is the
// zero value.
func (k EncryptingPublicKey) Bytes() []byte {
	if k == (EncryptingPublicKey{}) {
		return nil
	}
	return k.inner.Bytes()
}

// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
	if len(b) == 0 {
		*pk = EncryptingPublicKey{}
		return nil
	}

	b, err := decodeWithPrefix(encPubKeyPrefix, b)
	if err != nil {
		return fmt.Errorf("unmarshaling encrypting public key: %w", err)
	}

	if pk.inner, err = x25519.NewPublicKey(b); err != nil {
		return fmt.Errorf("converting bytes to public key: %w", err)
	}

	return nil
}

// UnmarshalNebulaPEM unmarshals the EncryptingPublicKey as a nebula host public
// key PEM.
func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error {
	b, _, err := cert.UnmarshalX25519PublicKey(b)
	if err != nil {
		return fmt.Errorf("unmarshaling nebula PEM as encrypting public key: %w", err)
	}

	if pk.inner, err = x25519.NewPublicKey(b); err != nil {
		return fmt.Errorf("converting bytes to public key: %w", err)
	}

	return nil
}

func (pk EncryptingPublicKey) String() string {
	b, err := pk.MarshalText()
	if err != nil {
		panic(err)
	}
	return string(b)
}

// EncryptingPrivateKey wraps an X25519-based ECDH private key to provide
// convenient text (un)marshaling methods.
type EncryptingPrivateKey struct{ inner *ecdh.PrivateKey }

// NewEncryptingPrivateKey generates and returns a fresh EncryptingPrivateKey.
func NewEncryptingPrivateKey() EncryptingPrivateKey {
	k, err := x25519.GenerateKey(rand.Reader)
	if err != nil {
		panic(err)
	}
	return EncryptingPrivateKey{k}
}

// PublicKey returns the public key which corresponds with this private key.
func (k EncryptingPrivateKey) PublicKey() EncryptingPublicKey {
	return EncryptingPublicKey{k.inner.PublicKey()}
}

// MarshalText implements the encoding.TextMarshaler interface.
func (k EncryptingPrivateKey) MarshalText() ([]byte, error) {
	if k == (EncryptingPrivateKey{}) {
		return []byte(""), nil
	}
	return encodeWithPrefix(encPrivKeyPrefix, k.Bytes()), nil
}

// Bytes returns the raw bytes of the EncryptingPrivateKey, or nil if it is the
// zero value.
func (k EncryptingPrivateKey) Bytes() []byte {
	if k == (EncryptingPrivateKey{}) {
		return nil
	}
	return k.inner.Bytes()
}

// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error {
	if len(b) == 0 {
		*k = EncryptingPrivateKey{}
		return nil
	}

	b, err := decodeWithPrefix(encPrivKeyPrefix, b)
	if err != nil {
		return fmt.Errorf("unmarshaling encrypting private key: %w", err)
	}

	if k.inner, err = x25519.NewPrivateKey(b); err != nil {
		return fmt.Errorf("converting bytes to private key: %w", err)
	}

	return nil
}

func (k EncryptingPrivateKey) String() string {
	b, err := k.MarshalText()
	if err != nil {
		panic(err)
	}
	return string(b)
}