diff --git a/cmd/wice/daemon.go b/cmd/wice/daemon.go index c2c8f787..32da211e 100644 --- a/cmd/wice/daemon.go +++ b/cmd/wice/daemon.go @@ -73,7 +73,7 @@ func daemon(cmd *cobra.Command, args []string) { } // Create control socket server to manage daemon - _, err = socket.Listen("unix", cfg.GetString("socket"), cfg.GetBool("socket-wait"), daemon) + _, err = socket.Listen("unix", cfg.GetString("socket.path"), cfg.GetBool("socket.wait"), daemon) if err != nil { logger.Fatal("Failed to initialize control socket", zap.Error(err)) } diff --git a/internal/config/agent.go b/internal/config/agent.go new file mode 100644 index 00000000..402c9bec --- /dev/null +++ b/internal/config/agent.go @@ -0,0 +1,125 @@ +package config + +import ( + "fmt" + "io" + "regexp" + + "github.com/pion/ice/v2" + "gopkg.in/yaml.v3" + icex "riasc.eu/wice/internal/ice" +) + +func (c *Config) AgentConfig() (*ice.AgentConfig, error) { + cfg := &ice.AgentConfig{ + InsecureSkipVerify: c.GetBool("ice.insecure_skip_verify"), + Lite: c.GetBool("ice.lite"), + PortMin: uint16(c.GetUint("ice.port.min")), + PortMax: uint16(c.GetUint("ice.port.max")), + } + + interfaceFilterRegex, err := regexp.Compile(c.GetString("ice.interface_filter")) + if err != nil { + return nil, fmt.Errorf("invalid ice.interface_filter config: %w", err) + } + + cfg.InterfaceFilter = func(name string) bool { + return interfaceFilterRegex.Match([]byte(name)) + } + + // ICE URLS + cfg.Urls = []*ice.URL{} + for _, u := range c.GetStringSlice("ice.urls") { + up, err := ice.ParseURL(u) + if err != nil { + return nil, fmt.Errorf("failed to parse ice.url: %s: %w", u, err) + } + + cfg.Urls = append(cfg.Urls, up) + } + + // Add default STUN/TURN servers + // Set ICE credentials + u := c.GetString("ice.username") + p := c.GetString("ice.password") + for _, q := range cfg.Urls { + if u != "" { + q.Username = u + } + + if p != "" { + q.Password = p + } + } + + if c.IsSet("ice.nat_1to1_ips") { + cfg.NAT1To1IPs = c.GetStringSlice("ice.nat_1to1_ips") + } + + if c.IsSet("ice.max_binding_requests") { + i := uint16(c.GetInt("ice.max_binding_requests")) + cfg.MaxBindingRequests = &i + } + + if c.GetBool("ice.mdns") { + cfg.MulticastDNSMode = ice.MulticastDNSModeQueryAndGather + } + + if c.IsSet("ice.disconnected_timeout") { + to := c.GetDuration("ice.disconnected_timeout") + cfg.DisconnectedTimeout = &to + } + + if c.IsSet("ice.failed_timeout") { + to := c.GetDuration("ice.failed_timeout") + cfg.FailedTimeout = &to + } + + if c.IsSet("ice.keepalive_interval") { + to := c.GetDuration("ice.keepalive_interval") + cfg.KeepaliveInterval = &to + } + + if c.IsSet("ice.check_interval") { + to := c.GetDuration("ice.check_interval") + cfg.CheckInterval = &to + } + + // Filter candidate types + candidateTypes := []ice.CandidateType{} + for _, value := range c.GetStringSlice("ice.candidate_types") { + ct, err := icex.CandidateTypeFromString(value) + if err != nil { + return nil, err + } + + candidateTypes = append(candidateTypes, ct) + } + + if len(candidateTypes) > 0 { + cfg.CandidateTypes = candidateTypes + } + + // Filter network types + networkTypes := []ice.NetworkType{} + for _, value := range c.GetStringSlice("ice.network_types") { + ct, err := icex.NetworkTypeFromString(value) + if err != nil { + return nil, err + } + + networkTypes = append(networkTypes, ct) + } + + if len(networkTypes) > 0 { + cfg.NetworkTypes = networkTypes + } + + return cfg, nil +} + +func (c *Config) Dump(wr io.Writer) error { + enc := yaml.NewEncoder(wr) + enc.SetIndent(2) + return enc.Encode(c.AllSettings()) +} diff --git a/internal/config/agent_test.go b/internal/config/agent_test.go new file mode 100644 index 00000000..604f6142 --- /dev/null +++ b/internal/config/agent_test.go @@ -0,0 +1,124 @@ +package config_test + +import ( + "testing" + + "github.com/pion/ice/v2" + "riasc.eu/wice/internal/config" +) + +func TestParseArgsCandidateTypes(t *testing.T) { + config, err := config.Parse("--ice-candidate-type", "host", "--ice-candidate-type", "relay") + if err != nil { + t.Errorf("err got %v, want nil", err) + } + + agentConfig, err := config.AgentConfig() + if err != nil { + t.Errorf("Failed to get agent config: %s", err) + } + + if len(agentConfig.CandidateTypes) != 2 { + t.Fail() + } + + if agentConfig.CandidateTypes[0] != ice.CandidateTypeHost { + t.Fail() + } + + if agentConfig.CandidateTypes[1] != ice.CandidateTypeRelay { + t.Fail() + } +} + +func TestParseArgsInterfaceFilter(t *testing.T) { + config, err := config.Parse("--ice-interface-filter", "eth\\d+") + if err != nil { + t.Errorf("err got %v, want nil", err) + } + + agentConfig, err := config.AgentConfig() + if err != nil { + t.Errorf("Failed to get agent config: %s", err) + } + + if !agentConfig.InterfaceFilter("eth0") { + t.Fail() + } + + if agentConfig.InterfaceFilter("wifi0") { + t.Fail() + } +} + +func TestParseArgsInterfaceFilterFail(t *testing.T) { + config, err := config.Parse("--ice-interface-filter", "eth(") + if err != nil { + t.Fail() + } + + _, err = config.AgentConfig() + if err == nil { + t.Fail() + } +} + +func TestParseArgsDefault(t *testing.T) { + cfg, err := config.Parse() + if err != nil { + t.Fail() + } + + agentConfig, err := cfg.AgentConfig() + if err != nil { + t.FailNow() + } + + if len(agentConfig.Urls) != 1 { + t.FailNow() + } + + if agentConfig.Urls[0].String() != config.DefaultURL { + t.FailNow() + } + + if len(cfg.Backends) != 1 { + t.FailNow() + } + + if cfg.Backends[0].String() != config.DefaultBackend+":" { + t.FailNow() + } +} + +func TestParseArgsUrls(t *testing.T) { + config, err := config.Parse("--url", "stun:stun.riasc.eu", "--url", "turn:turn.riasc.eu") + if err != nil { + t.Errorf("err got %v, want nil", err) + } + + agentConfig, err := config.AgentConfig() + if err != nil { + t.FailNow() + } + + if len(agentConfig.Urls) != 2 { + t.Fail() + } + + if agentConfig.Urls[0].Host != "stun.riasc.eu" { + t.Fail() + } + + if agentConfig.Urls[0].Scheme != ice.SchemeTypeSTUN { + t.Fail() + } + + if agentConfig.Urls[1].Host != "turn.riasc.eu" { + t.Fail() + } + + if agentConfig.Urls[1].Scheme != ice.SchemeTypeTURN { + t.Fail() + } +} diff --git a/internal/config/config.go b/internal/config/config.go index d214a860..33b7df78 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,7 +3,6 @@ package config import ( "errors" "fmt" - "io" "net" "net/http" "net/url" @@ -13,39 +12,40 @@ import ( "time" "go.uber.org/zap" - "gopkg.in/yaml.v3" - icex "riasc.eu/wice/internal/ice" - "github.com/pion/ice/v2" "github.com/spf13/pflag" "github.com/spf13/viper" ) // Copied from pion/ice/agent_config.go const ( - // defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase - defaultCheckInterval = 200 * time.Millisecond + // DefaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase + DefaultCheckInterval = 200 * time.Millisecond // keepaliveInterval used to keep candidates alive - defaultKeepaliveInterval = 2 * time.Second + DefaultKeepaliveInterval = 2 * time.Second - // defaultDisconnectedTimeout is the default time till an Agent transitions disconnected - defaultDisconnectedTimeout = 5 * time.Second + // DefaultDisconnectedTimeout is the default time till an Agent transitions disconnected + DefaultDisconnectedTimeout = 5 * time.Second // defaultRestartInterval is the default time an Agent waits before it attempts an ICE restart - defaultRestartTimeout = 5 * time.Second + DefaultRestartTimeout = 5 * time.Second - // defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected - defaultFailedTimeout = 5 * time.Second + // DefaultFailedTimeout is the default time till an Agent transitions to failed after disconnected + DefaultFailedTimeout = 5 * time.Second // max binding request before considering a pair failed - defaultMaxBindingRequests = 7 + DefaultMaxBindingRequests = 7 - defaultWatchInterval = time.Second + DefaultWatchInterval = time.Second DefaultSocketPath = "/var/run/wice.sock" - defaultWireguardConfigPath = "/etc/wireguard" + DefaultWireguardConfigPath = "/etc/wireguard" + + DefaultURL = "stun:l.google.com:19302" + + DefaultBackend = "p2p" ) type Config struct { @@ -57,6 +57,8 @@ type Config struct { WireguardInterfaces []string ConfigFiles []string + WireguardInterfaceFilter *regexp.Regexp + Backends []*url.URL } @@ -83,22 +85,22 @@ func NewConfig(flags *pflag.FlagSet) *Config { flags: flags, } - c.SetDefault("ice.urls", []string{"stun:l.google.com:19302"}) - c.SetDefault("backends", []string{"p2p"}) - c.SetDefault("watch_interval", defaultWatchInterval) + c.SetDefault("ice.urls", []string{DefaultURL}) + c.SetDefault("backends", []string{DefaultBackend}) + c.SetDefault("watch_interval", DefaultWatchInterval) c.SetDefault("socket.path", DefaultSocketPath) - c.SetDefault("ice.max_binding_requests", defaultMaxBindingRequests) + c.SetDefault("ice.max_binding_requests", DefaultMaxBindingRequests) - c.SetDefault("ice.check_interval", defaultCheckInterval) - c.SetDefault("ice.disconnected_timout", defaultDisconnectedTimeout) - c.SetDefault("ice.failed_timeout", defaultFailedTimeout) - c.SetDefault("ice.restart_timeout", defaultRestartTimeout) - c.SetDefault("ice.keepalive_interval", defaultKeepaliveInterval) + c.SetDefault("ice.check_interval", DefaultCheckInterval) + c.SetDefault("ice.disconnected_timeout", DefaultDisconnectedTimeout) + c.SetDefault("ice.failed_timeout", DefaultFailedTimeout) + c.SetDefault("ice.restart_timeout", DefaultRestartTimeout) + c.SetDefault("ice.keepalive_interval", DefaultKeepaliveInterval) c.SetDefault("ice.nat_1to1_ips", []net.IP{}) - c.SetDefault("wg.config.path", defaultWireguardConfigPath) + c.SetDefault("wg.config.path", DefaultWireguardConfigPath) flags.StringP("config-domain", "A", "", "Perform auto-configuration via DNS") flags.StringSliceVarP(&c.ConfigFiles, "config", "c", []string{}, "Path of configuration files") @@ -156,7 +158,7 @@ func NewConfig(flags *pflag.FlagSet) *Config { "ice-max-binding-requests": "ice.max_binding_requests", "ice-insecure-skip-verify": "ice.insecure_skip_verify", "ice-interface-filter": "ice.interface_filter", - "ice-disconnected-timout": "ice.disconnected_timout", + "ice-disconnected-timout": "ice.disconnected_timeout", "ice-failed-timeout": "ice.failed_timeout", "ice-keepalive-interval": "ice.keepalive_interval", "ice-check-interval": "ice.check_interval", @@ -178,6 +180,8 @@ func NewConfig(flags *pflag.FlagSet) *Config { } func (c *Config) Setup(args []string) error { + // We cant to this in NewConfig since its called by init() + // at which time the logging system is not initialized yet. c.logger = zap.L().Named("config") c.WireguardInterfaces = args @@ -252,6 +256,12 @@ func (c *Config) MergeRemoteConfig(url *url.URL) error { } func (c *Config) Load() error { + var err error + + c.WireguardInterfaceFilter, err = regexp.Compile(c.GetString("wg.interface_filter")) + if err != nil { + return fmt.Errorf("invalid regular expression for setting 'wg.interface_filter': %w", err) + } // Backends c.Backends = []*url.URL{} @@ -272,115 +282,3 @@ func (c *Config) Load() error { return nil } - -func (c *Config) AgentConfig() (*ice.AgentConfig, error) { - cfg := &ice.AgentConfig{ - InsecureSkipVerify: c.GetBool("ice.insecure_skip_verify"), - Lite: c.GetBool("ice.lite"), - PortMin: uint16(c.GetUint("ice.port.min")), - PortMax: uint16(c.GetUint("ice.port.max")), - } - - interfaceFilterRegex, err := regexp.Compile(c.GetString("ice.interface_filter")) - if err != nil { - return nil, fmt.Errorf("invalid ice.interface_filter config: %w", err) - } - - cfg.InterfaceFilter = func(name string) bool { - return interfaceFilterRegex.Match([]byte(name)) - } - - // ICE URLS - cfg.Urls = []*ice.URL{} - for _, u := range c.GetStringSlice("ice.urls") { - up, err := ice.ParseURL(u) - if err != nil { - return nil, fmt.Errorf("failed to parse ice.url: %s: %w", u, err) - } - - cfg.Urls = append(cfg.Urls, up) - } - - // Add default STUN/TURN servers - // Set ICE credentials - u := c.GetString("ice.username") - p := c.GetString("ice.password") - for _, q := range cfg.Urls { - if u != "" { - q.Username = u - } - - if p != "" { - q.Password = p - } - } - - if c.IsSet("ice.nat_1to1_ips") { - cfg.NAT1To1IPs = c.GetStringSlice("ice.nat_1to1_ips") - } - - if c.IsSet("ice.max_binding_requests") { - i := uint16(c.GetInt("ice.max_binding_requests")) - cfg.MaxBindingRequests = &i - } - - if c.GetBool("ice.mdns") { - cfg.MulticastDNSMode = ice.MulticastDNSModeQueryAndGather - } - - if c.IsSet("ice.disconnected_timeout") { - to := c.GetDuration("ice.disconnected_timeout") - cfg.DisconnectedTimeout = &to - } - - if c.IsSet("ice.failed_timeout") { - to := c.GetDuration("ice.failed_timeout") - cfg.FailedTimeout = &to - } - - if c.IsSet("ice.keepalive_interval") { - to := c.GetDuration("ice.keepalive_interval") - cfg.KeepaliveInterval = &to - } - - if c.IsSet("ice.check_interval") { - to := c.GetDuration("ice.check_interval") - cfg.CheckInterval = &to - } - - // Filter candidate types - candidateTypes := []ice.CandidateType{} - for _, value := range c.GetStringSlice("ice.candidate_types") { - ct, err := icex.CandidateTypeFromString(value) - if err != nil { - return nil, err - } - - candidateTypes = append(candidateTypes, ct) - } - - if len(candidateTypes) > 0 { - cfg.CandidateTypes = candidateTypes - } - - // Filter network types - networkTypes := []ice.NetworkType{} - for _, value := range c.GetStringSlice("ice.network_types") { - ct, err := icex.NetworkTypeFromString(value) - if err != nil { - return nil, err - } - - networkTypes = append(networkTypes, ct) - } - - if len(networkTypes) > 0 { - cfg.NetworkTypes = networkTypes - } - - return cfg, nil -} - -func (c *Config) Dump(wr io.Writer) error { - return yaml.NewEncoder(wr).Encode(c.AllSettings()) -} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 78765d70..9eed86cd 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -3,7 +3,6 @@ package config_test import ( "testing" - "github.com/pion/ice/v2" "riasc.eu/wice/internal/config" ) @@ -18,7 +17,7 @@ func TestParseArgsUser(t *testing.T) { } } -func TestParseArgsBackend(t *testing.T) { +func TestParseArgsBackends(t *testing.T) { config, err := config.Parse("--backend", "k8s", "--backend", "p2p") if err != nil { t.Errorf("err got %v, want nil", err) @@ -38,107 +37,3 @@ func TestParseArgsBackend(t *testing.T) { t.Fail() } } - -func TestParseArgsUrls(t *testing.T) { - config, err := config.Parse("--url", "stun:stun.riasc.eu", "--url", "turn:turn.riasc.eu") - if err != nil { - t.Errorf("err got %v, want nil", err) - } - - agentConfig, err := config.AgentConfig() - if err != nil { - t.FailNow() - } - - if len(agentConfig.Urls) != 2 { - t.Fail() - } - - if agentConfig.Urls[0].Host != "stun.riasc.eu" { - t.Fail() - } - - if agentConfig.Urls[0].Scheme != ice.SchemeTypeSTUN { - t.Fail() - } - - if agentConfig.Urls[1].Host != "turn.riasc.eu" { - t.Fail() - } - - if agentConfig.Urls[1].Scheme != ice.SchemeTypeTURN { - t.Fail() - } -} - -func TestParseArgsCandidateTypes(t *testing.T) { - config, err := config.Parse("--ice-candidate-type", "host", "--ice-candidate-type", "relay") - if err != nil { - t.Errorf("err got %v, want nil", err) - } - - agentConfig, err := config.AgentConfig() - if err != nil { - t.Errorf("Failed to get agent config: %s", err) - } - - if len(agentConfig.CandidateTypes) != 2 { - t.Fail() - } - - if agentConfig.CandidateTypes[0] != ice.CandidateTypeHost { - t.Fail() - } - - if agentConfig.CandidateTypes[1] != ice.CandidateTypeRelay { - t.Fail() - } -} - -func TestParseArgsInterfaceFilter(t *testing.T) { - config, err := config.Parse("--ice-interface-filter", "eth\\d+") - if err != nil { - t.Errorf("err got %v, want nil", err) - } - - agentConfig, err := config.AgentConfig() - if err != nil { - t.Errorf("Failed to get agent config: %s", err) - } - - if !agentConfig.InterfaceFilter("eth0") { - t.Fail() - } - - if agentConfig.InterfaceFilter("wifi0") { - t.Fail() - } -} - -func TestParseArgsInterfaceFilterFail(t *testing.T) { - config, err := config.Parse("--ice-interface-filter", "eth(") - if err != nil { - t.Fail() - } - - _, err = config.AgentConfig() - if err == nil { - t.Fail() - } -} - -func TestParseArgsDefault(t *testing.T) { - config, err := config.Parse() - if err != nil { - t.Fail() - } - - agentConfig, err := config.AgentConfig() - if err != nil { - t.FailNow() - } - - if len(agentConfig.Urls) != 1 { - t.Fail() - } -} diff --git a/pkg/daemon.go b/pkg/daemon.go index 96c5fedb..876b46a6 100644 --- a/pkg/daemon.go +++ b/pkg/daemon.go @@ -193,7 +193,7 @@ func (d *Daemon) SyncAllInterfaces() error { keepInterfaces := intf.InterfaceList{} for _, device := range devices { - if !d.Config.WireguardInterfaceFilter.Match([]byte(device.Name)) { + if !d.Config.WireguardInterfaceFilter.MatchString(device.Name) { continue // Skip interfaces which dont match the filter }