366 lines
8.4 KiB
Go
366 lines
8.4 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"isle/daemon"
|
|
"isle/daemon/daecommon"
|
|
"slices"
|
|
"strings"
|
|
|
|
"golang.org/x/exp/maps"
|
|
)
|
|
|
|
const vpnFirewallConfigChangeStagerName = "vpn-firewall-config"
|
|
|
|
// vpnFirewallGetConfigWithStaged returns the network config along with any
|
|
// staged firewall configuration changes, if there are any.
|
|
func vpnFirewallGetConfig(
|
|
ctx subCmdCtx, daemonRPC daemon.RPC,
|
|
) (
|
|
daecommon.NetworkConfig, error,
|
|
) {
|
|
config, err := daemonRPC.GetConfig(ctx)
|
|
if err != nil {
|
|
return daecommon.NetworkConfig{}, err
|
|
}
|
|
|
|
var firewallConfig daecommon.ConfigFirewall
|
|
if ok, err := ctx.getChangeStager().get(
|
|
&firewallConfig, vpnFirewallConfigChangeStagerName,
|
|
); err != nil {
|
|
return daecommon.NetworkConfig{}, fmt.Errorf(
|
|
"getting staged VPN firewall config: %w", err,
|
|
)
|
|
} else if ok {
|
|
config.VPN.Firewall = firewallConfig
|
|
}
|
|
|
|
return config, nil
|
|
}
|
|
|
|
func vpnFirewallRuleSetToFn(
|
|
str string,
|
|
) (
|
|
func(*daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule,
|
|
error,
|
|
) {
|
|
switch strings.ToLower(str) {
|
|
case "outbound":
|
|
return func(c *daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule {
|
|
return &c.Outbound
|
|
}, nil
|
|
case "inbound":
|
|
return func(c *daecommon.ConfigFirewall) *[]daecommon.ConfigFirewallRule {
|
|
return &c.Inbound
|
|
}, nil
|
|
default:
|
|
return nil, fmt.Errorf("must be 'inbound' or 'outbound'")
|
|
}
|
|
}
|
|
|
|
var subCmdVPNFirewallAdd = subCmd{
|
|
name: "add",
|
|
descr: "Add a new firewall rule to the staged configuration",
|
|
do: func(ctx subCmdCtx) error {
|
|
to := ctx.flags.String(
|
|
"to",
|
|
"",
|
|
"Which set of rules to add to, either 'inbound' or 'outbound'",
|
|
)
|
|
|
|
var rule daecommon.ConfigFirewallRule
|
|
ctx.flags.StringVar(
|
|
&rule.Port, "port", "any", "Port number or range to allow",
|
|
)
|
|
|
|
ctx.flags.StringVar(
|
|
&rule.Proto,
|
|
"proto",
|
|
"any",
|
|
"Protocol to allow. Can be 'tcp', 'udp', 'icmp', or 'any'",
|
|
)
|
|
|
|
ctx.flags.StringVar(
|
|
&rule.Host,
|
|
"host",
|
|
"",
|
|
"Name of host to allow. Defaults to 'any' if --groups is not given",
|
|
)
|
|
|
|
ctx.flags.StringSliceVar(
|
|
&rule.Groups,
|
|
"groups",
|
|
nil,
|
|
"One or more comma-separated group names to allow",
|
|
)
|
|
|
|
ctx, err := ctx.withParsedFlags(nil)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing flags: %w", err)
|
|
}
|
|
|
|
if *to == "" {
|
|
return errors.New("--to is required")
|
|
}
|
|
|
|
ruleSetFn, err := vpnFirewallRuleSetToFn(*to)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid --to value %q: %w", *to, err)
|
|
}
|
|
|
|
if rule.Host != "" && len(rule.Groups) > 0 {
|
|
return fmt.Errorf("--host and --groups are mutually exclusive")
|
|
} else if rule.Host == "" && len(rule.Groups) == 0 {
|
|
rule.Host = "any"
|
|
}
|
|
|
|
daemonRPC, err := ctx.newDaemonRPC()
|
|
if err != nil {
|
|
return fmt.Errorf("creating daemon RPC client: %w", err)
|
|
}
|
|
defer daemonRPC.Close()
|
|
|
|
config, err := vpnFirewallGetConfig(ctx, daemonRPC)
|
|
if err != nil {
|
|
return fmt.Errorf("getting network config: %w", err)
|
|
}
|
|
|
|
ruleSet := ruleSetFn(&config.VPN.Firewall)
|
|
*ruleSet = append(*ruleSet, rule)
|
|
|
|
if err := config.Validate(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ctx.getChangeStager().set(
|
|
config.VPN.Firewall, vpnFirewallConfigChangeStagerName,
|
|
); err != nil {
|
|
return fmt.Errorf("staging changes: %w", err)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
}
|
|
|
|
var subCmdVPNFirewallCommit = subCmd{
|
|
name: "commit",
|
|
descr: "Commit all changes which were staged using 'add' or 'remove'",
|
|
do: func(ctx subCmdCtx) error {
|
|
ctx, err := ctx.withParsedFlags(nil)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing flags: %w", err)
|
|
}
|
|
|
|
var firewallConfig daecommon.ConfigFirewall
|
|
ok, err := ctx.getChangeStager().get(
|
|
&firewallConfig, vpnFirewallConfigChangeStagerName,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("checking for staged changes: %w", err)
|
|
} else if !ok {
|
|
return errors.New("no changes staged, use 'add' or 'remove' to stage changes")
|
|
}
|
|
|
|
daemonRPC, err := ctx.newDaemonRPC()
|
|
if err != nil {
|
|
return fmt.Errorf("creating daemon RPC client: %w", err)
|
|
}
|
|
defer daemonRPC.Close()
|
|
|
|
config, err := daemonRPC.GetConfig(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("getting network config: %w", err)
|
|
}
|
|
|
|
config.VPN.Firewall = firewallConfig
|
|
|
|
return daemonRPC.SetConfig(ctx, config)
|
|
},
|
|
}
|
|
|
|
var subCmdVPNFirewallRemove = subCmd{
|
|
name: "remove",
|
|
descr: "Remove one or more firewall rules from the staged configuration",
|
|
do: func(ctx subCmdCtx) error {
|
|
from := ctx.flags.String(
|
|
"from",
|
|
"",
|
|
"Which set of rules to remove from, either 'inbound' or 'outbound'",
|
|
)
|
|
|
|
indexes := ctx.flags.IntSlice(
|
|
"indexes",
|
|
nil,
|
|
"Comma-separated indexes of rules to remove, as returned by 'show --staged'",
|
|
)
|
|
|
|
ctx, err := ctx.withParsedFlags(nil)
|
|
if err != nil {
|
|
return fmt.Errorf("parsing flags: %w", err)
|
|
}
|
|
|
|
if *from == "" || len(*indexes) == 0 {
|
|
return errors.New("--from and --indexes are required")
|
|
}
|
|
|
|
ruleSetFn, err := vpnFirewallRuleSetToFn(*from)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid --from value %q: %w", *from, err)
|
|
}
|
|
|
|
daemonRPC, err := ctx.newDaemonRPC()
|
|
if err != nil {
|
|
return fmt.Errorf("creating daemon RPC client: %w", err)
|
|
}
|
|
defer daemonRPC.Close()
|
|
|
|
indexSet := map[int]struct{}{}
|
|
for _, index := range *indexes {
|
|
if index < 0 {
|
|
return fmt.Errorf("invalid index %d", index)
|
|
}
|
|
indexSet[index] = struct{}{}
|
|
}
|
|
|
|
config, err := vpnFirewallGetConfig(ctx, daemonRPC)
|
|
if err != nil {
|
|
return fmt.Errorf("getting network config: %w", err)
|
|
}
|
|
|
|
var (
|
|
ruleSet = ruleSetFn(&config.VPN.Firewall)
|
|
filteredRuleSet = make(
|
|
[]daecommon.ConfigFirewallRule, 0, len(*ruleSet),
|
|
)
|
|
)
|
|
|
|
for i, rule := range *ruleSet {
|
|
if _, remove := indexSet[i]; remove {
|
|
delete(indexSet, i)
|
|
continue
|
|
}
|
|
filteredRuleSet = append(filteredRuleSet, rule)
|
|
}
|
|
|
|
if len(indexSet) > 0 {
|
|
invalidIndexes := maps.Keys(indexSet)
|
|
slices.Sort(invalidIndexes)
|
|
return fmt.Errorf("invalid index(es): %v", invalidIndexes)
|
|
}
|
|
|
|
*ruleSet = filteredRuleSet
|
|
|
|
if err := config.Validate(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := ctx.getChangeStager().set(
|
|
config.VPN.Firewall, vpnFirewallConfigChangeStagerName,
|
|
); err != nil {
|
|
return fmt.Errorf("staging changes: %w", err)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
}
|
|
|
|
var subCmdVPNFirewallReset = subCmd{
|
|
name: "reset",
|
|
descr: "Discard all changes which have been staged",
|
|
do: func(ctx subCmdCtx) error {
|
|
return ctx.getChangeStager().del(vpnFirewallConfigChangeStagerName)
|
|
},
|
|
}
|
|
|
|
type firewallRuleView struct {
|
|
Index int `yaml:"index"`
|
|
daecommon.ConfigFirewallRule `yaml:",inline"`
|
|
}
|
|
|
|
func newFirewallRuleViews(
|
|
rules []daecommon.ConfigFirewallRule,
|
|
) []firewallRuleView {
|
|
views := make([]firewallRuleView, len(rules))
|
|
for i := range rules {
|
|
views[i] = firewallRuleView{
|
|
Index: i,
|
|
ConfigFirewallRule: rules[i],
|
|
}
|
|
}
|
|
return views
|
|
}
|
|
|
|
type firewallView struct {
|
|
Outbound []firewallRuleView `yaml:"outbound"`
|
|
Inbound []firewallRuleView `yaml:"inbound"`
|
|
}
|
|
|
|
func newFirewallView(firewallConfig daecommon.ConfigFirewall) firewallView {
|
|
return firewallView{
|
|
Outbound: newFirewallRuleViews(firewallConfig.Outbound),
|
|
Inbound: newFirewallRuleViews(firewallConfig.Inbound),
|
|
}
|
|
}
|
|
|
|
var subCmdVPNFirewallShow = subCmd{
|
|
name: "show",
|
|
descr: "Shows the currently configured firewall rules",
|
|
do: doWithOutput(func(ctx subCmdCtx) (any, error) {
|
|
staged := ctx.flags.Bool(
|
|
"staged",
|
|
false,
|
|
"Return the firewall configuration with staged changes included",
|
|
)
|
|
|
|
ctx, err := ctx.withParsedFlags(nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing flags: %w", err)
|
|
}
|
|
|
|
var (
|
|
firewallConfig daecommon.ConfigFirewall
|
|
foundStaged bool
|
|
)
|
|
if *staged {
|
|
var err error
|
|
if foundStaged, err = ctx.getChangeStager().get(
|
|
&firewallConfig, vpnFirewallConfigChangeStagerName,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("checking for staged changes: %w", err)
|
|
}
|
|
}
|
|
|
|
if !foundStaged {
|
|
daemonRPC, err := ctx.newDaemonRPC()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating daemon RPC client: %w", err)
|
|
}
|
|
defer daemonRPC.Close()
|
|
|
|
config, err := daemonRPC.GetConfig(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("getting network config: %w", err)
|
|
}
|
|
|
|
firewallConfig = config.VPN.Firewall
|
|
}
|
|
|
|
return newFirewallView(firewallConfig), nil
|
|
}),
|
|
}
|
|
|
|
var subCmdVPNFirewall = subCmd{
|
|
name: "firewall",
|
|
descr: "Sub-commands related to this host's VPN firewall",
|
|
do: func(ctx subCmdCtx) error {
|
|
return ctx.doSubCmd(
|
|
subCmdVPNFirewallAdd,
|
|
subCmdVPNFirewallCommit,
|
|
subCmdVPNFirewallRemove,
|
|
subCmdVPNFirewallReset,
|
|
subCmdVPNFirewallShow,
|
|
)
|
|
},
|
|
}
|