config: rewrite config sub-package and add more tests

Signed-off-by: Steffen Vogel <post@steffenvogel.de>
This commit is contained in:
Steffen Vogel
2022-05-04 13:54:06 +02:00
parent 70a9195654
commit 46391bb75c
6 changed files with 288 additions and 246 deletions

View File

@@ -73,7 +73,7 @@ func daemon(cmd *cobra.Command, args []string) {
}
// Create control socket server to manage daemon
_, err = socket.Listen("unix", cfg.GetString("socket"), cfg.GetBool("socket-wait"), daemon)
_, err = socket.Listen("unix", cfg.GetString("socket.path"), cfg.GetBool("socket.wait"), daemon)
if err != nil {
logger.Fatal("Failed to initialize control socket", zap.Error(err))
}

125
internal/config/agent.go Normal file
View File

@@ -0,0 +1,125 @@
package config
import (
"fmt"
"io"
"regexp"
"github.com/pion/ice/v2"
"gopkg.in/yaml.v3"
icex "riasc.eu/wice/internal/ice"
)
func (c *Config) AgentConfig() (*ice.AgentConfig, error) {
cfg := &ice.AgentConfig{
InsecureSkipVerify: c.GetBool("ice.insecure_skip_verify"),
Lite: c.GetBool("ice.lite"),
PortMin: uint16(c.GetUint("ice.port.min")),
PortMax: uint16(c.GetUint("ice.port.max")),
}
interfaceFilterRegex, err := regexp.Compile(c.GetString("ice.interface_filter"))
if err != nil {
return nil, fmt.Errorf("invalid ice.interface_filter config: %w", err)
}
cfg.InterfaceFilter = func(name string) bool {
return interfaceFilterRegex.Match([]byte(name))
}
// ICE URLS
cfg.Urls = []*ice.URL{}
for _, u := range c.GetStringSlice("ice.urls") {
up, err := ice.ParseURL(u)
if err != nil {
return nil, fmt.Errorf("failed to parse ice.url: %s: %w", u, err)
}
cfg.Urls = append(cfg.Urls, up)
}
// Add default STUN/TURN servers
// Set ICE credentials
u := c.GetString("ice.username")
p := c.GetString("ice.password")
for _, q := range cfg.Urls {
if u != "" {
q.Username = u
}
if p != "" {
q.Password = p
}
}
if c.IsSet("ice.nat_1to1_ips") {
cfg.NAT1To1IPs = c.GetStringSlice("ice.nat_1to1_ips")
}
if c.IsSet("ice.max_binding_requests") {
i := uint16(c.GetInt("ice.max_binding_requests"))
cfg.MaxBindingRequests = &i
}
if c.GetBool("ice.mdns") {
cfg.MulticastDNSMode = ice.MulticastDNSModeQueryAndGather
}
if c.IsSet("ice.disconnected_timeout") {
to := c.GetDuration("ice.disconnected_timeout")
cfg.DisconnectedTimeout = &to
}
if c.IsSet("ice.failed_timeout") {
to := c.GetDuration("ice.failed_timeout")
cfg.FailedTimeout = &to
}
if c.IsSet("ice.keepalive_interval") {
to := c.GetDuration("ice.keepalive_interval")
cfg.KeepaliveInterval = &to
}
if c.IsSet("ice.check_interval") {
to := c.GetDuration("ice.check_interval")
cfg.CheckInterval = &to
}
// Filter candidate types
candidateTypes := []ice.CandidateType{}
for _, value := range c.GetStringSlice("ice.candidate_types") {
ct, err := icex.CandidateTypeFromString(value)
if err != nil {
return nil, err
}
candidateTypes = append(candidateTypes, ct)
}
if len(candidateTypes) > 0 {
cfg.CandidateTypes = candidateTypes
}
// Filter network types
networkTypes := []ice.NetworkType{}
for _, value := range c.GetStringSlice("ice.network_types") {
ct, err := icex.NetworkTypeFromString(value)
if err != nil {
return nil, err
}
networkTypes = append(networkTypes, ct)
}
if len(networkTypes) > 0 {
cfg.NetworkTypes = networkTypes
}
return cfg, nil
}
func (c *Config) Dump(wr io.Writer) error {
enc := yaml.NewEncoder(wr)
enc.SetIndent(2)
return enc.Encode(c.AllSettings())
}

