Compare commits

...

4 Commits

25 changed files with 738 additions and 359 deletions

3
.gitignore vendored
View File

@ -1,4 +1 @@
*-bin
*admin.yml*
*bootstrap.yml*
result result

View File

@ -3,7 +3,7 @@
package bootstrap package bootstrap
import ( import (
"crypto/sha512" "cmp"
"encoding/json" "encoding/json"
"fmt" "fmt"
"isle/nebula" "isle/nebula"
@ -11,7 +11,7 @@ import (
"maps" "maps"
"net/netip" "net/netip"
"path/filepath" "path/filepath"
"sort" "slices"
"strings" "strings"
"dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mctx"
@ -110,33 +110,26 @@ func New(
) ( ) (
Bootstrap, error, Bootstrap, error,
) { ) {
hostPubCreds, hostPrivCreds, err := nebula.NewHostCredentials( host, hostPrivCreds, err := NewHost(caCreds, name, ip)
caCreds, name, ip,
)
if err != nil { if err != nil {
return Bootstrap{}, fmt.Errorf("generating host credentials: %w", err) return Bootstrap{}, fmt.Errorf("creating host: %w", err)
} }
assigned := HostAssigned{ signedAssigned, err := nebula.Sign(
Name: name, host.HostAssigned, caCreds.SigningPrivateKey,
PublicCredentials: hostPubCreds, )
}
signedAssigned, err := nebula.Sign(assigned, caCreds.SigningPrivateKey)
if err != nil { if err != nil {
return Bootstrap{}, fmt.Errorf("signing assigned fields: %w", err) return Bootstrap{}, fmt.Errorf("signing assigned fields: %w", err)
} }
existingHosts = maps.Clone(existingHosts) existingHosts = maps.Clone(existingHosts)
existingHosts[name] = Host{ existingHosts[name] = host
HostAssigned: assigned,
}
return Bootstrap{ return Bootstrap{
NetworkCreationParams: adminCreationParams, NetworkCreationParams: adminCreationParams,
CAPublicCredentials: caCreds.Public, CAPublicCredentials: caCreds.Public,
PrivateCredentials: hostPrivCreds, PrivateCredentials: hostPrivCreds,
HostAssigned: assigned, HostAssigned: host.HostAssigned,
SignedHostAssigned: signedAssigned, SignedHostAssigned: signedAssigned,
Hosts: existingHosts, Hosts: existingHosts,
}, nil }, nil
@ -150,15 +143,19 @@ func (b *Bootstrap) UnmarshalJSON(data []byte) error {
err := json.Unmarshal(data, (*inner)(b)) err := json.Unmarshal(data, (*inner)(b))
if err != nil { if err != nil {
return err return fmt.Errorf("json unmarshaling: %w", err)
} }
// Generally this will be filled, but during unit tests we sometimes leave
// it empty for convenience.
if b.SignedHostAssigned != nil {
b.HostAssigned, err = b.SignedHostAssigned.Unwrap( b.HostAssigned, err = b.SignedHostAssigned.Unwrap(
b.CAPublicCredentials.SigningKey, b.CAPublicCredentials.SigningKey,
) )
if err != nil { if err != nil {
return fmt.Errorf("unwrapping HostAssigned: %w", err) return fmt.Errorf("unwrapping HostAssigned: %w", err)
} }
}
return nil return nil
} }
@ -174,22 +171,16 @@ func (b Bootstrap) ThisHost() Host {
return host return host
} }
// Hash returns a deterministic hash of the given hosts map. // HostsOrdered returns the Hosts as a slice in a deterministic order.
func HostsHash(hostsMap map[nebula.HostName]Host) ([]byte, error) { func (b Bootstrap) HostsOrdered() []Host {
hosts := make([]Host, 0, len(b.Hosts))
hosts := make([]Host, 0, len(hostsMap)) for _, host := range b.Hosts {
for _, host := range hostsMap {
hosts = append(hosts, host) hosts = append(hosts, host)
} }
sort.Slice(hosts, func(i, j int) bool { slices.SortFunc(hosts, func(a, b Host) int {
return hosts[i].Name < hosts[j].Name return cmp.Compare(a.Name, b.Name)
}) })
h := sha512.New() return hosts
if err := json.NewEncoder(h).Encode(hosts); err != nil {
return nil, err
}
return h.Sum(nil), nil
} }

View File

@ -73,6 +73,31 @@ type Host struct {
HostConfigured HostConfigured
} }
// NewHost creates a Host instance using the given assigned fields, along with
// the HostPrivateCredentials which its PublicCredentials field.
func NewHost(
caCreds nebula.CACredentials, name nebula.HostName, ip netip.Addr,
) (
host Host, hostPrivCreds nebula.HostPrivateCredentials, err error,
) {
hostPubCreds, hostPrivCreds, err := nebula.NewHostCredentials(
caCreds, name, ip,
)
if err != nil {
err = fmt.Errorf("generating host credentials: %w", err)
return
}
host = Host{
HostAssigned: HostAssigned{
Name: name,
PublicCredentials: hostPubCreds,
},
}
return
}
// IP returns the IP address encoded in the Host's nebula certificate, or panics // IP returns the IP address encoded in the Host's nebula certificate, or panics
// if there is an error. // if there is an error.
// //

View File

@ -1,9 +0,0 @@
package main
import (
"isle/bootstrap"
)
func (ctx subCmdCtx) getHosts() ([]bootstrap.Host, error) {
return ctx.getDaemonRPC().GetHosts(ctx)
}

View File

@ -4,10 +4,8 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"isle/bootstrap"
"isle/daemon/network" "isle/daemon/network"
"os" "os"
"sort"
) )
var subCmdHostCreate = subCmd{ var subCmdHostCreate = subCmd{
@ -65,35 +63,48 @@ var subCmdHostList = subCmd{
return nil, fmt.Errorf("parsing flags: %w", err) return nil, fmt.Errorf("parsing flags: %w", err)
} }
hostsRes, err := ctx.getHosts() currBoostrap, err := ctx.getDaemonRPC().GetBootstrap(ctx)
if err != nil { if err != nil {
return nil, fmt.Errorf("calling GetHosts: %w", err) return nil, fmt.Errorf("calling GetBootstrap: %w", err)
} }
type host struct { hosts := currBoostrap.HostsOrdered()
Name string
type storageInstanceView struct {
ID string `yaml:"id"`
RPCPort int `yaml:"rpc_port"`
S3APIPort int `yaml:"s3_api_port"`
}
type hostView struct {
Name string `yaml:"name"`
VPN struct { VPN struct {
IP string IP string `yaml:"ip"`
PublicAddr string `yaml:"public_addr,omitempty"`
} }
Storage bootstrap.GarageHost `json:",omitempty"` Storage struct {
Instances []storageInstanceView `yaml:"instances"`
} `yaml:",omitempty"`
} }
hosts := make([]host, 0, len(hostsRes)) hostViews := make([]hostView, len(hosts))
for _, h := range hostsRes { for i, host := range hosts {
storageInstanceViews := make([]storageInstanceView, len(host.Garage.Instances))
host := host{ for i := range host.Garage.Instances {
Name: string(h.Name), storageInstanceViews[i] = storageInstanceView(host.Garage.Instances[i])
Storage: h.Garage,
} }
host.VPN.IP = h.IP().String() hostView := hostView{
Name: string(host.Name),
hosts = append(hosts, host)
} }
sort.Slice(hosts, func(i, j int) bool { return hosts[i].Name < hosts[j].Name }) hostView.VPN.IP = host.IP().String()
hostView.VPN.PublicAddr = host.Nebula.PublicAddr
hostView.Storage.Instances = storageInstanceViews
hostViews[i] = hostView
}
return hosts, nil return hostViews, nil
}), }),
} }

