From 54cebcad53896e2b5438647c64e5ed1b2b176b1b Mon Sep 17 00:00:00 2001 From: Brian Picciano Date: Sat, 7 Dec 2024 20:36:29 +0100 Subject: [PATCH] Allow marshaling/unmarshaling zero value keys --- go/nebula/certificate.go | 11 ++++++++++- go/nebula/encrypting_key.go | 38 ++++++++++++++++++++++++++++++------- go/nebula/signed.go | 11 +++++++++++ go/nebula/signed_test.go | 2 +- go/nebula/signing_key.go | 20 +++++++++++++++++-- 5 files changed, 71 insertions(+), 11 deletions(-) diff --git a/go/nebula/certificate.go b/go/nebula/certificate.go index a5da459..45ef75d 100644 --- a/go/nebula/certificate.go +++ b/go/nebula/certificate.go @@ -2,6 +2,7 @@ package nebula import ( "fmt" + "reflect" "github.com/slackhq/nebula/cert" ) @@ -37,14 +38,22 @@ func (c Certificate) Unwrap() *cert.NebulaCertificate { // MarshalText implements the encoding.TextMarshaler interface. func (c Certificate) MarshalText() ([]byte, error) { + if reflect.DeepEqual(c, Certificate{}) { + return []byte(""), nil + } return c.inner.MarshalToPEM() } // UnmarshalText implements the encoding.TextUnmarshaler interface. func (c *Certificate) UnmarshalText(b []byte) error { + if len(b) == 0 { + *c = Certificate{} + return nil + } + nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b) if err != nil { - return err + return fmt.Errorf("unmarshaling nebula certificate from PEM: %w", err) } c.inner = *nebCrt return nil diff --git a/go/nebula/encrypting_key.go b/go/nebula/encrypting_key.go index 7e74912..c3da79c 100644 --- a/go/nebula/encrypting_key.go +++ b/go/nebula/encrypting_key.go @@ -21,19 +21,31 @@ type EncryptingPublicKey struct{ inner *ecdh.PublicKey } // MarshalText implements the encoding.TextMarshaler interface. func (pk EncryptingPublicKey) MarshalText() ([]byte, error) { - return encodeWithPrefix(encPubKeyPrefix, pk.inner.Bytes()), nil + if pk == (EncryptingPublicKey{}) { + return []byte(""), nil + } + return encodeWithPrefix(encPubKeyPrefix, pk.Bytes()), nil } -// Bytes returns the raw bytes of the EncryptingPublicKey. +// Bytes returns the raw bytes of the EncryptingPublicKey, or nil if it is the +// zero value. func (k EncryptingPublicKey) Bytes() []byte { + if k == (EncryptingPublicKey{}) { + return nil + } return k.inner.Bytes() } // UnmarshalText implements the encoding.TextUnmarshaler interface. func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error { + if len(b) == 0 { + *pk = EncryptingPublicKey{} + return nil + } + b, err := decodeWithPrefix(encPubKeyPrefix, b) if err != nil { - return fmt.Errorf("unmarshaling: %w", err) + return fmt.Errorf("unmarshaling encrypting public key: %w", err) } if pk.inner, err = x25519.NewPublicKey(b); err != nil { @@ -48,7 +60,7 @@ func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error { func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error { b, _, err := cert.UnmarshalX25519PublicKey(b) if err != nil { - return fmt.Errorf("unmarshaling: %w", err) + return fmt.Errorf("unmarshaling nebula PEM as encrypting public key: %w", err) } if pk.inner, err = x25519.NewPublicKey(b); err != nil { @@ -86,19 +98,31 @@ func (k EncryptingPrivateKey) PublicKey() EncryptingPublicKey { // MarshalText implements the encoding.TextMarshaler interface. func (k EncryptingPrivateKey) MarshalText() ([]byte, error) { - return encodeWithPrefix(encPrivKeyPrefix, k.inner.Bytes()), nil + if k == (EncryptingPrivateKey{}) { + return []byte(""), nil + } + return encodeWithPrefix(encPrivKeyPrefix, k.Bytes()), nil } -// Bytes returns the raw bytes of the EncryptingPrivateKey. +// Bytes returns the raw bytes of the EncryptingPrivateKey, or nil if it is the +// zero value. func (k EncryptingPrivateKey) Bytes() []byte { + if k == (EncryptingPrivateKey{}) { + return nil + } return k.inner.Bytes() } // UnmarshalText implements the encoding.TextUnmarshaler interface. func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error { + if len(b) == 0 { + *k = EncryptingPrivateKey{} + return nil + } + b, err := decodeWithPrefix(encPrivKeyPrefix, b) if err != nil { - return fmt.Errorf("unmarshaling: %w", err) + return fmt.Errorf("unmarshaling encrypting private key: %w", err) } if k.inner, err = x25519.NewPrivateKey(b); err != nil { diff --git a/go/nebula/signed.go b/go/nebula/signed.go index 0170740..044551b 100644 --- a/go/nebula/signed.go +++ b/go/nebula/signed.go @@ -1,6 +1,7 @@ package nebula import ( + "bytes" "crypto" "crypto/ed25519" "crypto/rand" @@ -47,13 +48,23 @@ func Sign[T any](v T, k SigningPrivateKey) (Signed[T], error) { return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)}) } +var jsonNull = []byte("null") + // MarshalJSON implements the json.Marshaler interface. func (s Signed[T]) MarshalJSON() ([]byte, error) { + if s == nil { + return jsonNull, nil + } return []byte(s), nil } // UnmarshalJSON implements the json.Unmarshaler interface. func (s *Signed[T]) UnmarshalJSON(b []byte) error { + if bytes.Equal(b, jsonNull) { + *s = nil + return nil + } + *s = b return nil } diff --git a/go/nebula/signed_test.go b/go/nebula/signed_test.go index afb6ec2..6692b56 100644 --- a/go/nebula/signed_test.go +++ b/go/nebula/signed_test.go @@ -34,7 +34,7 @@ func TestSigned(t *testing.T) { _, err = signedB.Unwrap(hostPubCredsB.SigningKey) if !errors.Is(err, ErrInvalidSignature) { - t.Fatalf("expected ErrInvalidSignature but got %v", err) + t.Fatalf("expected ErrInvalidSignature but got: %v", err) } b, err := signedB.Unwrap(hostPubCredsA.SigningKey) diff --git a/go/nebula/signing_key.go b/go/nebula/signing_key.go index b1eed8d..b334434 100644 --- a/go/nebula/signing_key.go +++ b/go/nebula/signing_key.go @@ -17,14 +17,22 @@ type SigningPrivateKey ed25519.PrivateKey // MarshalText implements the encoding.TextMarshaler interface. func (k SigningPrivateKey) MarshalText() ([]byte, error) { + if k == nil { + return []byte(""), nil + } return encodeWithPrefix(sigPrivKeyPrefix, k), nil } // UnmarshalText implements the encoding.TextUnmarshaler interface. func (k *SigningPrivateKey) UnmarshalText(b []byte) error { + if len(b) == 0 { + *k = SigningPrivateKey{} + return nil + } + b, err := decodeWithPrefix(sigPrivKeyPrefix, b) if err != nil { - return fmt.Errorf("unmarshaling: %w", err) + return fmt.Errorf("unmarshaling signing private key: %w", err) } *k = SigningPrivateKey(b) @@ -45,14 +53,22 @@ type SigningPublicKey ed25519.PublicKey // MarshalText implements the encoding.TextMarshaler interface. func (pk SigningPublicKey) MarshalText() ([]byte, error) { + if pk == nil { + return []byte(""), nil + } return encodeWithPrefix(sigPubKeyPrefix, pk), nil } // UnmarshalText implements the encoding.TextUnmarshaler interface. func (pk *SigningPublicKey) UnmarshalText(b []byte) error { + if len(b) == 0 { + *pk = SigningPublicKey{} + return nil + } + b, err := decodeWithPrefix(sigPubKeyPrefix, b) if err != nil { - return fmt.Errorf("unmarshaling: %w", err) + return fmt.Errorf("unmarshaling signing public key: %w", err) } *pk = SigningPublicKey(b)