View File

@@ -0,0 +1,124 @@
package config_test
import (
"testing"
"github.com/pion/ice/v2"
"riasc.eu/wice/internal/config"
)
func TestParseArgsCandidateTypes(t *testing.T) {
config, err := config.Parse("--ice-candidate-type", "host", "--ice-candidate-type", "relay")
if err != nil {
t.Errorf("err got %v, want nil", err)
}
agentConfig, err := config.AgentConfig()
if err != nil {
t.Errorf("Failed to get agent config: %s", err)
}
if len(agentConfig.CandidateTypes) != 2 {
t.Fail()
}
if agentConfig.CandidateTypes[0] != ice.CandidateTypeHost {
t.Fail()
}
if agentConfig.CandidateTypes[1] != ice.CandidateTypeRelay {
t.Fail()
}
}
func TestParseArgsInterfaceFilter(t *testing.T) {
config, err := config.Parse("--ice-interface-filter", "eth\\d+")
if err != nil {
t.Errorf("err got %v, want nil", err)
}
agentConfig, err := config.AgentConfig()
if err != nil {
t.Errorf("Failed to get agent config: %s", err)
}
if !agentConfig.InterfaceFilter("eth0") {
t.Fail()
}
if agentConfig.InterfaceFilter("wifi0") {
t.Fail()
}
}
func TestParseArgsInterfaceFilterFail(t *testing.T) {
config, err := config.Parse("--ice-interface-filter", "eth(")
if err != nil {
t.Fail()
}
_, err = config.AgentConfig()
if err == nil {
t.Fail()
}
}
func TestParseArgsDefault(t *testing.T) {
cfg, err := config.Parse()
if err != nil {
t.Fail()
}
agentConfig, err := cfg.AgentConfig()
if err != nil {
t.FailNow()
}
if len(agentConfig.Urls) != 1 {
t.FailNow()
}
if agentConfig.Urls[0].String() != config.DefaultURL {
t.FailNow()
}
if len(cfg.Backends) != 1 {
t.FailNow()
}
if cfg.Backends[0].String() != config.DefaultBackend+":" {
t.FailNow()
}
}
func TestParseArgsUrls(t *testing.T) {
config, err := config.Parse("--url", "stun:stun.riasc.eu", "--url", "turn:turn.riasc.eu")
if err != nil {
t.Errorf("err got %v, want nil", err)
}
agentConfig, err := config.AgentConfig()
if err != nil {
t.FailNow()
}
if len(agentConfig.Urls) != 2 {
t.Fail()
}
if agentConfig.Urls[0].Host != "stun.riasc.eu" {
t.Fail()
}
if agentConfig.Urls[0].Scheme != ice.SchemeTypeSTUN {
t.Fail()
}
if agentConfig.Urls[1].Host != "turn.riasc.eu" {
t.Fail()
}
if agentConfig.Urls[1].Scheme != ice.SchemeTypeTURN {
t.Fail()
}
}

View File