View File

@ -0,0 +1,134 @@
package main
import (
"context"
"isle/bootstrap"
"isle/nebula"
"isle/toolkit"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
func TestHostList(t *testing.T) {
t.Parallel()
var ipNet nebula.IPNet
require.NoError(t, ipNet.UnmarshalText([]byte("172.16.0.0/16")))
caCreds, err := nebula.NewCACredentials("test.com", ipNet)
require.NoError(t, err)
type host struct {
name string
ip string
publicAddr string
storageInstances []bootstrap.GarageHostInstance
}
tests := []struct {
name string
hosts []host
want any
}{
{
name: "no hosts",
want: []any{},
},
{
name: "single",
hosts: []host{
{
name: "a",
ip: "172.16.0.1",
},
},
want: []map[string]any{
{
"name": "a",
"vpn": map[string]any{
"ip": "172.16.0.1",
},
},
},
},
{
name: "multiple",
hosts: []host{
{
name: "a",
ip: "172.16.0.1",
},
{
name: "b",
ip: "172.16.0.2",
publicAddr: "1.1.1.1:80",
storageInstances: []bootstrap.GarageHostInstance{{
ID: "storageInstanceID",
RPCPort: 9000,
S3APIPort: 9001,
}},
},
},
want: []map[string]any{
{
"name": "a",
"vpn": map[string]any{
"ip": "172.16.0.1",
},
},
{
"name": "b",
"vpn": map[string]any{
"ip": "172.16.0.2",
"public_addr": "1.1.1.1:80",
},
"storage": map[string]any{
"instances": []any{
map[string]any{
"id": "storageInstanceID",
"rpc_port": 9000,
"s3_api_port": 9001,
},
},
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
hosts = map[nebula.HostName]bootstrap.Host{}
)
for _, testHost := range test.hosts {
var (
hostName nebula.HostName
ip = netip.MustParseAddr(testHost.ip)
)
require.NoError(
t, hostName.UnmarshalText([]byte(testHost.name)),
)
host, _, err := bootstrap.NewHost(caCreds, hostName, ip)
require.NoError(t, err)
host.Nebula.PublicAddr = testHost.publicAddr
host.Garage.Instances = testHost.storageInstances
hosts[hostName] = host
}
h.daemonRPC.
On("GetBootstrap", toolkit.MockArg[context.Context]()).
Return(bootstrap.Bootstrap{Hosts: hosts}, nil).
Once()
h.runAssertStdout(t, test.want, "host", "list")
})
}
}

View File

@ -4,7 +4,9 @@ import (
"bytes" "bytes"
"context" "context"
"isle/daemon" "isle/daemon"
"isle/daemon/jsonrpc2"
"isle/toolkit" "isle/toolkit"
"net/http/httptest"
"reflect" "reflect"
"testing" "testing"
@ -16,8 +18,9 @@ import (
type runHarness struct { type runHarness struct {
ctx context.Context ctx context.Context
logger *mlog.Logger logger *mlog.Logger
daemonRPC *daemon.MockRPC
stdout *bytes.Buffer stdout *bytes.Buffer
daemonRPC *daemon.MockRPC
daemonRPCServer *httptest.Server
} }
func newRunHarness(t *testing.T) *runHarness { func newRunHarness(t *testing.T) *runHarness {
@ -26,17 +29,33 @@ func newRunHarness(t *testing.T) *runHarness {
var ( var (
ctx = context.Background() ctx = context.Background()
logger = toolkit.NewTestLogger(t) logger = toolkit.NewTestLogger(t)
daemonRPC = daemon.NewMockRPC(t)
stdout = new(bytes.Buffer) stdout = new(bytes.Buffer)
daemonRPC = daemon.NewMockRPC(t)
daemonRPCHandler = jsonrpc2.NewHTTPHandler(daemon.NewRPCHandler(
logger.WithNamespace("rpc"), daemonRPC,
))
daemonRPCServer = httptest.NewServer(daemonRPCHandler)
) )
return &runHarness{ctx, logger, daemonRPC, stdout} t.Cleanup(daemonRPCServer.Close)
return &runHarness{ctx, logger, stdout, daemonRPC, daemonRPCServer}
} }
func (h *runHarness) run(_ *testing.T, args ...string) error { func (h *runHarness) run(t *testing.T, args ...string) error {
httpClient := toolkit.NewHTTPClient(h.logger.WithNamespace("http"))
t.Cleanup(func() {
assert.NoError(t, httpClient.Close())
})
daemonRPCClient := daemon.RPCFromClient(
jsonrpc2.NewHTTPClient(httpClient, h.daemonRPCServer.URL),
)
return doRootCmd(h.ctx, h.logger, &subCmdCtxOpts{ return doRootCmd(h.ctx, h.logger, &subCmdCtxOpts{
args: args, args: args,
daemonRPC: h.daemonRPC, daemonRPC: daemonRPCClient,
stdout: h.stdout, stdout: h.stdout,
}) })
} }

View File

@ -63,73 +63,12 @@ var subCmdNebulaCreateCert = subCmd{
}, },
} }
var subCmdNebulaShow = subCmd{
name: "show",
descr: "Writes nebula network information to stdout in JSON format",
do: doWithOutput(func(ctx subCmdCtx) (any, error) {
ctx, err := ctx.withParsedFlags()
if err != nil {
return nil, fmt.Errorf("parsing flags: %w", err)
}
hosts, err := ctx.getHosts()
if err != nil {
return nil, fmt.Errorf("getting hosts: %w", err)
}
caPublicCreds, err := ctx.getDaemonRPC().GetNebulaCAPublicCredentials(ctx)
if err != nil {
return nil, fmt.Errorf("calling GetNebulaCAPublicCredentials: %w", err)
}
caCert := caPublicCreds.Cert
caCertDetails := caCert.Unwrap().Details
if len(caCertDetails.Subnets) != 1 {
return nil, fmt.Errorf(
"malformed ca.crt, contains unexpected subnets %#v",
caCertDetails.Subnets,
)
}
subnet := caCertDetails.Subnets[0]
type outLighthouse struct {
PublicAddr string
IP string
}
out := struct {
CACert nebula.Certificate
SubnetCIDR string
Lighthouses []outLighthouse
}{
CACert: caCert,
SubnetCIDR: subnet.String(),
}
for _, h := range hosts {
if h.Nebula.PublicAddr == "" {
continue
}
out.Lighthouses = append(out.Lighthouses, outLighthouse{
PublicAddr: h.Nebula.PublicAddr,
IP: h.IP().String(),
})
}
return out, nil
}),
}
var subCmdNebula = subCmd{ var subCmdNebula = subCmd{
name: "nebula", name: "nebula",
descr: "Sub-commands related to the nebula VPN", descr: "Sub-commands related to the nebula VPN",
do: func(ctx subCmdCtx) error { do: func(ctx subCmdCtx) error {
return ctx.doSubCmd( return ctx.doSubCmd(
subCmdNebulaCreateCert, subCmdNebulaCreateCert,
subCmdNebulaShow,
) )
}, },
} }

View File

@ -1,10 +1,15 @@
package main package main
import ( import (
"cmp"
"errors" "errors"
"fmt" "fmt"
"isle/bootstrap"
"isle/daemon"
"isle/daemon/network" "isle/daemon/network"
"isle/jsonutil" "isle/jsonutil"
"isle/nebula"
"slices"
) )
var subCmdNetworkCreate = subCmd{ var subCmdNetworkCreate = subCmd{
@ -102,7 +107,75 @@ var subCmdNetworkList = subCmd{
return nil, fmt.Errorf("parsing flags: %w", err) return nil, fmt.Errorf("parsing flags: %w", err)
} }
return ctx.getDaemonRPC().GetNetworks(ctx) networkCreationParams, err := ctx.getDaemonRPC().GetNetworks(ctx)
if err != nil {
return nil, fmt.Errorf("calling GetNetworks: %w", err)
}
type lighthouseView struct {
PublicAddr string `yaml:"public_addr,omitempty"`
IP string `yaml:"ip"`
}
type networkView struct {
bootstrap.CreationParams `yaml:",inline"`
CACert nebula.Certificate `yaml:"ca_cert"`
SubnetCIDR string `yaml:"subnet_cidr"`
Lighthouses []lighthouseView `yaml:"lighthouses"`
}
var (
daemonRPC = ctx.getDaemonRPC()
networkViews = make([]networkView, len(networkCreationParams))
)
for i, creationParams := range networkCreationParams {
ctx := daemon.WithNetwork(ctx, creationParams.ID)
networkBootstrap, err := daemonRPC.GetBootstrap(ctx)
if err != nil {
return nil, fmt.Errorf(
"calling GetBootstrap with network:%+v: %w",
networkCreationParams,
err,
)
}
var (
caCert = networkBootstrap.CAPublicCredentials.Cert
caCertDetails = caCert.Unwrap().Details
subnet = caCertDetails.Subnets[0]
lighthouseViews []lighthouseView
)
for _, h := range networkBootstrap.HostsOrdered() {
if h.Nebula.PublicAddr == "" {
continue
}
lighthouseViews = append(lighthouseViews, lighthouseView{
PublicAddr: h.Nebula.PublicAddr,
IP: h.IP().String(),
})
}
networkViews[i] = networkView{
CreationParams: creationParams,
CACert: caCert,
SubnetCIDR: subnet.String(),
Lighthouses: lighthouseViews,
}
}
slices.SortFunc(networkViews, func(a, b networkView) int {
return cmp.Or(
cmp.Compare(a.Name, b.Name),
cmp.Compare(a.ID, b.ID),
)
})
return networkViews, nil
}), }),
} }

