Allow marshaling/unmarshaling zero value keys
This commit is contained in:
parent
2e92081e07
commit
54cebcad53
@ -2,6 +2,7 @@ package nebula
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
|
||||||
"github.com/slackhq/nebula/cert"
|
"github.com/slackhq/nebula/cert"
|
||||||
)
|
)
|
||||||
@ -37,14 +38,22 @@ func (c Certificate) Unwrap() *cert.NebulaCertificate {
|
|||||||
|
|
||||||
// MarshalText implements the encoding.TextMarshaler interface.
|
// MarshalText implements the encoding.TextMarshaler interface.
|
||||||
func (c Certificate) MarshalText() ([]byte, error) {
|
func (c Certificate) MarshalText() ([]byte, error) {
|
||||||
|
if reflect.DeepEqual(c, Certificate{}) {
|
||||||
|
return []byte(""), nil
|
||||||
|
}
|
||||||
return c.inner.MarshalToPEM()
|
return c.inner.MarshalToPEM()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||||
func (c *Certificate) UnmarshalText(b []byte) error {
|
func (c *Certificate) UnmarshalText(b []byte) error {
|
||||||
|
if len(b) == 0 {
|
||||||
|
*c = Certificate{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b)
|
nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("unmarshaling nebula certificate from PEM: %w", err)
|
||||||
}
|
}
|
||||||
c.inner = *nebCrt
|
c.inner = *nebCrt
|
||||||
return nil
|
return nil
|
||||||
|
@ -21,19 +21,31 @@ type EncryptingPublicKey struct{ inner *ecdh.PublicKey }
|
|||||||
|
|
||||||
// MarshalText implements the encoding.TextMarshaler interface.
|
// MarshalText implements the encoding.TextMarshaler interface.
|
||||||
func (pk EncryptingPublicKey) MarshalText() ([]byte, error) {
|
func (pk EncryptingPublicKey) MarshalText() ([]byte, error) {
|
||||||
return encodeWithPrefix(encPubKeyPrefix, pk.inner.Bytes()), nil
|
if pk == (EncryptingPublicKey{}) {
|
||||||
|
return []byte(""), nil
|
||||||
|
}
|
||||||
|
return encodeWithPrefix(encPubKeyPrefix, pk.Bytes()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bytes returns the raw bytes of the EncryptingPublicKey.
|
// Bytes returns the raw bytes of the EncryptingPublicKey, or nil if it is the
|
||||||
|
// zero value.
|
||||||
func (k EncryptingPublicKey) Bytes() []byte {
|
func (k EncryptingPublicKey) Bytes() []byte {
|
||||||
|
if k == (EncryptingPublicKey{}) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return k.inner.Bytes()
|
return k.inner.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||||
func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
|
func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
|
||||||
|
if len(b) == 0 {
|
||||||
|
*pk = EncryptingPublicKey{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
b, err := decodeWithPrefix(encPubKeyPrefix, b)
|
b, err := decodeWithPrefix(encPubKeyPrefix, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshaling: %w", err)
|
return fmt.Errorf("unmarshaling encrypting public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pk.inner, err = x25519.NewPublicKey(b); err != nil {
|
if pk.inner, err = x25519.NewPublicKey(b); err != nil {
|
||||||
@ -48,7 +60,7 @@ func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
|
|||||||
func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error {
|
func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error {
|
||||||
b, _, err := cert.UnmarshalX25519PublicKey(b)
|
b, _, err := cert.UnmarshalX25519PublicKey(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshaling: %w", err)
|
return fmt.Errorf("unmarshaling nebula PEM as encrypting public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if pk.inner, err = x25519.NewPublicKey(b); err != nil {
|
if pk.inner, err = x25519.NewPublicKey(b); err != nil {
|
||||||
@ -86,19 +98,31 @@ func (k EncryptingPrivateKey) PublicKey() EncryptingPublicKey {
|
|||||||
|
|
||||||
// MarshalText implements the encoding.TextMarshaler interface.
|
// MarshalText implements the encoding.TextMarshaler interface.
|
||||||
func (k EncryptingPrivateKey) MarshalText() ([]byte, error) {
|
func (k EncryptingPrivateKey) MarshalText() ([]byte, error) {
|
||||||
return encodeWithPrefix(encPrivKeyPrefix, k.inner.Bytes()), nil
|
if k == (EncryptingPrivateKey{}) {
|
||||||
|
return []byte(""), nil
|
||||||
|
}
|
||||||
|
return encodeWithPrefix(encPrivKeyPrefix, k.Bytes()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bytes returns the raw bytes of the EncryptingPrivateKey.
|
// Bytes returns the raw bytes of the EncryptingPrivateKey, or nil if it is the
|
||||||
|
// zero value.
|
||||||
func (k EncryptingPrivateKey) Bytes() []byte {
|
func (k EncryptingPrivateKey) Bytes() []byte {
|
||||||
|
if k == (EncryptingPrivateKey{}) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return k.inner.Bytes()
|
return k.inner.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||||
func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error {
|
func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error {
|
||||||
|
if len(b) == 0 {
|
||||||
|
*k = EncryptingPrivateKey{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
b, err := decodeWithPrefix(encPrivKeyPrefix, b)
|
b, err := decodeWithPrefix(encPrivKeyPrefix, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshaling: %w", err)
|
return fmt.Errorf("unmarshaling encrypting private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if k.inner, err = x25519.NewPrivateKey(b); err != nil {
|
if k.inner, err = x25519.NewPrivateKey(b); err != nil {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package nebula
|
package nebula
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
@ -47,13 +48,23 @@ func Sign[T any](v T, k SigningPrivateKey) (Signed[T], error) {
|
|||||||
return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)})
|
return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var jsonNull = []byte("null")
|
||||||
|
|
||||||
// MarshalJSON implements the json.Marshaler interface.
|
// MarshalJSON implements the json.Marshaler interface.
|
||||||
func (s Signed[T]) MarshalJSON() ([]byte, error) {
|
func (s Signed[T]) MarshalJSON() ([]byte, error) {
|
||||||
|
if s == nil {
|
||||||
|
return jsonNull, nil
|
||||||
|
}
|
||||||
return []byte(s), nil
|
return []byte(s), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalJSON implements the json.Unmarshaler interface.
|
// UnmarshalJSON implements the json.Unmarshaler interface.
|
||||||
func (s *Signed[T]) UnmarshalJSON(b []byte) error {
|
func (s *Signed[T]) UnmarshalJSON(b []byte) error {
|
||||||
|
if bytes.Equal(b, jsonNull) {
|
||||||
|
*s = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
*s = b
|
*s = b
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -34,7 +34,7 @@ func TestSigned(t *testing.T) {
|
|||||||
|
|
||||||
_, err = signedB.Unwrap(hostPubCredsB.SigningKey)
|
_, err = signedB.Unwrap(hostPubCredsB.SigningKey)
|
||||||
if !errors.Is(err, ErrInvalidSignature) {
|
if !errors.Is(err, ErrInvalidSignature) {
|
||||||
t.Fatalf("expected ErrInvalidSignature but got %v", err)
|
t.Fatalf("expected ErrInvalidSignature but got: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
b, err := signedB.Unwrap(hostPubCredsA.SigningKey)
|
b, err := signedB.Unwrap(hostPubCredsA.SigningKey)
|
||||||
|
@ -17,14 +17,22 @@ type SigningPrivateKey ed25519.PrivateKey
|
|||||||
|
|
||||||
// MarshalText implements the encoding.TextMarshaler interface.
|
// MarshalText implements the encoding.TextMarshaler interface.
|
||||||
func (k SigningPrivateKey) MarshalText() ([]byte, error) {
|
func (k SigningPrivateKey) MarshalText() ([]byte, error) {
|
||||||
|
if k == nil {
|
||||||
|
return []byte(""), nil
|
||||||
|
}
|
||||||
return encodeWithPrefix(sigPrivKeyPrefix, k), nil
|
return encodeWithPrefix(sigPrivKeyPrefix, k), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||||
func (k *SigningPrivateKey) UnmarshalText(b []byte) error {
|
func (k *SigningPrivateKey) UnmarshalText(b []byte) error {
|
||||||
|
if len(b) == 0 {
|
||||||
|
*k = SigningPrivateKey{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
b, err := decodeWithPrefix(sigPrivKeyPrefix, b)
|
b, err := decodeWithPrefix(sigPrivKeyPrefix, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshaling: %w", err)
|
return fmt.Errorf("unmarshaling signing private key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
*k = SigningPrivateKey(b)
|
*k = SigningPrivateKey(b)
|
||||||
@ -45,14 +53,22 @@ type SigningPublicKey ed25519.PublicKey
|
|||||||
|
|
||||||
// MarshalText implements the encoding.TextMarshaler interface.
|
// MarshalText implements the encoding.TextMarshaler interface.
|
||||||
func (pk SigningPublicKey) MarshalText() ([]byte, error) {
|
func (pk SigningPublicKey) MarshalText() ([]byte, error) {
|
||||||
|
if pk == nil {
|
||||||
|
return []byte(""), nil
|
||||||
|
}
|
||||||
return encodeWithPrefix(sigPubKeyPrefix, pk), nil
|
return encodeWithPrefix(sigPubKeyPrefix, pk), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
// UnmarshalText implements the encoding.TextUnmarshaler interface.
|
||||||
func (pk *SigningPublicKey) UnmarshalText(b []byte) error {
|
func (pk *SigningPublicKey) UnmarshalText(b []byte) error {
|
||||||
|
if len(b) == 0 {
|
||||||
|
*pk = SigningPublicKey{}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
b, err := decodeWithPrefix(sigPubKeyPrefix, b)
|
b, err := decodeWithPrefix(sigPubKeyPrefix, b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unmarshaling: %w", err)
|
return fmt.Errorf("unmarshaling signing public key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
*pk = SigningPublicKey(b)
|
*pk = SigningPublicKey(b)
|
||||||
|
Loading…
Reference in New Issue
Block a user