package main import ( "errors" "fmt" "isle/daemon" "isle/daemon/daecommon" "slices" "strings" "golang.org/x/exp/maps" ) const vpnFirewallConfigChangeStagerName = "vpn-firewall-config" // vpnFirewallGetConfigWithStaged returns the network config along with any // staged firewall configuration changes, if there are any. func vpnFirewallGetConfig( ctx subCmdCtx, daemonRPC daemon.RPC, ) ( daecommon.NetworkConfig, error, ) { config, err := daemonRPC.GetConfig(ctx) if err != nil { return daecommon.NetworkConfig{}, err } var firewallConfig daecommon.ConfigFirewall if ok, err := ctx.getChangeStager().get( &firewallConfig, vpnFirewallConfigChangeStagerName, ); err != nil { return daecommon.NetworkConfig{}, fmt.Errorf( "getting staged VPN firewall config: %w", err, ) } else if ok { config.VPN.Firewall = firewallConfig } return config, nil } func vpnFirewallRuleSetToFn( str string, ) ( func(*daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule, error, ) { switch strings.ToLower(str) { case "outbound": return func(c *daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule { return &c.Outbound }, nil case "inbound": return func(c *daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule { return &c.Inbound }, nil default: return nil, fmt.Errorf("must be 'inbound' or 'outbound'") } } var subCmdVPNFirewallAdd = subCmd{ name: "add", descr: "Add a new firewall rule to the staged configuration", do: func(ctx subCmdCtx) error { to := ctx.flags.String( "to", "", "Which set of rules to add to, either 'inbound' or 'outbound'", ) var rule daecommon.ConfigFirewallRule ctx.flags.StringVar( &rule.Port, "port", "any", "Port number or range to allow", ) ctx.flags.StringVar( &rule.Proto, "proto", "any", "Protocol to allow. Can be 'tcp', 'udp', 'icmp', or 'any'", ) ctx.flags.StringVar( &rule.Host, "host", "", "Name of host to allow. Defaults to 'any' if --groups is not given", ) ctx.flags.StringSliceVar( &rule.Groups, "groups", nil, "One or more comma-separated group names to allow", ) ctx, err := ctx.withParsedFlags(nil) if err != nil { return fmt.Errorf("parsing flags: %w", err) } if *to == "" { return errors.New("--to is required") } ruleSetFn, err := vpnFirewallRuleSetToFn(*to) if err != nil { return fmt.Errorf("invalid --to value %q: %w", *to, err) } if rule.Host != "" && len(rule.Groups) > 0 { return fmt.Errorf("--host and --groups are mutually exclusive") } else if rule.Host == "" && len(rule.Groups) == 0 { rule.Host = "any" } daemonRPC, err := ctx.newDaemonRPC() if err != nil { return fmt.Errorf("creating daemon RPC client: %w", err) } defer daemonRPC.Close() config, err := vpnFirewallGetConfig(ctx, daemonRPC) if err != nil { return fmt.Errorf("getting network config: %w", err) } ruleSet := ruleSetFn(&config.VPN.Firewall) *ruleSet = append(*ruleSet, rule) if err := config.Validate(); err != nil { return err } if err := ctx.getChangeStager().set( config.VPN.Firewall, vpnFirewallConfigChangeStagerName, ); err != nil { return fmt.Errorf("staging changes: %w", err) } return nil }, } var subCmdVPNFirewallCommit = subCmd{ name: "commit", descr: "Commit all changes which were staged using 'add' or 'remove'", do: func(ctx subCmdCtx) error { ctx, err := ctx.withParsedFlags(nil) if err != nil { return fmt.Errorf("parsing flags: %w", err) } var firewallConfig daecommon.ConfigFirewall ok, err := ctx.getChangeStager().get( &firewallConfig, vpnFirewallConfigChangeStagerName, ) if err != nil { return fmt.Errorf("checking for staged changes: %w", err) } else if !ok { return errors.New("no changes staged, use 'add' or 'remove' to stage changes") } daemonRPC, err := ctx.newDaemonRPC() if err != nil { return fmt.Errorf("creating daemon RPC client: %w", err) } defer daemonRPC.Close() config, err := daemonRPC.GetConfig(ctx) if err != nil { return fmt.Errorf("getting network config: %w", err) } config.VPN.Firewall = firewallConfig return daemonRPC.SetConfig(ctx, config) }, } var subCmdVPNFirewallRemove = subCmd{ name: "remove", descr: "Remove one or more firewall rules from the staged configuration", do: func(ctx subCmdCtx) error { from := ctx.flags.String( "from", "", "Which set of rules to remove from, either 'inbound' or 'outbound'", ) indexes := ctx.flags.IntSlice( "indexes", nil, "Comma-separated indexes of rules to remove, as returned by 'show --staged'", ) ctx, err := ctx.withParsedFlags(nil) if err != nil { return fmt.Errorf("parsing flags: %w", err) } if *from == "" || len(*indexes) == 0 { return errors.New("--from and --indexes are required") } ruleSetFn, err := vpnFirewallRuleSetToFn(*from) if err != nil { return fmt.Errorf("invalid --from value %q: %w", *from, err) } daemonRPC, err := ctx.newDaemonRPC() if err != nil { return fmt.Errorf("creating daemon RPC client: %w", err) } defer daemonRPC.Close() indexSet := map[int]struct{}{} for _, index := range *indexes { if index < 0 { return fmt.Errorf("invalid index %d", index) } indexSet[index] = struct{}{} } config, err := vpnFirewallGetConfig(ctx, daemonRPC) if err != nil { return fmt.Errorf("getting network config: %w", err) } var ( ruleSet = ruleSetFn(&config.VPN.Firewall) filteredRuleSet = make( []daecommon.ConfigFirewallRule, 0, len(*ruleSet), ) ) for i, rule := range *ruleSet { if _, remove := indexSet[i]; remove { delete(indexSet, i) continue } filteredRuleSet = append(filteredRuleSet, rule) } if len(indexSet) > 0 { invalidIndexes := maps.Keys(indexSet) slices.Sort(invalidIndexes) return fmt.Errorf("invalid index(es): %v", invalidIndexes) } *ruleSet = filteredRuleSet if err := config.Validate(); err != nil { return err } if err := ctx.getChangeStager().set( config.VPN.Firewall, vpnFirewallConfigChangeStagerName, ); err != nil { return fmt.Errorf("staging changes: %w", err) } return nil }, } var subCmdVPNFirewallReset = subCmd{ name: "reset", descr: "Discard all changes which have been staged", do: func(ctx subCmdCtx) error { return ctx.getChangeStager().del(vpnFirewallConfigChangeStagerName) }, } type firewallRuleView struct { Index int `yaml:"index"` daecommon.ConfigFirewallRule `yaml:",inline"` } func newFirewallRuleViews( rules []daecommon.ConfigFirewallRule, ) []firewallRuleView { views := make([]firewallRuleView, len(rules)) for i := range rules { views[i] = firewallRuleView{ Index: i, ConfigFirewallRule: rules[i], } } return views } type firewallView struct { Outbound []firewallRuleView `yaml:"outbound"` Inbound []firewallRuleView `yaml:"inbound"` InternalOutbound []daecommon.ConfigFirewallRule `yaml:"internal_outbound,omitempty"` InternalInbound []daecommon.ConfigFirewallRule `yaml:"internal_inbound,omitempty"` } func newFirewallView(networkConfig daecommon.NetworkConfig) firewallView { var ( firewallConfig = networkConfig.VPN.Firewall internalOutbound, internalInbound = networkConfig.InternalFirewallRules() ) return firewallView{ Outbound: newFirewallRuleViews(firewallConfig.Outbound), Inbound: newFirewallRuleViews(firewallConfig.Inbound), InternalOutbound: internalOutbound, InternalInbound: internalInbound, } } var subCmdVPNFirewallShow = subCmd{ name: "show", descr: "Shows the currently configured firewall rules", do: doWithOutput(func(ctx subCmdCtx) (any, error) { staged := ctx.flags.Bool( "staged", false, "Return the firewall configuration with staged changes included", ) ctx, err := ctx.withParsedFlags(nil) if err != nil { return nil, fmt.Errorf("parsing flags: %w", err) } daemonRPC, err := ctx.newDaemonRPC() if err != nil { return nil, fmt.Errorf("creating daemon RPC client: %w", err) } defer daemonRPC.Close() config, err := daemonRPC.GetConfig(ctx) if err != nil { return nil, fmt.Errorf("getting network config: %w", err) } if *staged { var firewallConfig daecommon.ConfigFirewall foundStaged, err := ctx.getChangeStager().get( &firewallConfig, vpnFirewallConfigChangeStagerName, ) if err != nil { return nil, fmt.Errorf("checking for staged changes: %w", err) } else if foundStaged { config.VPN.Firewall = firewallConfig } } return newFirewallView(config), nil }), } var subCmdVPNFirewall = subCmd{ name: "firewall", descr: "Sub-commands related to this host's VPN firewall", do: func(ctx subCmdCtx) error { return ctx.doSubCmd( subCmdVPNFirewallAdd, subCmdVPNFirewallCommit, subCmdVPNFirewallRemove, subCmdVPNFirewallReset, subCmdVPNFirewallShow, ) }, }