@@ -3,7 +3,6 @@ package config
import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
@@ -13,39 +12,40 @@ import (
"time"
"go.uber.org/zap"
"gopkg.in/yaml.v3"
icex "riasc.eu/wice/internal/ice"
"github.com/pion/ice/v2"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
// Copied from pion/ice/agent_config.go
const (
// defaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase
defaultCheckInterval = 200 * time.Millisecond
// DefaultCheckInterval is the interval at which the agent performs candidate checks in the connecting phase
DefaultCheckInterval = 200 * time.Millisecond
// keepaliveInterval used to keep candidates alive
defaultKeepaliveInterval = 2 * time.Second
DefaultKeepaliveInterval = 2 * time.Second
// defaultDisconnectedTimeout is the default time till an Agent transitions disconnected
defaultDisconnectedTimeout = 5 * time.Second
// DefaultDisconnectedTimeout is the default time till an Agent transitions disconnected
DefaultDisconnectedTimeout = 5 * time.Second
// defaultRestartInterval is the default time an Agent waits before it attempts an ICE restart
defaultRestartTimeout = 5 * time.Second
DefaultRestartTimeout = 5 * time.Second
// defaultFailedTimeout is the default time till an Agent transitions to failed after disconnected
defaultFailedTimeout = 5 * time.Second
// DefaultFailedTimeout is the default time till an Agent transitions to failed after disconnected
DefaultFailedTimeout = 5 * time.Second
// max binding request before considering a pair failed
defaultMaxBindingRequests = 7
DefaultMaxBindingRequests = 7
defaultWatchInterval = time.Second
DefaultWatchInterval = time.Second
DefaultSocketPath = "/var/run/wice.sock"
defaultWireguardConfigPath = "/etc/wireguard"
DefaultWireguardConfigPath = "/etc/wireguard"
DefaultURL = "stun:l.google.com:19302"
DefaultBackend = "p2p"
)
type Config struct {
@@ -57,6 +57,8 @@ type Config struct {
WireguardInterfaces []string
ConfigFiles []string
WireguardInterfaceFilter *regexp.Regexp
Backends []*url.URL
}
@@ -83,22 +85,22 @@ func NewConfig(flags *pflag.FlagSet) *Config {
flags: flags,
}
c.SetDefault("ice.urls", []string{"stun:l.google.com:19302"})
c.SetDefault("backends", []string{"p2p"})
c.SetDefault("watch_interval", defaultWatchInterval)
c.SetDefault("ice.urls", []string{DefaultURL})
c.SetDefault("backends", []string{DefaultBackend})
c.SetDefault("watch_interval", DefaultWatchInterval)
c.SetDefault("socket.path", DefaultSocketPath)
c.SetDefault("ice.max_binding_requests", defaultMaxBindingRequests)
c.SetDefault("ice.max_binding_requests", DefaultMaxBindingRequests)
c.SetDefault("ice.check_interval", defaultCheckInterval)
c.SetDefault("ice.disconnected_timout", defaultDisconnectedTimeout)
c.SetDefault("ice.failed_timeout", defaultFailedTimeout)
c.SetDefault("ice.restart_timeout", defaultRestartTimeout)
c.SetDefault("ice.keepalive_interval", defaultKeepaliveInterval)
c.SetDefault("ice.check_interval", DefaultCheckInterval)
c.SetDefault("ice.disconnected_timeout", DefaultDisconnectedTimeout)
c.SetDefault("ice.failed_timeout", DefaultFailedTimeout)
c.SetDefault("ice.restart_timeout", DefaultRestartTimeout)
c.SetDefault("ice.keepalive_interval", DefaultKeepaliveInterval)
c.SetDefault("ice.nat_1to1_ips", []net.IP{})
c.SetDefault("wg.config.path", defaultWireguardConfigPath)
c.SetDefault("wg.config.path", DefaultWireguardConfigPath)
flags.StringP("config-domain", "A", "", "Perform auto-configuration via DNS")
flags.StringSliceVarP(&c.ConfigFiles, "config", "c", []string{}, "Path of configuration files")
@@ -156,7 +158,7 @@ func NewConfig(flags *pflag.FlagSet) *Config {
"ice-max-binding-requests": "ice.max_binding_requests",
"ice-insecure-skip-verify": "ice.insecure_skip_verify",
"ice-interface-filter": "ice.interface_filter",
"ice-disconnected-timout": "ice.disconnected_timout",
"ice-disconnected-timout": "ice.disconnected_timeout",
"ice-failed-timeout": "ice.failed_timeout",
"ice-keepalive-interval": "ice.keepalive_interval",
"ice-check-interval": "ice.check_interval",
@@ -178,6 +180,8 @@ func NewConfig(flags *pflag.FlagSet) *Config {
}
func (c *Config) Setup(args []string) error {
// We cant to this in NewConfig since its called by init()
// at which time the logging system is not initialized yet.
c.logger = zap.L().Named("config")
c.WireguardInterfaces = args
@@ -252,6 +256,12 @@ func (c *Config) MergeRemoteConfig(url *url.URL) error {
}
func (c *Config) Load() error {
var err error
c.WireguardInterfaceFilter, err = regexp.Compile(c.GetString("wg.interface_filter"))
if err != nil {
return fmt.Errorf("invalid regular expression for setting 'wg.interface_filter': %w", err)
}
// Backends
c.Backends = []*url.URL{}
@@ -272,115 +282,3 @@ func (c *Config) Load() error {
return nil
}
func (c *Config) AgentConfig() (*ice.AgentConfig, error) {
cfg := &ice.AgentConfig{
InsecureSkipVerify: c.GetBool("ice.insecure_skip_verify"),
Lite: c.GetBool("ice.lite"),
PortMin: uint16(c.GetUint("ice.port.min")),
PortMax: uint16(c.GetUint("ice.port.max")),
}
interfaceFilterRegex, err := regexp.Compile(c.GetString("ice.interface_filter"))
if err != nil {
return nil, fmt.Errorf("invalid ice.interface_filter config: %w", err)
}
cfg.InterfaceFilter = func(name string) bool {
return interfaceFilterRegex.Match([]byte(name))
}
// ICE URLS
cfg.Urls = []*ice.URL{}
for _, u := range c.GetStringSlice("ice.urls") {
up, err := ice.ParseURL(u)
if err != nil {
return nil, fmt.Errorf("failed to parse ice.url: %s: %w", u, err)
}
cfg.Urls = append(cfg.Urls, up)
}
// Add default STUN/TURN servers
// Set ICE credentials
u := c.GetString("ice.username")
p := c.GetString("ice.password")
for _, q := range cfg.Urls {
if u != "" {
q.Username = u
}
if p != "" {
q.Password = p
}
}
if c.IsSet("ice.nat_1to1_ips") {
cfg.NAT1To1IPs = c.GetStringSlice("ice.nat_1to1_ips")
}
if c.IsSet("ice.max_binding_requests") {
i := uint16(c.GetInt("ice.max_binding_requests"))
cfg.MaxBindingRequests = &i
}
if c.GetBool("ice.mdns") {
cfg.MulticastDNSMode = ice.MulticastDNSModeQueryAndGather
}
if c.IsSet("ice.disconnected_timeout") {
to := c.GetDuration("ice.disconnected_timeout")
cfg.DisconnectedTimeout = &to
}
if c.IsSet("ice.failed_timeout") {
to := c.GetDuration("ice.failed_timeout")
cfg.FailedTimeout = &to
}
if c.IsSet("ice.keepalive_interval") {
to := c.GetDuration("ice.keepalive_interval")
cfg.KeepaliveInterval = &to
}
if c.IsSet("ice.check_interval") {
to := c.GetDuration("ice.check_interval")
cfg.CheckInterval = &to
}
// Filter candidate types
candidateTypes := []ice.CandidateType{}
for _, value := range c.GetStringSlice("ice.candidate_types") {
ct, err := icex.CandidateTypeFromString(value)
if err != nil {
return nil, err
}
candidateTypes = append(candidateTypes, ct)
}
if len(candidateTypes) > 0 {
cfg.CandidateTypes = candidateTypes
}
// Filter network types
networkTypes := []ice.NetworkType{}
for _, value := range c.GetStringSlice("ice.network_types") {
ct, err := icex.NetworkTypeFromString(value)
if err != nil {
return nil, err
}
networkTypes = append(networkTypes, ct)
}
if len(networkTypes) > 0 {
cfg.NetworkTypes = networkTypes
}
return cfg, nil
}
func (c *Config) Dump(wr io.Writer) error {
return yaml.NewEncoder(wr).Encode(c.AllSettings())
}

View File

@@ -3,7 +3,6 @@ package config_test
import (
"testing"
"github.com/pion/ice/v2"
"riasc.eu/wice/internal/config"
)
@@ -18,7 +17,7 @@ func TestParseArgsUser(t *testing.T) {
}
}
func TestParseArgsBackend(t *testing.T) {
func TestParseArgsBackends(t *testing.T) {
config, err := config.Parse("--backend", "k8s", "--backend", "p2p")
if err != nil {
t.Errorf("err got %v, want nil", err)
@@ -38,107 +37,3 @@ func TestParseArgsBackend(t *testing.T) {
t.Fail()
}
}
func TestParseArgsUrls(t *testing.T) {
config, err := config.Parse("--url", "stun:stun.riasc.eu", "--url", "turn:turn.riasc.eu")
if err != nil {
t.Errorf("err got %v, want nil", err)
}
agentConfig, err := config.AgentConfig()
if err != nil {
t.FailNow()
}
if len(agentConfig.Urls) != 2 {
t.Fail()
}
if agentConfig.Urls[0].Host != "stun.riasc.eu" {
t.Fail()
}
if agentConfig.Urls[0].Scheme != ice.SchemeTypeSTUN {
t.Fail()
}
if agentConfig.Urls[1].Host != "turn.riasc.eu" {
t.Fail()
}
if agentConfig.Urls[1].Scheme != ice.SchemeTypeTURN {
t.Fail()
}
}
func TestParseArgsCandidateTypes(t *testing.T) {
config, err := config.Parse("--ice-candidate-type", "host", "--ice-candidate-type", "relay")
if err != nil {
t.Errorf("err got %v, want nil", err)
}
agentConfig, err := config.AgentConfig()
if err != nil {
t.Errorf("Failed to get agent config: %s", err)
}
if len(agentConfig.CandidateTypes) != 2 {
t.Fail()
}
if agentConfig.CandidateTypes[0] != ice.CandidateTypeHost {
t.Fail()
}
if agentConfig.CandidateTypes[1] != ice.CandidateTypeRelay {
t.Fail()
}
}
func TestParseArgsInterfaceFilter(t *testing.T) {
config, err := config.Parse("--ice-interface-filter", "eth\\d+")
if err != nil {
t.Errorf("err got %v, want nil", err)
}
agentConfig, err := config.AgentConfig()
if err != nil {
t.Errorf("Failed to get agent config: %s", err)
}
if !agentConfig.InterfaceFilter("eth0") {
t.Fail()
}
if agentConfig.InterfaceFilter("wifi0") {
t.Fail()
}
}
func TestParseArgsInterfaceFilterFail(t *testing.T) {
config, err := config.Parse("--ice-interface-filter", "eth(")
if err != nil {
t.Fail()
}
_, err = config.AgentConfig()
if err == nil {
t.Fail()
}
}
func TestParseArgsDefault(t *testing.T) {
config, err := config.Parse()
if err != nil {
t.Fail()
}
agentConfig, err := config.AgentConfig()
if err != nil {
t.FailNow()
}
if len(agentConfig.Urls) != 1 {
t.Fail()
}
}

View File

@@ -193,7 +193,7 @@ func (d *Daemon) SyncAllInterfaces() error {
keepInterfaces := intf.InterfaceList{}
for _, device := range devices {
if !d.Config.WireguardInterfaceFilter.Match([]byte(device.Name)) {
if !d.Config.WireguardInterfaceFilter.MatchString(device.Name) {
continue // Skip interfaces which dont match the filter
}