View File

@ -0,0 +1,217 @@
package main
import (
"context"
"fmt"
"isle/bootstrap"
"isle/daemon"
"isle/nebula"
"isle/toolkit"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
)
func TestNetworkList(t *testing.T) {
t.Parallel()
type networkBase struct {
bootstrap.CreationParams
ipNet nebula.IPNet
caCreds nebula.CACredentials
caCertPEM string
}
newNetworkBase := func(id, name, domain, ipNetStr string) networkBase {
var ipNet nebula.IPNet
require.NoError(t, ipNet.UnmarshalText([]byte(ipNetStr)))
caCreds, err := nebula.NewCACredentials(domain, ipNet)
require.NoError(t, err)
caCertPEM, err := caCreds.Public.Cert.MarshalText()
require.NoError(t, err)
return networkBase{
CreationParams: bootstrap.CreationParams{
ID: id,
Name: name,
Domain: domain,
},
ipNet: ipNet,
caCreds: caCreds,
caCertPEM: string(caCertPEM),
}
}
var (
networkBaseA = newNetworkBase("idA", "nameA", "a.com", "172.16.1.0/24")
networkBaseB = newNetworkBase("idB", "nameB", "b.com", "172.16.2.0/24")
)
type host struct {
ip string
publicAddr string
}
type network struct {
networkBase
hosts []host
}
tests := []struct {
name string
networks []network
want []map[string]any
}{
{
name: "no networks",
want: []map[string]any{},
},
{
name: "single",
networks: []network{
{
networkBase: networkBaseA,
hosts: []host{
{
ip: "172.16.1.1",
publicAddr: "1.1.1.1:80",
},
},
},
},
want: []map[string]any{
{
"id": "idA",
"name": "nameA",
"domain": "a.com",
"ca_cert": networkBaseA.caCertPEM,
"subnet_cidr": "172.16.1.0/24",
"lighthouses": []any{
map[string]any{
"ip": "172.16.1.1",
"public_addr": "1.1.1.1:80",
},
},
},
},
},
{
name: "multiple",
networks: []network{
{
networkBase: networkBaseB,
hosts: []host{
{
ip: "172.16.2.1",
publicAddr: "2.2.2.2:80",
},
{
ip: "172.16.2.2",
publicAddr: "3.3.3.3:80",
},
{
ip: "172.16.2.3",
},
},
},
{
networkBase: networkBaseA,
hosts: []host{
{
ip: "172.16.1.1",
publicAddr: "1.1.1.1:80",
},
},
},
},
want: []map[string]any{
{
"id": "idA",
"name": "nameA",
"domain": "a.com",
"ca_cert": networkBaseA.caCertPEM,
"subnet_cidr": "172.16.1.0/24",
"lighthouses": []any{
map[string]any{
"ip": "172.16.1.1",
"public_addr": "1.1.1.1:80",
},
},
},
{
"id": "idB",
"name": "nameB",
"domain": "b.com",
"ca_cert": networkBaseB.caCertPEM,
"subnet_cidr": "172.16.2.0/24",
"lighthouses": []any{
map[string]any{
"ip": "172.16.2.1",
"public_addr": "2.2.2.2:80",
},
map[string]any{
"ip": "172.16.2.2",
"public_addr": "3.3.3.3:80",
},
},
},
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var (
h = newRunHarness(t)
creationParams = make([]bootstrap.CreationParams, len(test.networks))
)
for i, testNetwork := range test.networks {
creationParams[i] = testNetwork.CreationParams
hosts := map[nebula.HostName]bootstrap.Host{}
for _, testHost := range testNetwork.hosts {
var (
hostName nebula.HostName
ip = netip.MustParseAddr(testHost.ip)
hostNameStr = fmt.Sprintf("host%d", len(hosts))
)
require.NoError(
t, hostName.UnmarshalText([]byte(hostNameStr)),
)
host, _, err := bootstrap.NewHost(
testNetwork.caCreds, hostName, ip,
)
require.NoError(t, err)
host.Nebula.PublicAddr = testHost.publicAddr
hosts[hostName] = host
}
h.daemonRPC.
On(
"GetBootstrap",
daemon.MockContextWithNetwork(testNetwork.ID),
).
Return(bootstrap.Bootstrap{
NetworkCreationParams: creationParams[i],
CAPublicCredentials: testNetwork.caCreds.Public,
Hosts: hosts,
}, nil).
Once()
}
h.daemonRPC.
On("GetNetworks", toolkit.MockArg[context.Context]()).
Return(creationParams, nil).
Once()
h.runAssertStdout(t, test.want, "network", "list")
})
}
}

