diff --git a/go/bootstrap/bootstrap.go b/go/bootstrap/bootstrap.go index 9b61aba..6160928 100644 --- a/go/bootstrap/bootstrap.go +++ b/go/bootstrap/bootstrap.go @@ -12,6 +12,7 @@ import ( "net/netip" "path/filepath" "sort" + "strings" "dev.mediocregopher.com/mediocre-go-lib.git/mctx" ) @@ -52,6 +53,42 @@ func (p CreationParams) Annotate(aa mctx.Annotations) { aa["networkDomain"] = p.Domain } +// Matches returns true if the given string matches some aspect of the +// CreationParams. +func (p CreationParams) Matches(str string) bool { + if strings.HasPrefix(p.ID, str) { + return true + } + + if strings.EqualFold(p.Name, str) { + return true + } + + if strings.EqualFold(p.Domain, str) { + return true + } + + return false +} + +// Conflicts returns true if either CreationParams has some parameter which +// overlaps with that of the other. +func (p CreationParams) Conflicts(p2 CreationParams) bool { + if p.ID == p2.ID { + return true + } + + if strings.EqualFold(p.Name, p2.Name) { + return true + } + + if strings.EqualFold(p.Domain, p2.Domain) { + return true + } + + return false +} + // Bootstrap contains all information which is needed by a host daemon to join a // network on boot. type Bootstrap struct { diff --git a/go/daemon/config.go b/go/daemon/config.go index b23ff74..62c3641 100644 --- a/go/daemon/config.go +++ b/go/daemon/config.go @@ -47,25 +47,18 @@ func pickNetworkConfig( ) ( daecommon.NetworkConfig, bool, ) { - if c, ok := daemonConfig.Networks[creationParams.ID]; ok { - return c, true - } - - if c, ok := daemonConfig.Networks[creationParams.Name]; ok { - return c, true - } - - if c, ok := daemonConfig.Networks[creationParams.Domain]; ok { - return c, true - } - - { // DEPRECATED - c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID] - if len(daemonConfig.Networks) == 1 && ok { + if len(daemonConfig.Networks) == 1 { // DEPRECATED + if c, ok := daemonConfig.Networks[daecommon.DeprecatedNetworkID]; ok { return c, true } } + for searchStr, networkConfig := range daemonConfig.Networks { + if creationParams.Matches(searchStr) { + return networkConfig, true + } + } + return daecommon.NetworkConfig{}, false } diff --git a/go/daemon/ctx.go b/go/daemon/ctx.go new file mode 100644 index 0000000..5f001e6 --- /dev/null +++ b/go/daemon/ctx.go @@ -0,0 +1,20 @@ +package daemon + +import ( + "context" + "isle/daemon/jsonrpc2" +) + +const metaKeyNetworkSearchStr = "daemon.networkSearchStr" + +// WithNetwork returns the Context so that, when used against a daemon RPC +// endpoint, the endpoint knows which network is being targetted for the call. +// The network can be identified by its ID, name, or domain. +func WithNetwork(ctx context.Context, searchStr string) context.Context { + return jsonrpc2.WithMeta(ctx, metaKeyNetworkSearchStr, searchStr) +} + +func getNetworkSearchStr(ctx context.Context) string { + v, _ := jsonrpc2.GetMeta(ctx)[metaKeyNetworkSearchStr].(string) + return v +} diff --git a/go/daemon/daemon.go b/go/daemon/daemon.go index baf6a4a..d4d71cb 100644 --- a/go/daemon/daemon.go +++ b/go/daemon/daemon.go @@ -65,8 +65,8 @@ type Daemon struct { networksStateDir toolkit.Dir networksRuntimeDir toolkit.Dir - l sync.RWMutex - network network.Network + l sync.RWMutex + networks map[string]network.Network } // New initializes and returns a Daemon. @@ -96,6 +96,7 @@ func New( daemonConfig: daemonConfig, envBinDirPath: envBinDirPath, opts: opts, + networks: map[string]network.Network{}, } { @@ -113,18 +114,14 @@ func New( } } - loadableNetworks, err := LoadableNetworks(d.networksStateDir) + loadableNetworks, err := loadableNetworks(d.networksStateDir) if err != nil { return nil, fmt.Errorf("listing loadable networks: %w", err) } - if len(loadableNetworks) > 1 { - return nil, fmt.Errorf( - "more then one loadable Network found: %+v", loadableNetworks, - ) - } else if len(loadableNetworks) == 1 { - id := loadableNetworks[0].ID - ctx = mctx.WithAnnotator(ctx, loadableNetworks[0]) + for _, creationParams := range loadableNetworks { + id := creationParams.ID + ctx = mctx.WithAnnotator(ctx, creationParams) networkStateDir, networkRuntimeDir, err := networkDirs( d.networksStateDir, d.networksRuntimeDir, id, true, @@ -135,9 +132,9 @@ func New( ) } - networkConfig, _ := pickNetworkConfig(daemonConfig, loadableNetworks[0]) + networkConfig, _ := pickNetworkConfig(daemonConfig, creationParams) - d.network, err = network.Load( + d.networks[id], err = network.Load( ctx, logger.WithNamespace("network"), id, @@ -187,7 +184,9 @@ func (d *Daemon) CreateNetwork( d.l.Lock() defer d.l.Unlock() - if d.network != nil { + if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil { + return fmt.Errorf("checking if already joined to network: %w", err) + } else if joined { return ErrAlreadyJoined } @@ -222,7 +221,7 @@ func (d *Daemon) CreateNetwork( } d.logger.Info(ctx, "Network created successfully") - d.network = n + d.networks[creationParams.ID] = n return nil } @@ -245,7 +244,9 @@ func (d *Daemon) JoinNetwork( d.l.Lock() defer d.l.Unlock() - if d.network != nil { + if joined, err := alreadyJoined(ctx, d.networks, creationParams); err != nil { + return fmt.Errorf("checking if already joined to network: %w", err) + } else if joined { return ErrAlreadyJoined } @@ -278,7 +279,7 @@ func (d *Daemon) JoinNetwork( } d.logger.Info(ctx, "Network joined successfully") - d.network = n + d.networks[networkID] = n return nil } @@ -292,12 +293,13 @@ func withNetwork[Res any]( d.l.RLock() defer d.l.RUnlock() - if d.network == nil { + network, err := pickNetwork(ctx, d.networks, d.networksStateDir) + if err != nil { var zero Res - return zero, ErrNoNetwork + return zero, nil } - return fn(ctx, d.network) + return fn(ctx, network) } // GetHost implements the method for the network.RPC interface. @@ -417,32 +419,28 @@ func (d *Daemon) Shutdown() error { d.l.Lock() defer d.l.Unlock() - if d.network != nil { - return d.network.Shutdown() + var ( + errCh = make(chan error, len(d.networks)) + errs []error + ) + + for id := range d.networks { + var ( + id = id + n = d.networks[id] + ) + + go func() { + if err := n.Shutdown(); err != nil { + errCh <- fmt.Errorf("shutting down network %q: %w", id, err) + } + errCh <- nil + }() } - return nil - //var ( - // errCh = make(chan error, len(d.networks)) - // errs []error - //) + for range cap(errCh) { + errs = append(errs, <-errCh) + } - //for id := range d.networks { - // id := id - // n := d.networks[id] - // go func() { - // if err := n.Shutdown(); err != nil { - // errCh <- fmt.Errorf("shutting down network %q: %w", id, err) - // } - // errCh <- nil - // }() - //} - - //for range cap(errCh) { - // if err := <-errCh; err != nil { - // errs = append(errs, err) - // } - //} - - //return errors.Join(errs...) + return errors.Join(errs...) } diff --git a/go/daemon/errors.go b/go/daemon/errors.go index aa13083..1ddd242 100644 --- a/go/daemon/errors.go +++ b/go/daemon/errors.go @@ -8,6 +8,8 @@ import ( const ( errCodeNoNetwork = daecommon.ErrorCodeRangeDaemon + iota errCodeAlreadyJoined + errCodeNoMatchingNetworks + errCodeMultipleMatchingNetworks ) var ( @@ -16,6 +18,19 @@ var ( ErrNoNetwork = jsonrpc2.NewError(errCodeNoNetwork, "No network configured") // ErrAlreadyJoined is returned when the daemon is instructed to create or - // join a new network, but it is already joined to a network. + // join a new network, but it is already joined to that network. ErrAlreadyJoined = jsonrpc2.NewError(errCodeAlreadyJoined, "Already joined to a network") + + // ErrNoMatchingNetworks is returned if the search string didn't match any + // networks. + ErrNoMatchingNetworks = jsonrpc2.NewError( + errCodeNoMatchingNetworks, "No networks matched the search string", + ) + + // ErrMultipleMatchingNetworks is returned if the search string matched + // multiple networks. + ErrMultipleMatchingNetworks = jsonrpc2.NewError( + errCodeMultipleMatchingNetworks, + "Multiple networks matched the search string", + ) ) diff --git a/go/daemon/jsonrpc2/client.go b/go/daemon/jsonrpc2/client.go index fba9b42..fb930fa 100644 --- a/go/daemon/jsonrpc2/client.go +++ b/go/daemon/jsonrpc2/client.go @@ -13,7 +13,11 @@ type Client interface { // receiver pointer, unless it is nil in which case the result will be // discarded. // + // If the Context was produced using WithMeta then that metadata will be + // carried with the request to the server via the Meta field of the + // RequestParams. + // // If an error result is returned from the server that will be returned as // an Error struct. - Call(ctx context.Context, rcv any, method string, params ...any) error + Call(ctx context.Context, rcv any, method string, args ...any) error } diff --git a/go/daemon/jsonrpc2/client_http.go b/go/daemon/jsonrpc2/client_http.go index 6b2845c..74fea17 100644 --- a/go/daemon/jsonrpc2/client_http.go +++ b/go/daemon/jsonrpc2/client_http.go @@ -49,14 +49,14 @@ func NewUnixHTTPClient(unixSocketPath, reqPath string) Client { } func (c *httpClient) Call( - ctx context.Context, rcv any, method string, params ...any, + ctx context.Context, rcv any, method string, args ...any, ) error { var ( body = new(bytes.Buffer) enc = json.NewEncoder(body) ) - id, err := encodeRequest(enc, method, params) + id, err := encodeRequest(ctx, enc, method, args) if err != nil { return fmt.Errorf("encoding request: %w", err) } diff --git a/go/daemon/jsonrpc2/client_rw.go b/go/daemon/jsonrpc2/client_rw.go index 71f7e8c..25c6454 100644 --- a/go/daemon/jsonrpc2/client_rw.go +++ b/go/daemon/jsonrpc2/client_rw.go @@ -19,9 +19,9 @@ func NewReadWriterClient(rw io.ReadWriter) Client { } func (c rwClient) Call( - ctx context.Context, rcv any, method string, params ...any, + ctx context.Context, rcv any, method string, args ...any, ) error { - id, err := encodeRequest(c.enc, method, params) + id, err := encodeRequest(ctx, c.enc, method, args) if err != nil { return fmt.Errorf("encoding request: %w", err) } diff --git a/go/daemon/jsonrpc2/dispatcher.go b/go/daemon/jsonrpc2/dispatcher.go index a06fbdb..7d98f1e 100644 --- a/go/daemon/jsonrpc2/dispatcher.go +++ b/go/daemon/jsonrpc2/dispatcher.go @@ -18,19 +18,21 @@ type methodDispatchFunc func(context.Context, Request) (any, error) func newMethodDispatchFunc( method reflect.Value, ) methodDispatchFunc { - paramTs := make([]reflect.Type, method.Type().NumIn()-1) - for i := range paramTs { - paramTs[i] = method.Type().In(i + 1) + argTs := make([]reflect.Type, method.Type().NumIn()-1) + for i := range argTs { + argTs[i] = method.Type().In(i + 1) } return func(ctx context.Context, req Request) (any, error) { - callVals := make([]reflect.Value, 0, len(paramTs)+1) + ctx = context.WithValue(ctx, ctxKeyMeta(0), req.Params.Meta) + + callVals := make([]reflect.Value, 0, len(argTs)+1) callVals = append(callVals, reflect.ValueOf(ctx)) - for i, paramT := range paramTs { - paramPtrV := reflect.New(paramT) + for i, argT := range argTs { + argPtrV := reflect.New(argT) - err := json.Unmarshal(req.Params[i], paramPtrV.Interface()) + err := json.Unmarshal(req.Params.Args[i], argPtrV.Interface()) if err != nil { // The JSON has already been validated, so this is not an // errCodeParse situation. We assume it's an invalid param then, @@ -38,13 +40,13 @@ func newMethodDispatchFunc( // returning an Error of its own. if !errors.As(err, new(Error)) { err = NewInvalidParamsError( - "JSON unmarshaling param %d into %T: %v", i, paramT, err, + "JSON unmarshaling arg %d into %T: %v", i, argT, err, ) } return nil, err } - callVals = append(callVals, paramPtrV.Elem()) + callVals = append(callVals, argPtrV.Elem()) } var ( @@ -82,7 +84,8 @@ type dispatcher struct { // MethodName(context.Context, ...ParamType) (ResponseType, error) // MethodName(context.Context, ...ParamType) error // -// will be available via RPC calls. +// will be available via RPC calls. Any Meta data in the request can be obtained +// within the method handler by calling GetMeta on the method's Context. func NewDispatchHandler(i any) Handler { v := reflect.ValueOf(i) if v.Kind() != reflect.Pointer { diff --git a/go/daemon/jsonrpc2/handler.go b/go/daemon/jsonrpc2/handler.go index ded4b6b..c452036 100644 --- a/go/daemon/jsonrpc2/handler.go +++ b/go/daemon/jsonrpc2/handler.go @@ -57,15 +57,16 @@ func NewMLogMiddleware(logger *mlog.Logger) Middleware { ctx, "rpcRequestID", req.ID, "rpcRequestMethod", req.Method, + "rpcRequestMeta", req.Params.Meta, ) if logger.MaxLevel() >= mlog.LevelDebug.Int() { ctx := ctx - for i := range req.Params { + for i := range req.Params.Args { ctx = mctx.Annotate( ctx, - fmt.Sprintf("rpcRequestParam%d", i), - string(req.Params[i]), + fmt.Sprintf("rpcRequestArgs%d", i), + string(req.Params.Args[i]), ) } logger.Debug(ctx, "Handling RPC request") diff --git a/go/daemon/jsonrpc2/jsonrpc2.go b/go/daemon/jsonrpc2/jsonrpc2.go index 83a47f5..648180c 100644 --- a/go/daemon/jsonrpc2/jsonrpc2.go +++ b/go/daemon/jsonrpc2/jsonrpc2.go @@ -3,6 +3,7 @@ package jsonrpc2 import ( + "context" "crypto/rand" "encoding/json" "fmt" @@ -11,12 +12,20 @@ import ( const version = "2.0" +// RequestParams are the parameters passed in a Request. Meta contains +// information that is not directly related to what is being requested, while +// Args are the request's actual arguments. +type RequestParams struct { + Meta map[string]any `json:"meta,omitempty"` + Args []json.RawMessage `json:"args,omitempty"` +} + // Request encodes an RPC request according to the spec. type Request struct { - Version string `json:"jsonrpc"` // must be "2.0" - Method string `json:"method"` - Params []json.RawMessage `json:"params,omitempty"` - ID string `json:"id"` + Version string `json:"jsonrpc"` // must be "2.0" + Method string `json:"method"` + Params RequestParams `json:"params,omitempty"` + ID string `json:"id"` } type response[Result any] struct { @@ -37,18 +46,18 @@ func newID() string { // encodeRequest writes a request to an io.Writer, returning the ID of the // request. func encodeRequest( - enc *json.Encoder, method string, params []any, + ctx context.Context, enc *json.Encoder, method string, args []any, ) ( string, error, ) { var ( - paramsBs = make([]json.RawMessage, len(params)) - err error + argsBs = make([]json.RawMessage, len(args)) + err error ) - for i := range params { - paramsBs[i], err = json.Marshal(params[i]) + for i := range args { + argsBs[i], err = json.Marshal(args[i]) if err != nil { - return "", fmt.Errorf("encoding param %d as JSON: %w", i, err) + return "", fmt.Errorf("encoding arg %d as JSON: %w", i, err) } } @@ -57,8 +66,11 @@ func encodeRequest( reqEnvelope = Request{ Version: version, Method: method, - Params: paramsBs, - ID: id, + Params: RequestParams{ + Meta: GetMeta(ctx), + Args: argsBs, + }, + ID: id, } ) diff --git a/go/daemon/jsonrpc2/jsonrpc2_test.go b/go/daemon/jsonrpc2/jsonrpc2_test.go index fbcc245..2658e64 100644 --- a/go/daemon/jsonrpc2/jsonrpc2_test.go +++ b/go/daemon/jsonrpc2/jsonrpc2_test.go @@ -48,6 +48,16 @@ func (i dividerImpl) Noop(ctx context.Context) error { return nil } +func (i dividerImpl) Divide2FromMeta(ctx context.Context) (int, error) { + var ( + meta = GetMeta(ctx) + top = int(meta["top"].(float64)) + bottom = int(meta["bottom"].(float64)) + ) + + return i.Divide2(ctx, top, bottom) +} + func (dividerImpl) Hidden(ctx context.Context, p struct{}) (int, error) { return 0, errors.New("Shouldn't be possible to call this!") } @@ -57,6 +67,7 @@ type divider interface { Divide(ctx context.Context, p DivideParams) (int, error) One(ctx context.Context) (int, error) Noop(ctx context.Context) error + Divide2FromMeta(ctx context.Context) (int, error) } var testHandler = func() Handler { @@ -124,6 +135,19 @@ func testClient(t *testing.T, client Client) { } }) + t.Run("success/meta", func(t *testing.T) { + ctx = WithMeta(ctx, "top", 6) + ctx = WithMeta(ctx, "bottom", 2) + + var res int + err := client.Call(ctx, &res, "Divide2FromMeta") + if err != nil { + t.Fatal(err) + } else if res != 3 { + t.Fatalf("expected 2, got %d", res) + } + }) + t.Run("err/application", func(t *testing.T) { err := client.Call(ctx, nil, "Divide", DivideParams{}) if !errors.Is(err, ErrDivideByZero) { diff --git a/go/daemon/jsonrpc2/meta.go b/go/daemon/jsonrpc2/meta.go new file mode 100644 index 0000000..61006fc --- /dev/null +++ b/go/daemon/jsonrpc2/meta.go @@ -0,0 +1,30 @@ +package jsonrpc2 + +import ( + "context" + "maps" +) + +type ctxKeyMeta int + +// WithMeta returns a Context where the given key will be set to the given value +// in the Meta field of all JSONRPC2 requests made using Clients from this +// package. +func WithMeta(ctx context.Context, key string, value any) context.Context { + m, _ := ctx.Value(ctxKeyMeta(0)).(map[string]any) + if m == nil { + m = map[string]any{} + } else { + m = maps.Clone(m) + } + + m[key] = value + return context.WithValue(ctx, ctxKeyMeta(0), m) +} + +// GetMeta returns all key/values which have been set on the Context using +// WithMeta. This may return nil if WithMeta was never called. +func GetMeta(ctx context.Context) map[string]any { + m, _ := ctx.Value(ctxKeyMeta(0)).(map[string]any) + return m +} diff --git a/go/daemon/network.go b/go/daemon/network.go index 648f73b..72f2653 100644 --- a/go/daemon/network.go +++ b/go/daemon/network.go @@ -1,6 +1,7 @@ package daemon import ( + "context" "fmt" "isle/bootstrap" "isle/daemon/network" @@ -41,9 +42,9 @@ func networkDirs( return } -// LoadableNetworks returns the CreationParams for each Network which is able to +// loadableNetworks returns the CreationParams for each Network which is able to // be loaded. -func LoadableNetworks( +func loadableNetworks( networksStateDir toolkit.Dir, ) ( []bootstrap.CreationParams, error, @@ -71,3 +72,60 @@ func LoadableNetworks( return creationParams, nil } + +func pickNetwork( + ctx context.Context, + networks map[string]network.Network, + networksStateDir toolkit.Dir, +) ( + network.Network, error, +) { + if len(networks) == 0 { + return nil, ErrNoNetwork + } + + creationParams, err := loadableNetworks(networksStateDir) + if err != nil { + return nil, fmt.Errorf("getting loadable networks: %w", err) + } + + var ( + networkSearchStr = getNetworkSearchStr(ctx) + matchingNetworkIDs = make([]string, 0, len(networks)) + ) + + for _, creationParam := range creationParams { + if networkSearchStr == "" || creationParam.Matches(networkSearchStr) { + matchingNetworkIDs = append(matchingNetworkIDs, creationParam.ID) + } + } + + if len(matchingNetworkIDs) == 0 { + return nil, ErrNoMatchingNetworks + } else if len(matchingNetworkIDs) > 1 { + return nil, ErrMultipleMatchingNetworks + } + + return networks[matchingNetworkIDs[0]], nil +} + +func alreadyJoined( + ctx context.Context, + networks map[string]network.Network, + creationParams bootstrap.CreationParams, +) ( + bool, error, +) { + for networkID, network := range networks { + existingCreationParams, err := network.GetNetworkCreationParams(ctx) + if err != nil { + return false, fmt.Errorf( + "getting creation params of network %q: %w", networkID, err, + ) + } else if existingCreationParams.Conflicts(creationParams) { + return true, nil + } + } + + return false, nil +} diff --git a/go/daemon/rpc.go b/go/daemon/rpc.go index 1dfa14a..ac35969 100644 --- a/go/daemon/rpc.go +++ b/go/daemon/rpc.go @@ -25,6 +25,13 @@ type RPC interface { // All network.RPC methods are automatically implemented by Daemon using the // currently joined network. If no network is joined then any call to these // methods will return ErrNoNetwork. + // + // All calls to these methods must be accompanied with a context produced by + // WithNetwork, in order to choose the network. These methods may return + // these errors, in addition to those documented on the individual methods: + // + // - ErrNoMatchingNetworks + // - ErrMultipleMatchingNetworks network.RPC }