isle/go/cmd/entrypoint/vpn_firewall.go

366 lines
8.4 KiB
Go
Raw Permalink Normal View History

package main
import (
"errors"
"fmt"
"isle/daemon"
"isle/daemon/daecommon"
2024-12-10 14:17:07 +00:00
"slices"
2024-12-10 12:52:57 +00:00
"strings"
2024-12-10 14:17:07 +00:00
"golang.org/x/exp/maps"
)
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
2024-12-10 12:52:57 +00:00
// vpnFirewallGetConfigWithStaged returns the network config along with any
// staged firewall configuration changes, if there are any.
func vpnFirewallGetConfig(
ctx subCmdCtx, daemonRPC daemon.RPC,
) (
daecommon.NetworkConfig, error,
) {
config, err := daemonRPC.GetConfig(ctx)
2024-12-10 12:52:57 +00:00
if err != nil {
return daecommon.NetworkConfig{}, err
}
var firewallConfig daecommon.ConfigFirewall
2024-12-13 09:56:43 +00:00
if ok, err := ctx.getChangeStager().get(
2024-12-10 12:52:57 +00:00
&firewallConfig, vpnFirewallConfigChangeStagerName,
); err != nil {
return daecommon.NetworkConfig{}, fmt.Errorf(
"getting staged VPN firewall config: %w", err,
)
} else if ok {
config.VPN.Firewall = firewallConfig
}
return config, nil
}
func vpnFirewallRuleSetToFn(
str string,
) (
func(*daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule,
error,
) {
switch strings.ToLower(str) {
case "outbound":
return func(c *daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule {
return &c.Outbound
}, nil
case "inbound":
return func(c *daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule {
return &c.Inbound
}, nil
default:
return nil, fmt.Errorf("must be 'inbound' or 'outbound'")
}
}
var subCmdVPNFirewallAdd = subCmd{
name: "add",
descr: "Add a new firewall rule to the staged configuration",
do: func(ctx subCmdCtx) error {
to := ctx.flags.String(
"to",
"",
"Which set of rules to add to, either 'inbound' or 'outbound'",
)
var rule daecommon.ConfigFirewallRule
ctx.flags.StringVar(
&rule.Port, "port", "any", "Port number or range to allow",
)
ctx.flags.StringVar(
&rule.Proto,
"proto",
"any",
"Protocol to allow. Can be 'tcp', 'udp', 'icmp', or 'any'",
)
ctx.flags.StringVar(
&rule.Host,
"host",
"",
"Name of host to allow. Defaults to 'any' if --groups is not given",
)
ctx.flags.StringSliceVar(
&rule.Groups,
"groups",
nil,
"One or more comma-separated group names to allow",
)
ctx, err := ctx.withParsedFlags(nil)
2024-12-10 12:52:57 +00:00
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
if *to == "" {
return errors.New("--to is required")
}
ruleSetFn, err := vpnFirewallRuleSetToFn(*to)
if err != nil {
return fmt.Errorf("invalid --to value %q: %w", *to, err)
}
if rule.Host != "" && len(rule.Groups) > 0 {
return fmt.Errorf("--host and --groups are mutually exclusive")
} else if rule.Host == "" && len(rule.Groups) == 0 {
rule.Host = "any"
}
daemonRPC, err := ctx.newDaemonRPC()
if err != nil {
return fmt.Errorf("creating daemon RPC client: %w", err)
}
defer daemonRPC.Close()
config, err := vpnFirewallGetConfig(ctx, daemonRPC)
2024-12-10 12:52:57 +00:00
if err != nil {
return fmt.Errorf("getting network config: %w", err)
}
ruleSet := ruleSetFn(&config.VPN.Firewall)
*ruleSet = append(*ruleSet, rule)
if err := config.Validate(); err != nil {
return err
}
2024-12-13 09:56:43 +00:00
if err := ctx.getChangeStager().set(
2024-12-10 12:52:57 +00:00
config.VPN.Firewall, vpnFirewallConfigChangeStagerName,
); err != nil {
return fmt.Errorf("staging changes: %w", err)
}
return nil
},
}
var subCmdVPNFirewallCommit = subCmd{
name: "commit",
descr: "Commit all changes which were staged using 'add' or 'remove'",
do: func(ctx subCmdCtx) error {
ctx, err := ctx.withParsedFlags(nil)
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
var firewallConfig daecommon.ConfigFirewall
2024-12-13 09:56:43 +00:00
ok, err := ctx.getChangeStager().get(
&firewallConfig, vpnFirewallConfigChangeStagerName,
)
if err != nil {
return fmt.Errorf("checking for staged changes: %w", err)
} else if !ok {
return errors.New("no changes staged, use 'add' or 'remove' to stage changes")
}
daemonRPC, err := ctx.newDaemonRPC()
if err != nil {
return fmt.Errorf("creating daemon RPC client: %w", err)
}
defer daemonRPC.Close()
config, err := daemonRPC.GetConfig(ctx)
if err != nil {
return fmt.Errorf("getting network config: %w", err)
}
config.VPN.Firewall = firewallConfig
return daemonRPC.SetConfig(ctx, config)
},
}
2024-12-10 14:17:07 +00:00
var subCmdVPNFirewallRemove = subCmd{
name: "remove",
descr: "Remove one or more firewall rules from the staged configuration",
do: func(ctx subCmdCtx) error {
from := ctx.flags.String(
"from",
"",
"Which set of rules to remove from, either 'inbound' or 'outbound'",
)
indexes := ctx.flags.IntSlice(
"indexes",
nil,
"Comma-separated indexes of rules to remove, as returned by 'show --staged'",
2024-12-10 14:17:07 +00:00
)
ctx, err := ctx.withParsedFlags(nil)
2024-12-10 14:17:07 +00:00
if err != nil {
return fmt.Errorf("parsing flags: %w", err)
}
if *from == "" || len(*indexes) == 0 {
return errors.New("--from and --indexes are required")
}
ruleSetFn, err := vpnFirewallRuleSetToFn(*from)
if err != nil {
return fmt.Errorf("invalid --from value %q: %w", *from, err)
}
daemonRPC, err := ctx.newDaemonRPC()
if err != nil {
return fmt.Errorf("creating daemon RPC client: %w", err)
}
defer daemonRPC.Close()
2024-12-10 14:17:07 +00:00
indexSet := map[int]struct{}{}
for _, index := range *indexes {
if index < 0 {
return fmt.Errorf("invalid index %d", index)
}
indexSet[index] = struct{}{}
}
config, err := vpnFirewallGetConfig(ctx, daemonRPC)
2024-12-10 14:17:07 +00:00
if err != nil {
return fmt.Errorf("getting network config: %w", err)
}
var (
ruleSet = ruleSetFn(&config.VPN.Firewall)
filteredRuleSet = make(
[]daecommon.ConfigFirewallRule, 0, len(*ruleSet),
)
)
for i, rule := range *ruleSet {
if _, remove := indexSet[i]; remove {
delete(indexSet, i)
continue
}
filteredRuleSet = append(filteredRuleSet, rule)
}
if len(indexSet) > 0 {
invalidIndexes := maps.Keys(indexSet)
slices.Sort(invalidIndexes)
return fmt.Errorf("invalid index(es): %v", invalidIndexes)
}
*ruleSet = filteredRuleSet
if err := config.Validate(); err != nil {
return err
}
2024-12-13 09:56:43 +00:00
if err := ctx.getChangeStager().set(
2024-12-10 14:17:07 +00:00
config.VPN.Firewall, vpnFirewallConfigChangeStagerName,
); err != nil {
return fmt.Errorf("staging changes: %w", err)
}
return nil
},
}
var subCmdVPNFirewallReset = subCmd{
name: "reset",
descr: "Discard all changes which have been staged",
do: func(ctx subCmdCtx) error {
2024-12-13 09:56:43 +00:00
return ctx.getChangeStager().del(vpnFirewallConfigChangeStagerName)
},
}
type firewallRuleView struct {
Index int `yaml:"index"`
daecommon.ConfigFirewallRule `yaml:",inline"`
}
func newFirewallRuleViews(
rules []daecommon.ConfigFirewallRule,
) []firewallRuleView {
views := make([]firewallRuleView, len(rules))
for i := range rules {
views[i] = firewallRuleView{
Index: i,
ConfigFirewallRule: rules[i],
}
}
return views
}
type firewallView struct {
Outbound []firewallRuleView `yaml:"outbound"`
Inbound []firewallRuleView `yaml:"inbound"`
}
func newFirewallView(firewallConfig daecommon.ConfigFirewall) firewallView {
return firewallView{
Outbound: newFirewallRuleViews(firewallConfig.Outbound),
Inbound: newFirewallRuleViews(firewallConfig.Inbound),
}
}
var subCmdVPNFirewallShow = subCmd{
name: "show",
descr: "Shows the currently configured firewall rules",
do: doWithOutput(func(ctx subCmdCtx) (any, error) {
staged := ctx.flags.Bool(
"staged",
false,
"Return the firewall configuration with staged changes included",
)
ctx, err := ctx.withParsedFlags(nil)
if err != nil {
return nil, fmt.Errorf("parsing flags: %w", err)
}
var (
firewallConfig daecommon.ConfigFirewall
foundStaged bool
)
if *staged {
var err error
2024-12-13 09:56:43 +00:00
if foundStaged, err = ctx.getChangeStager().get(
&firewallConfig, vpnFirewallConfigChangeStagerName,
); err != nil {
return nil, fmt.Errorf("checking for staged changes: %w", err)
}
}
if !foundStaged {
daemonRPC, err := ctx.newDaemonRPC()
if err != nil {
return nil, fmt.Errorf("creating daemon RPC client: %w", err)
}
defer daemonRPC.Close()
config, err := daemonRPC.GetConfig(ctx)
if err != nil {
return nil, fmt.Errorf("getting network config: %w", err)
}
firewallConfig = config.VPN.Firewall
}
return newFirewallView(firewallConfig), nil
}),
}
var subCmdVPNFirewall = subCmd{
name: "firewall",
descr: "Sub-commands related to this host's VPN firewall",
do: func(ctx subCmdCtx) error {
return ctx.doSubCmd(
2024-12-10 12:52:57 +00:00
subCmdVPNFirewallAdd,
subCmdVPNFirewallCommit,
2024-12-10 14:17:07 +00:00
subCmdVPNFirewallRemove,
subCmdVPNFirewallReset,
subCmdVPNFirewallShow,
)
},
}