View File

@ -60,6 +60,15 @@ func (c *rpcClient) CreateNetwork(ctx context.Context, name string, domain strin
return return
} }
func (c *rpcClient) GetBootstrap(ctx context.Context) (b1 bootstrap.Bootstrap, err error) {
err = c.client.Call(
ctx,
&b1,
"GetBootstrap",
)
return
}
func (c *rpcClient) GetConfig(ctx context.Context) (n1 daecommon.NetworkConfig, err error) { func (c *rpcClient) GetConfig(ctx context.Context) (n1 daecommon.NetworkConfig, err error) {
err = c.client.Call( err = c.client.Call(
ctx, ctx,
@ -78,24 +87,6 @@ func (c *rpcClient) GetGarageClientParams(ctx context.Context) (g1 network.Garag
return return
} }
func (c *rpcClient) GetHosts(ctx context.Context) (ha1 []bootstrap.Host, err error) {
err = c.client.Call(
ctx,
&ha1,
"GetHosts",
)
return
}
func (c *rpcClient) GetNebulaCAPublicCredentials(ctx context.Context) (c2 nebula.CAPublicCredentials, err error) {
err = c.client.Call(
ctx,
&c2,
"GetNebulaCAPublicCredentials",
)
return
}
func (c *rpcClient) GetNetworks(ctx context.Context) (ca1 []bootstrap.CreationParams, err error) { func (c *rpcClient) GetNetworks(ctx context.Context) (ca1 []bootstrap.CreationParams, err error) {
err = c.client.Call( err = c.client.Call(
ctx, ctx,

View File

@ -3,6 +3,8 @@ package daemon
import ( import (
"context" "context"
"isle/daemon/jsonrpc2" "isle/daemon/jsonrpc2"
"github.com/stretchr/testify/mock"
) )
const metaKeyNetworkSearchStr = "daemon.networkSearchStr" const metaKeyNetworkSearchStr = "daemon.networkSearchStr"
@ -18,3 +20,12 @@ func getNetworkSearchStr(ctx context.Context) string {
v, _ := jsonrpc2.GetMeta(ctx)[metaKeyNetworkSearchStr].(string) v, _ := jsonrpc2.GetMeta(ctx)[metaKeyNetworkSearchStr].(string)
return v return v
} }
// MockContextWithNetwork returns a value which can be used with the
// tesstify/mock package to match a context which has a search string added to
// it by WithNetwork.
func MockContextWithNetwork(searchStr string) any {
return mock.MatchedBy(func(ctx context.Context) bool {
return getNetworkSearchStr(ctx) == searchStr
})
}

View File

@ -242,13 +242,21 @@ func (d *Daemon) GetNetworks(
return res, nil return res, nil
} }
// GetHost implements the method for the network.RPC interface. // GetBootstrap implements the method for the network.RPC interface.
func (d *Daemon) GetHosts(ctx context.Context) ([]bootstrap.Host, error) { func (d *Daemon) GetBootstrap(
ctx context.Context,
) (
bootstrap.Bootstrap, error,
) {
return withNetwork( return withNetwork(
ctx, ctx,
d, d,
func(ctx context.Context, n joinedNetwork) ([]bootstrap.Host, error) { func(
return n.GetHosts(ctx) ctx context.Context, n joinedNetwork,
) (
bootstrap.Bootstrap, error,
) {
return n.GetBootstrap(ctx)
}, },
) )
} }
@ -272,26 +280,6 @@ func (d *Daemon) GetGarageClientParams(
) )
} }
// GetNebulaCAPublicCredentials implements the method for the network.RPC
// interface.
func (d *Daemon) GetNebulaCAPublicCredentials(
ctx context.Context,
) (
nebula.CAPublicCredentials, error,
) {
return withNetwork(
ctx,
d,
func(
ctx context.Context, n joinedNetwork,
) (
nebula.CAPublicCredentials, error,
) {
return n.GetNebulaCAPublicCredentials(ctx)
},
)
}
// RemoveHost implements the method for the network.RPC interface. // RemoveHost implements the method for the network.RPC interface.
func (d *Daemon) RemoveHost(ctx context.Context, hostName nebula.HostName) error { func (d *Daemon) RemoveHost(ctx context.Context, hostName nebula.HostName) error {
_, err := withNetwork( _, err := withNetwork(

View File

@ -3,6 +3,7 @@ package jsonrpc2
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
"net/http" "net/http"
) )
@ -50,7 +51,9 @@ func (h httpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
switch rpcErr.Code { switch rpcErr.Code {
case 0: // no error case 0: // no error
_ = encodeResponse(enc, req.ID, res) if err := encodeResponse(enc, req.ID, res); err != nil {
panic(fmt.Errorf("encoding response %#v: %w", res, err))
}
return return
case errCodeMethodNotFound: case errCodeMethodNotFound:
@ -63,5 +66,7 @@ func (h httpHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusBadRequest) rw.WriteHeader(http.StatusBadRequest)
} }
_ = encodeErrorResponse(enc, req.ID, rpcErr) if err := encodeErrorResponse(enc, req.ID, rpcErr); err != nil {
panic(fmt.Errorf("encoding error %+v: %w", rpcErr, err))
}
} }

View File

@ -6,7 +6,6 @@ package network
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/json" "encoding/json"
@ -21,13 +20,11 @@ import (
"isle/secrets" "isle/secrets"
"isle/toolkit" "isle/toolkit"
"net/netip" "net/netip"
"slices"
"sync" "sync"
"time" "time"
"dev.mediocregopher.com/mediocre-go-lib.git/mctx" "dev.mediocregopher.com/mediocre-go-lib.git/mctx"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog" "dev.mediocregopher.com/mediocre-go-lib.git/mlog"
"golang.org/x/exp/maps"
) )
// GarageClientParams contains all the data needed to instantiate garage // GarageClientParams contains all the data needed to instantiate garage
@ -75,21 +72,14 @@ type JoiningBootstrap struct {
// RPC defines the methods related to a single network which are available over // RPC defines the methods related to a single network which are available over
// the daemon's RPC interface. // the daemon's RPC interface.
type RPC interface { type RPC interface {
// GetHosts returns all hosts known to the network, sorted by their name. // GetBootstrap returns the currently active Bootstrap for the Network. The
GetHosts(context.Context) ([]bootstrap.Host, error) // PrivateCredentials field will be zero'd out before being returned.
GetBootstrap(context.Context) (bootstrap.Bootstrap, error)
// GetGarageClientParams returns a GarageClientParams for the current // GetGarageClientParams returns a GarageClientParams for the current
// network state. // network state.
GetGarageClientParams(context.Context) (GarageClientParams, error) GetGarageClientParams(context.Context) (GarageClientParams, error)
// GetNebulaCAPublicCredentials returns the CAPublicCredentials for the
// network.
GetNebulaCAPublicCredentials(
context.Context,
) (
nebula.CAPublicCredentials, error,
)
// RemoveHost removes the host of the given name from the network. // RemoveHost removes the host of the given name from the network.
RemoveHost(ctx context.Context, hostName nebula.HostName) error RemoveHost(ctx context.Context, hostName nebula.HostName) error
@ -750,17 +740,18 @@ func (n *network) getBootstrap() (
}) })
} }
func (n *network) GetHosts(ctx context.Context) ([]bootstrap.Host, error) { func (n *network) GetBootstrap(
ctx context.Context,
) (
bootstrap.Bootstrap, error,
) {
return withCurrBootstrap(n, func( return withCurrBootstrap(n, func(
currBootstrap bootstrap.Bootstrap, currBootstrap bootstrap.Bootstrap,
) ( ) (
[]bootstrap.Host, error, bootstrap.Bootstrap, error,
) { ) {
hosts := maps.Values(currBootstrap.Hosts) currBootstrap.PrivateCredentials = nebula.HostPrivateCredentials{}
slices.SortFunc(hosts, func(a, b bootstrap.Host) int { return currBootstrap, nil
return cmp.Compare(a.Name, b.Name)
})
return hosts, nil
}) })
} }
@ -778,21 +769,6 @@ func (n *network) GetGarageClientParams(
}) })
} }
func (n *network) GetNebulaCAPublicCredentials(
ctx context.Context,
) (
nebula.CAPublicCredentials, error,
) {
b, err := n.getBootstrap()
if err != nil {
return nebula.CAPublicCredentials{}, fmt.Errorf(
"retrieving bootstrap: %w", err,
)
}
return b.CAPublicCredentials, nil
}
func (n *network) RemoveHost(ctx context.Context, hostName nebula.HostName) error { func (n *network) RemoveHost(ctx context.Context, hostName nebula.HostName) error {
// TODO RemoveHost should publish a certificate revocation for the host // TODO RemoveHost should publish a certificate revocation for the host
// being removed. // being removed.

View File

@ -85,6 +85,19 @@ func TestJoin(t *testing.T) {
}) })
} }
func TestNetwork_GetBootstrap(t *testing.T) {
var (
h = newIntegrationHarness(t)
network = h.createNetwork(t, "primus", nil)
)
currBootstrap, err := network.GetBootstrap(h.ctx)
assert.NoError(t, err)
assert.Equal(
t, nebula.HostPrivateCredentials{}, currBootstrap.PrivateCredentials,
)
}
func TestNetwork_SetConfig(t *testing.T) { func TestNetwork_SetConfig(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -364,13 +364,7 @@ func (nh *integrationHarnessNetwork) garageAdminClient(
func (nh *integrationHarnessNetwork) getHostsByName( func (nh *integrationHarnessNetwork) getHostsByName(
t *testing.T, t *testing.T,
) map[nebula.HostName]bootstrap.Host { ) map[nebula.HostName]bootstrap.Host {
hosts, err := nh.Network.GetHosts(nh.ctx) currBootstrap, err := nh.Network.GetBootstrap(nh.ctx)
require.NoError(t, err) require.NoError(t, err)
return currBootstrap.Hosts
hostsByName := map[nebula.HostName]bootstrap.Host{}
for _, h := range hosts {
hostsByName[h.Name] = h
}
return hostsByName
} }

View File

@ -74,6 +74,34 @@ func (_m *MockNetwork) CreateNebulaCertificate(_a0 context.Context, _a1 nebula.H
return r0, r1 return r0, r1
} }
// GetBootstrap provides a mock function with given fields: _a0
func (_m *MockNetwork) GetBootstrap(_a0 context.Context) (bootstrap.Bootstrap, error) {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for GetBootstrap")
}
var r0 bootstrap.Bootstrap
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (bootstrap.Bootstrap, error)); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(context.Context) bootstrap.Bootstrap); ok {
r0 = rf(_a0)
} else {
r0 = ret.Get(0).(bootstrap.Bootstrap)
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetConfig provides a mock function with given fields: _a0 // GetConfig provides a mock function with given fields: _a0
func (_m *MockNetwork) GetConfig(_a0 context.Context) (daecommon.NetworkConfig, error) { func (_m *MockNetwork) GetConfig(_a0 context.Context) (daecommon.NetworkConfig, error) {
ret := _m.Called(_a0) ret := _m.Called(_a0)
@ -130,64 +158,6 @@ func (_m *MockNetwork) GetGarageClientParams(_a0 context.Context) (GarageClientP
return r0, r1 return r0, r1
} }
// GetHosts provides a mock function with given fields: _a0
func (_m *MockNetwork) GetHosts(_a0 context.Context) ([]bootstrap.Host, error) {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for GetHosts")
}
var r0 []bootstrap.Host
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) ([]bootstrap.Host, error)); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(context.Context) []bootstrap.Host); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]bootstrap.Host)
}
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetNebulaCAPublicCredentials provides a mock function with given fields: _a0
func (_m *MockNetwork) GetNebulaCAPublicCredentials(_a0 context.Context) (nebula.CAPublicCredentials, error) {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for GetNebulaCAPublicCredentials")
}
var r0 nebula.CAPublicCredentials
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (nebula.CAPublicCredentials, error)); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(context.Context) nebula.CAPublicCredentials); ok {
r0 = rf(_a0)
} else {
r0 = ret.Get(0).(nebula.CAPublicCredentials)
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetNetworkCreationParams provides a mock function with given fields: _a0 // GetNetworkCreationParams provides a mock function with given fields: _a0
func (_m *MockNetwork) GetNetworkCreationParams(_a0 context.Context) (bootstrap.CreationParams, error) { func (_m *MockNetwork) GetNetworkCreationParams(_a0 context.Context) (bootstrap.CreationParams, error) {
ret := _m.Called(_a0) ret := _m.Called(_a0)

View File

@ -10,6 +10,8 @@ import (
"isle/daemon/network" "isle/daemon/network"
"isle/nebula" "isle/nebula"
"net/http" "net/http"
"dev.mediocregopher.com/mediocre-go-lib.git/mlog"
) )
// RPC defines the methods which the Daemon exposes over RPC (via the RPCHandler // RPC defines the methods which the Daemon exposes over RPC (via the RPCHandler
@ -53,12 +55,13 @@ type RPC interface {
network.RPC network.RPC
} }
// RPCHandler returns a jsonrpc2.Handler which will use the Daemon to serve all // NewRPCHandler returns a jsonrpc2.Handler which will use the given RPC to
// methods defined on the RPC interface. // serve all methods defined on the interface.
func (d *Daemon) RPCHandler() jsonrpc2.Handler { func NewRPCHandler(
rpc := RPC(d) logger *mlog.Logger, rpc RPC,
) jsonrpc2.Handler {
return jsonrpc2.Chain( return jsonrpc2.Chain(
jsonrpc2.NewMLogMiddleware(d.logger.WithNamespace("rpc")), jsonrpc2.NewMLogMiddleware(logger),
jsonrpc2.ExposeServerSideErrorsMiddleware, jsonrpc2.ExposeServerSideErrorsMiddleware,
)( )(
jsonrpc2.NewDispatchHandler(&rpc), jsonrpc2.NewDispatchHandler(&rpc),
@ -68,5 +71,6 @@ func (d *Daemon) RPCHandler() jsonrpc2.Handler {
// HTTPRPCHandler returns an http.Handler which will use the Daemon to serve all // HTTPRPCHandler returns an http.Handler which will use the Daemon to serve all
// methods defined on the RPC interface via the JSONRPC2 protocol. // methods defined on the RPC interface via the JSONRPC2 protocol.
func (d *Daemon) HTTPRPCHandler() http.Handler { func (d *Daemon) HTTPRPCHandler() http.Handler {
return jsonrpc2.NewHTTPHandler(d.RPCHandler()) handler := NewRPCHandler(d.logger.WithNamespace("rpc"), d)
return jsonrpc2.NewHTTPHandler(handler)
} }

View File

@ -94,6 +94,34 @@ func (_m *MockRPC) CreateNetwork(ctx context.Context, name string, domain string
return r0 return r0
} }
// GetBootstrap provides a mock function with given fields: _a0
func (_m *MockRPC) GetBootstrap(_a0 context.Context) (bootstrap.Bootstrap, error) {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for GetBootstrap")
}
var r0 bootstrap.Bootstrap
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (bootstrap.Bootstrap, error)); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(context.Context) bootstrap.Bootstrap); ok {
r0 = rf(_a0)
} else {
r0 = ret.Get(0).(bootstrap.Bootstrap)
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetConfig provides a mock function with given fields: _a0 // GetConfig provides a mock function with given fields: _a0
func (_m *MockRPC) GetConfig(_a0 context.Context) (daecommon.NetworkConfig, error) { func (_m *MockRPC) GetConfig(_a0 context.Context) (daecommon.NetworkConfig, error) {
ret := _m.Called(_a0) ret := _m.Called(_a0)
@ -150,64 +178,6 @@ func (_m *MockRPC) GetGarageClientParams(_a0 context.Context) (network.GarageCli
return r0, r1 return r0, r1
} }
// GetHosts provides a mock function with given fields: _a0
func (_m *MockRPC) GetHosts(_a0 context.Context) ([]bootstrap.Host, error) {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for GetHosts")
}
var r0 []bootstrap.Host
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) ([]bootstrap.Host, error)); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(context.Context) []bootstrap.Host); ok {
r0 = rf(_a0)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]bootstrap.Host)
}
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetNebulaCAPublicCredentials provides a mock function with given fields: _a0
func (_m *MockRPC) GetNebulaCAPublicCredentials(_a0 context.Context) (nebula.CAPublicCredentials, error) {
ret := _m.Called(_a0)
if len(ret) == 0 {
panic("no return value specified for GetNebulaCAPublicCredentials")
}
var r0 nebula.CAPublicCredentials
var r1 error
if rf, ok := ret.Get(0).(func(context.Context) (nebula.CAPublicCredentials, error)); ok {
return rf(_a0)
}
if rf, ok := ret.Get(0).(func(context.Context) nebula.CAPublicCredentials); ok {
r0 = rf(_a0)
} else {
r0 = ret.Get(0).(nebula.CAPublicCredentials)
}
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(_a0)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetNetworks provides a mock function with given fields: _a0 // GetNetworks provides a mock function with given fields: _a0
func (_m *MockRPC) GetNetworks(_a0 context.Context) ([]bootstrap.CreationParams, error) { func (_m *MockRPC) GetNetworks(_a0 context.Context) ([]bootstrap.CreationParams, error) {
ret := _m.Called(_a0) ret := _m.Called(_a0)

View File

@ -2,6 +2,7 @@ package nebula
import ( import (
"fmt" "fmt"
"reflect"
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
) )
@ -37,14 +38,22 @@ func (c Certificate) Unwrap() *cert.NebulaCertificate {
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (c Certificate) MarshalText() ([]byte, error) { func (c Certificate) MarshalText() ([]byte, error) {
if reflect.DeepEqual(c, Certificate{}) {
return []byte(""), nil
}
return c.inner.MarshalToPEM() return c.inner.MarshalToPEM()
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (c *Certificate) UnmarshalText(b []byte) error { func (c *Certificate) UnmarshalText(b []byte) error {
if len(b) == 0 {
*c = Certificate{}
return nil
}
nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b) nebCrt, _, err := cert.UnmarshalNebulaCertificateFromPEM(b)
if err != nil { if err != nil {
return err return fmt.Errorf("unmarshaling nebula certificate from PEM: %w", err)
} }
c.inner = *nebCrt c.inner = *nebCrt
return nil return nil

View File

@ -21,19 +21,31 @@ type EncryptingPublicKey struct{ inner *ecdh.PublicKey }
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (pk EncryptingPublicKey) MarshalText() ([]byte, error) { func (pk EncryptingPublicKey) MarshalText() ([]byte, error) {
return encodeWithPrefix(encPubKeyPrefix, pk.inner.Bytes()), nil if pk == (EncryptingPublicKey{}) {
return []byte(""), nil
}
return encodeWithPrefix(encPubKeyPrefix, pk.Bytes()), nil
} }
// Bytes returns the raw bytes of the EncryptingPublicKey. // Bytes returns the raw bytes of the EncryptingPublicKey, or nil if it is the
// zero value.
func (k EncryptingPublicKey) Bytes() []byte { func (k EncryptingPublicKey) Bytes() []byte {
if k == (EncryptingPublicKey{}) {
return nil
}
return k.inner.Bytes() return k.inner.Bytes()
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error { func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*pk = EncryptingPublicKey{}
return nil
}
b, err := decodeWithPrefix(encPubKeyPrefix, b) b, err := decodeWithPrefix(encPubKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling encrypting public key: %w", err)
} }
if pk.inner, err = x25519.NewPublicKey(b); err != nil { if pk.inner, err = x25519.NewPublicKey(b); err != nil {
@ -48,7 +60,7 @@ func (pk *EncryptingPublicKey) UnmarshalText(b []byte) error {
func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error { func (pk *EncryptingPublicKey) UnmarshalNebulaPEM(b []byte) error {
b, _, err := cert.UnmarshalX25519PublicKey(b) b, _, err := cert.UnmarshalX25519PublicKey(b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling nebula PEM as encrypting public key: %w", err)
} }
if pk.inner, err = x25519.NewPublicKey(b); err != nil { if pk.inner, err = x25519.NewPublicKey(b); err != nil {
@ -86,19 +98,31 @@ func (k EncryptingPrivateKey) PublicKey() EncryptingPublicKey {
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (k EncryptingPrivateKey) MarshalText() ([]byte, error) { func (k EncryptingPrivateKey) MarshalText() ([]byte, error) {
return encodeWithPrefix(encPrivKeyPrefix, k.inner.Bytes()), nil if k == (EncryptingPrivateKey{}) {
return []byte(""), nil
}
return encodeWithPrefix(encPrivKeyPrefix, k.Bytes()), nil
} }
// Bytes returns the raw bytes of the EncryptingPrivateKey. // Bytes returns the raw bytes of the EncryptingPrivateKey, or nil if it is the
// zero value.
func (k EncryptingPrivateKey) Bytes() []byte { func (k EncryptingPrivateKey) Bytes() []byte {
if k == (EncryptingPrivateKey{}) {
return nil
}
return k.inner.Bytes() return k.inner.Bytes()
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error { func (k *EncryptingPrivateKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*k = EncryptingPrivateKey{}
return nil
}
b, err := decodeWithPrefix(encPrivKeyPrefix, b) b, err := decodeWithPrefix(encPrivKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling encrypting private key: %w", err)
} }
if k.inner, err = x25519.NewPrivateKey(b); err != nil { if k.inner, err = x25519.NewPrivateKey(b); err != nil {

View File

@ -1,6 +1,7 @@
package nebula package nebula
import ( import (
"bytes"
"crypto" "crypto"
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
@ -47,13 +48,23 @@ func Sign[T any](v T, k SigningPrivateKey) (Signed[T], error) {
return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)}) return json.Marshal(signed[T]{Signature: sig, Body: json.RawMessage(b)})
} }
var jsonNull = []byte("null")
// MarshalJSON implements the json.Marshaler interface. // MarshalJSON implements the json.Marshaler interface.
func (s Signed[T]) MarshalJSON() ([]byte, error) { func (s Signed[T]) MarshalJSON() ([]byte, error) {
if s == nil {
return jsonNull, nil
}
return []byte(s), nil return []byte(s), nil
} }
// UnmarshalJSON implements the json.Unmarshaler interface. // UnmarshalJSON implements the json.Unmarshaler interface.
func (s *Signed[T]) UnmarshalJSON(b []byte) error { func (s *Signed[T]) UnmarshalJSON(b []byte) error {
if bytes.Equal(b, jsonNull) {
*s = nil
return nil
}
*s = b *s = b
return nil return nil
} }

View File

@ -34,7 +34,7 @@ func TestSigned(t *testing.T) {
_, err = signedB.Unwrap(hostPubCredsB.SigningKey) _, err = signedB.Unwrap(hostPubCredsB.SigningKey)
if !errors.Is(err, ErrInvalidSignature) { if !errors.Is(err, ErrInvalidSignature) {
t.Fatalf("expected ErrInvalidSignature but got %v", err) t.Fatalf("expected ErrInvalidSignature but got: %v", err)
} }
b, err := signedB.Unwrap(hostPubCredsA.SigningKey) b, err := signedB.Unwrap(hostPubCredsA.SigningKey)

View File

@ -17,14 +17,22 @@ type SigningPrivateKey ed25519.PrivateKey
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (k SigningPrivateKey) MarshalText() ([]byte, error) { func (k SigningPrivateKey) MarshalText() ([]byte, error) {
if k == nil {
return []byte(""), nil
}
return encodeWithPrefix(sigPrivKeyPrefix, k), nil return encodeWithPrefix(sigPrivKeyPrefix, k), nil
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (k *SigningPrivateKey) UnmarshalText(b []byte) error { func (k *SigningPrivateKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*k = SigningPrivateKey{}
return nil
}
b, err := decodeWithPrefix(sigPrivKeyPrefix, b) b, err := decodeWithPrefix(sigPrivKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling signing private key: %w", err)
} }
*k = SigningPrivateKey(b) *k = SigningPrivateKey(b)
@ -45,14 +53,22 @@ type SigningPublicKey ed25519.PublicKey
// MarshalText implements the encoding.TextMarshaler interface. // MarshalText implements the encoding.TextMarshaler interface.
func (pk SigningPublicKey) MarshalText() ([]byte, error) { func (pk SigningPublicKey) MarshalText() ([]byte, error) {
if pk == nil {
return []byte(""), nil
}
return encodeWithPrefix(sigPubKeyPrefix, pk), nil return encodeWithPrefix(sigPubKeyPrefix, pk), nil
} }
// UnmarshalText implements the encoding.TextUnmarshaler interface. // UnmarshalText implements the encoding.TextUnmarshaler interface.
func (pk *SigningPublicKey) UnmarshalText(b []byte) error { func (pk *SigningPublicKey) UnmarshalText(b []byte) error {
if len(b) == 0 {
*pk = SigningPublicKey{}
return nil
}
b, err := decodeWithPrefix(sigPubKeyPrefix, b) b, err := decodeWithPrefix(sigPubKeyPrefix, b)
if err != nil { if err != nil {
return fmt.Errorf("unmarshaling: %w", err) return fmt.Errorf("unmarshaling signing public key: %w", err)
} }
*pk = SigningPublicKey(b) *pk = SigningPublicKey(b)