package nebula

import (
	"fmt"
	"net"
	"net/netip"
)

// IPNet is the CIDR of a nebula network.
type IPNet net.IPNet

// UnmarshalText parses and validates an IPNet from a text string.
func (n *IPNet) UnmarshalText(b []byte) error {
	str := string(b)
	_, subnet, err := net.ParseCIDR(str)
	if err != nil {
		return err
	}

	if cstr := subnet.String(); cstr != str {
		return fmt.Errorf("IPNet is not given in its canonical form of %q", cstr)
	}

	if !subnet.IP.IsPrivate() {
		return fmt.Errorf("IPNet is not in a private IP range")
	}

	*n = IPNet(*subnet)
	return nil
}

func (n IPNet) String() string {
	return (*net.IPNet)(&n).String()
}

func (n IPNet) MarshalText() ([]byte, error) {
	return []byte(n.String()), nil
}

// FirstAddr returns the first IP address in the subnet.
func (n IPNet) FirstAddr() netip.Addr {
	return netip.MustParseAddr(n.IP.String()).Next()
}