Allow marshaling/unmarshaling zero value keys

This commit is contained in:
Brian Picciano 2024-12-07 20:36:29 +01:00
parent 2e92081e07
commit 54cebcad53
5 changed files with 71 additions and 11 deletions

View File

@ -2,6 +2,7 @@ package nebula
import ( import (
"fmt" "fmt"
"reflect"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@ -37,14 +38,22 @@ func (c Certificate) Unwrap() *cert.NebulaCertificate {
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (c Certificate) MarshalText() ([]byte, error) { func (c Certificate) MarshalText() ([]byte, error) {
if reflect.DeepEqual(c, Certificate{}) {
return []byte(""), nil
}
return c.inner.MarshalToPEM() return c.inner.MarshalToPEM()
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (c *Certificate) UnmarshalText(b []byte) error { func (c *Certificate) UnmarshalText(b []byte) error {
if len(b) == 0 {
*c = Certificate{}
return nil
}
nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b) nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b)
if err != nil { if err != nil {
return err return fmt.Errorf("unmarshaling nebula certificate from PEM: %w", err)
} }
c.inner = *nebCrt c.inner = *nebCrt
return nil return nil

View File

@ -21,19 +21,31 @@ type EncryptingPublicKey struct{ inner *ecdh.PublicKey }
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (pk EncryptingPublicKey) MarshalText() ([]byte, error) { func (pk EncryptingPublicKey) MarshalText() ([]byte, error) {
return encodeWithPrefix(encPubKeyPrefix, pk.inner.Bytes()), nil if pk == (EncryptingPublicKey{}) {
return []byte(""), nil
}
return encodeWithPrefix(encPubKeyPrefix, pk.Bytes()), nil
} }
// Bytes returns the raw bytes of the EncryptingPublicKey. // Bytes returns the raw bytes of the EncryptingPublicKey, or nil if it is the
// zero value.
func (k EncryptingPublicKey) Bytes() []byte { func (k EncryptingPublicKey) Bytes() []byte {
if k == (EncryptingPublicKey{}) {
return nil
}
return k.inner.Bytes() return k.inner.Bytes()
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error { func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*pk = EncryptingPublicKey{}
return nil
}
b, err := decodeWithPrefix(encPubKeyPrefix, b) b, err := decodeWithPrefix(encPubKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling encrypting public key: %w", err)
} }
if pk.inner, err = x25519.NewPublicKey(b); err != nil { if pk.inner, err = x25519.NewPublicKey(b); err != nil {
@ -48,7 +60,7 @@ func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error { func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error {
b, _, err := cert.UnmarshalX25519PublicKey(b) b, _, err := cert.UnmarshalX25519PublicKey(b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling nebula PEM as encrypting public key: %w", err)
} }
if pk.inner, err = x25519.NewPublicKey(b); err != nil { if pk.inner, err = x25519.NewPublicKey(b); err != nil {
@ -86,19 +98,31 @@ func (k EncryptingPrivateKey) PublicKey() EncryptingPublicKey {
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (k EncryptingPrivateKey) MarshalText() ([]byte, error) { func (k EncryptingPrivateKey) MarshalText() ([]byte, error) {
return encodeWithPrefix(encPrivKeyPrefix, k.inner.Bytes()), nil if k == (EncryptingPrivateKey{}) {
return []byte(""), nil
}
return encodeWithPrefix(encPrivKeyPrefix, k.Bytes()), nil
} }
// Bytes returns the raw bytes of the EncryptingPrivateKey. // Bytes returns the raw bytes of the EncryptingPrivateKey, or nil if it is the
// zero value.
func (k EncryptingPrivateKey) Bytes() []byte { func (k EncryptingPrivateKey) Bytes() []byte {
if k == (EncryptingPrivateKey{}) {
return nil
}
return k.inner.Bytes() return k.inner.Bytes()
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error { func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*k = EncryptingPrivateKey{}
return nil
}
b, err := decodeWithPrefix(encPrivKeyPrefix, b) b, err := decodeWithPrefix(encPrivKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling encrypting private key: %w", err)
} }
if k.inner, err = x25519.NewPrivateKey(b); err != nil { if k.inner, err = x25519.NewPrivateKey(b); err != nil {

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"bytes"
"crypto" "crypto"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
@ -47,13 +48,23 @@ func Sign[T any](v T, k SigningPrivateKey) (Signed[T], error) {
return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)}) return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)})
} }
var jsonNull = []byte("null")
// MarshalJSON implements the json.Marshaler interface. // MarshalJSON implements the json.Marshaler interface.
func (s Signed[T]) MarshalJSON() ([]byte, error) { func (s Signed[T]) MarshalJSON() ([]byte, error) {
if s == nil {
return jsonNull, nil
}
return []byte(s), nil return []byte(s), nil
} }
// UnmarshalJSON implements the json.Unmarshaler interface. // UnmarshalJSON implements the json.Unmarshaler interface.
func (s *Signed[T]) UnmarshalJSON(b []byte) error { func (s *Signed[T]) UnmarshalJSON(b []byte) error {
if bytes.Equal(b, jsonNull) {
*s = nil
return nil
}
*s = b *s = b
return nil return nil
} }

View File

@ -34,7 +34,7 @@ func TestSigned(t *testing.T) {
_, err = signedB.Unwrap(hostPubCredsB.SigningKey) _, err = signedB.Unwrap(hostPubCredsB.SigningKey)
if !errors.Is(err, ErrInvalidSignature) { if !errors.Is(err, ErrInvalidSignature) {
t.Fatalf("expected ErrInvalidSignature but got %v", err) t.Fatalf("expected ErrInvalidSignature but got: %v", err)
} }
b, err := signedB.Unwrap(hostPubCredsA.SigningKey) b, err := signedB.Unwrap(hostPubCredsA.SigningKey)

View File

@ -17,14 +17,22 @@ type SigningPrivateKey ed25519.PrivateKey
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (k SigningPrivateKey) MarshalText() ([]byte, error) { func (k SigningPrivateKey) MarshalText() ([]byte, error) {
if k == nil {
return []byte(""), nil
}
return encodeWithPrefix(sigPrivKeyPrefix, k), nil return encodeWithPrefix(sigPrivKeyPrefix, k), nil
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (k *SigningPrivateKey) UnmarshalText(b []byte) error { func (k *SigningPrivateKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*k = SigningPrivateKey{}
return nil
}
b, err := decodeWithPrefix(sigPrivKeyPrefix, b) b, err := decodeWithPrefix(sigPrivKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling signing private key: %w", err)
} }
*k = SigningPrivateKey(b) *k = SigningPrivateKey(b)
@ -45,14 +53,22 @@ type SigningPublicKey ed25519.PublicKey
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (pk SigningPublicKey) MarshalText() ([]byte, error) { func (pk SigningPublicKey) MarshalText() ([]byte, error) {
if pk == nil {
return []byte(""), nil
}
return encodeWithPrefix(sigPubKeyPrefix, pk), nil return encodeWithPrefix(sigPubKeyPrefix, pk), nil
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (pk *SigningPublicKey) UnmarshalText(b []byte) error { func (pk *SigningPublicKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*pk = SigningPublicKey{}
return nil
}
b, err := decodeWithPrefix(sigPubKeyPrefix, b) b, err := decodeWithPrefix(sigPubKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling signing public key: %w", err)
} }
*pk = SigningPublicKey(b) *pk = SigningPublicKey(b)