Allow marshaling/unmarshaling zero value keys

This commit is contained in:
Brian Picciano 2024-12-07 20:36:29 +01:00
parent 2e92081e07
commit 54cebcad53
5 changed files with 71 additions and 11 deletions

View File

@ -2,6 +2,7 @@ package nebula
import (
"fmt"
"reflect"
"github.com/slackhq/nebula/cert"
)
@ -37,14 +38,22 @@ func (c Certificate) Unwrap() *cert.NebulaCertificate {
// MarshalText implements the encoding.TextMarshaler interface.
func (c Certificate) MarshalText() ([]byte, error) {
if reflect.DeepEqual(c, Certificate{}) {
return []byte(""), nil
}
return c.inner.MarshalToPEM()
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (c *Certificate) UnmarshalText(b []byte) error {
if len(b) == 0 {
*c = Certificate{}
return nil
}
nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b)
if err != nil {
return err
return fmt.Errorf("unmarshaling nebula certificate from PEM: %w", err)
}
c.inner = *nebCrt
return nil

View File

@ -21,19 +21,31 @@ type EncryptingPublicKey struct{ inner *ecdh.PublicKey }
// MarshalText implements the encoding.TextMarshaler interface.
func (pk EncryptingPublicKey) MarshalText() ([]byte, error) {
return encodeWithPrefix(encPubKeyPrefix, pk.inner.Bytes()), nil
if pk == (EncryptingPublicKey{}) {
return []byte(""), nil
}
return encodeWithPrefix(encPubKeyPrefix, pk.Bytes()), nil
}
// Bytes returns the raw bytes of the EncryptingPublicKey.
// Bytes returns the raw bytes of the EncryptingPublicKey, or nil if it is the
// zero value.
func (k EncryptingPublicKey) Bytes() []byte {
if k == (EncryptingPublicKey{}) {
return nil
}
return k.inner.Bytes()
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*pk = EncryptingPublicKey{}
return nil
}
b, err := decodeWithPrefix(encPubKeyPrefix, b)
if err != nil {
return fmt.Errorf("unmarshaling: %w", err)
return fmt.Errorf("unmarshaling encrypting public key: %w", err)
}
if pk.inner, err = x25519.NewPublicKey(b); err != nil {
@ -48,7 +60,7 @@ func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error {
b, _, err := cert.UnmarshalX25519PublicKey(b)
if err != nil {
return fmt.Errorf("unmarshaling: %w", err)
return fmt.Errorf("unmarshaling nebula PEM as encrypting public key: %w", err)
}
if pk.inner, err = x25519.NewPublicKey(b); err != nil {
@ -86,19 +98,31 @@ func (k EncryptingPrivateKey) PublicKey() EncryptingPublicKey {
// MarshalText implements the encoding.TextMarshaler interface.
func (k EncryptingPrivateKey) MarshalText() ([]byte, error) {
return encodeWithPrefix(encPrivKeyPrefix, k.inner.Bytes()), nil
if k == (EncryptingPrivateKey{}) {
return []byte(""), nil
}
return encodeWithPrefix(encPrivKeyPrefix, k.Bytes()), nil
}
// Bytes returns the raw bytes of the EncryptingPrivateKey.
// Bytes returns the raw bytes of the EncryptingPrivateKey, or nil if it is the
// zero value.
func (k EncryptingPrivateKey) Bytes() []byte {
if k == (EncryptingPrivateKey{}) {
return nil
}
return k.inner.Bytes()
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*k = EncryptingPrivateKey{}
return nil
}
b, err := decodeWithPrefix(encPrivKeyPrefix, b)
if err != nil {
return fmt.Errorf("unmarshaling: %w", err)
return fmt.Errorf("unmarshaling encrypting private key: %w", err)
}
if k.inner, err = x25519.NewPrivateKey(b); err != nil {

View File

@ -1,6 +1,7 @@
package nebula
import (
"bytes"
"crypto"
"crypto/ed25519"
"crypto/rand"
@ -47,13 +48,23 @@ func Sign[T any](v T, k SigningPrivateKey) (Signed[T], error) {
return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)})
}
var jsonNull = []byte("null")
// MarshalJSON implements the json.Marshaler interface.
func (s Signed[T]) MarshalJSON() ([]byte, error) {
if s == nil {
return jsonNull, nil
}
return []byte(s), nil
}
// UnmarshalJSON implements the json.Unmarshaler interface.
func (s *Signed[T]) UnmarshalJSON(b []byte) error {
if bytes.Equal(b, jsonNull) {
*s = nil
return nil
}
*s = b
return nil
}

View File

@ -34,7 +34,7 @@ func TestSigned(t *testing.T) {
_, err = signedB.Unwrap(hostPubCredsB.SigningKey)
if !errors.Is(err, ErrInvalidSignature) {
t.Fatalf("expected ErrInvalidSignature but got %v", err)
t.Fatalf("expected ErrInvalidSignature but got: %v", err)
}
b, err := signedB.Unwrap(hostPubCredsA.SigningKey)

View File

@ -17,14 +17,22 @@ type SigningPrivateKey ed25519.PrivateKey
// MarshalText implements the encoding.TextMarshaler interface.
func (k SigningPrivateKey) MarshalText() ([]byte, error) {
if k == nil {
return []byte(""), nil
}
return encodeWithPrefix(sigPrivKeyPrefix, k), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (k *SigningPrivateKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*k = SigningPrivateKey{}
return nil
}
b, err := decodeWithPrefix(sigPrivKeyPrefix, b)
if err != nil {
return fmt.Errorf("unmarshaling: %w", err)
return fmt.Errorf("unmarshaling signing private key: %w", err)
}
*k = SigningPrivateKey(b)
@ -45,14 +53,22 @@ type SigningPublicKey ed25519.PublicKey
// MarshalText implements the encoding.TextMarshaler interface.
func (pk SigningPublicKey) MarshalText() ([]byte, error) {
if pk == nil {
return []byte(""), nil
}
return encodeWithPrefix(sigPubKeyPrefix, pk), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
func (pk *SigningPublicKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*pk = SigningPublicKey{}
return nil
}
b, err := decodeWithPrefix(sigPubKeyPrefix, b)
if err != nil {
return fmt.Errorf("unmarshaling: %w", err)
return fmt.Errorf("unmarshaling signing public key: %w", err)
}
*pk = SigningPublicKey(b)