diff --git a/pkg/daemon/daemon.go b/pkg/daemon/daemon.go index 735f5be0..8755712f 100644 --- a/pkg/daemon/daemon.go +++ b/pkg/daemon/daemon.go @@ -15,18 +15,18 @@ import ( "github.com/stv0g/cunicu/pkg/wg" "go.uber.org/zap" "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/stv0g/cunicu/pkg/signaling" ) type Daemon struct { - *watcher.Watcher - // Shared Backend *signaling.MultiBackend - Client *wgctrl.Client + client *wgctrl.Client Config *config.Config + watcher *watcher.Watcher devices []device.Device interfaces map[*core.Interface]*Interface @@ -56,19 +56,13 @@ func New(cfg *config.Config) (*Daemon, error) { d.logger = zap.L().Named("daemon") - // Initialize some defaults configuration settings at runtime - if err = config.InitDefaults(); err != nil { - return nil, fmt.Errorf("failed to initialize defaults: %w", err) - } - // Create WireGuard netlink socket - d.Client, err = wgctrl.New() - if err != nil { + if d.client, err = wgctrl.New(); err != nil { return nil, fmt.Errorf("failed to create WireGuard client: %w", err) } // Create watcher - if d.Watcher, err = watcher.New(d.Client, cfg.WatchInterval, cfg.InterfaceFilter); err != nil { + if d.watcher, err = watcher.New(d.client, cfg.WatchInterval, cfg.InterfaceFilter); err != nil { return nil, fmt.Errorf("failed to initialize watcher: %w", err) } @@ -78,13 +72,11 @@ func New(cfg *config.Config) (*Daemon, error) { urls = append(urls, &u.URL) } - if d.Backend, err = signaling.NewMultiBackend(urls, &signaling.BackendConfig{ - OnReady: []signaling.BackendReadyHandler{}, - }); err != nil { + if d.Backend, err = signaling.NewMultiBackend(urls, &signaling.BackendConfig{}); err != nil { return nil, fmt.Errorf("failed to initialize signaling backend: %w", err) } - d.Watcher.OnInterface(d) + d.watcher.OnInterface(d) return d, nil } @@ -100,13 +92,9 @@ func (d *Daemon) Run() error { signals := util.SetupSignals(util.SigUpdate) - d.logger.Debug("Started initial synchronization") - if err := d.Watcher.Sync(); err != nil { - d.logger.Fatal("Initial synchronization failed", zap.Error(err)) - } - d.logger.Debug("Finished initial synchronization") + go d.watcher.Watch() - go d.Watcher.Watch() + d.watcher.Sync() out: for { @@ -142,7 +130,7 @@ func (d *Daemon) Restart() { } func (d *Daemon) Sync() error { - if err := d.Watcher.Sync(); err != nil { + if err := d.watcher.Sync(); err != nil { return err } @@ -156,13 +144,7 @@ func (d *Daemon) Sync() error { } func (d *Daemon) Close() error { - for _, dev := range d.devices { - if err := dev.Close(); err != nil { - return fmt.Errorf("failed to delete device: %w", err) - } - } - - if err := d.Watcher.Close(); err != nil { + if err := d.watcher.Close(); err != nil { return fmt.Errorf("failed to close watcher: %w", err) } @@ -172,7 +154,13 @@ func (d *Daemon) Close() error { } } - if err := d.Client.Close(); err != nil { + for _, dev := range d.devices { + if err := dev.Close(); err != nil { + return fmt.Errorf("failed to delete device: %w", err) + } + } + + if err := d.client.Close(); err != nil { return fmt.Errorf("failed to close WireGuard client: %w", err) } @@ -189,7 +177,7 @@ func (d *Daemon) CreateDevicesFromArgs() error { var devs wg.DeviceList var err error - if devs, err = d.Client.Devices(); err != nil { + if devs, err = d.client.Devices(); err != nil { return fmt.Errorf("failed to get existing WireGuard devices: %w", err) } @@ -227,7 +215,7 @@ func (d *Daemon) InterfaceByCore(ci *core.Interface) *Interface { } func (d *Daemon) InterfaceByName(name string) *Interface { - ci := d.Watcher.InterfaceByName(name) + ci := d.watcher.InterfaceByName(name) if ci == nil { return nil } @@ -236,7 +224,7 @@ func (d *Daemon) InterfaceByName(name string) *Interface { } func (d *Daemon) InterfaceByPublicKey(pk crypto.Key) *Interface { - ci := d.Watcher.InterfaceByPublicKey(pk) + ci := d.watcher.InterfaceByPublicKey(pk) if ci == nil { return nil } @@ -245,7 +233,7 @@ func (d *Daemon) InterfaceByPublicKey(pk crypto.Key) *Interface { } func (d *Daemon) InterfaceByIndex(idx int) *Interface { - ci := d.Watcher.InterfaceByIndex(idx) + ci := d.watcher.InterfaceByIndex(idx) if ci == nil { return nil } @@ -262,3 +250,11 @@ func (d *Daemon) ForEachInterface(cb func(i *Interface) error) error { return nil } + +func (d *Daemon) ConfigureDevice(name string, cfg wgtypes.Config) error { + if err := d.client.ConfigureDevice(name, cfg); err != nil { + return err + } + + return d.watcher.Sync() +} diff --git a/pkg/daemon/feature/autocfg/autocfg.go b/pkg/daemon/feature/autocfg/autocfg.go index a4654d62..fa20f1ab 100644 --- a/pkg/daemon/feature/autocfg/autocfg.go +++ b/pkg/daemon/feature/autocfg/autocfg.go @@ -132,7 +132,7 @@ func (a *Interface) configureWireGuardInterface() error { } if cfg.PrivateKey != nil || cfg.ListenPort != nil { - if err := a.Daemon.Client.ConfigureDevice(a.Name(), cfg); err != nil { + if err := a.Daemon.ConfigureDevice(a.Name(), cfg); err != nil { return fmt.Errorf("failed to configure device: %w", err) } } diff --git a/pkg/daemon/feature/autocfg/handlers.go b/pkg/daemon/feature/autocfg/handlers.go index 0381ac55..c8d2f4c6 100644 --- a/pkg/daemon/feature/autocfg/handlers.go +++ b/pkg/daemon/feature/autocfg/handlers.go @@ -10,7 +10,6 @@ import ( ) func (a *Interface) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core.InterfaceModifier) { - // Update link-local addresses in case the interface key has changed if mod&core.InterfaceModifiedPrivateKey != 0 { oldPk := crypto.Key(old.PublicKey) diff --git a/pkg/daemon/feature/epdisc/peer.go b/pkg/daemon/feature/epdisc/peer.go index 60de18de..2ec5909c 100644 --- a/pkg/daemon/feature/epdisc/peer.go +++ b/pkg/daemon/feature/epdisc/peer.go @@ -11,7 +11,6 @@ import ( "github.com/pion/ice/v2" "go.uber.org/zap" - "github.com/stv0g/cunicu/pkg/config" "github.com/stv0g/cunicu/pkg/core" "github.com/stv0g/cunicu/pkg/crypto" "github.com/stv0g/cunicu/pkg/daemon/feature/epdisc/proxy" @@ -28,18 +27,14 @@ import ( type Peer struct { *core.Peer - Discovery *Interface - - config *config.InterfaceSettings - agent *ice.Agent - backend signaling.Backend - proxy proxy.Proxy - connectionState util.AtomicEnum[icex.ConnectionState] - lastStateChange time.Time - lastEndpoint *net.UDPAddr - restarts uint - credentials protoepdisc.Credentials - + intf *Interface + agent *ice.Agent + proxy proxy.Proxy + connectionState util.AtomicEnum[icex.ConnectionState] + lastStateChange time.Time + lastEndpoint *net.UDPAddr + restarts uint + credentials protoepdisc.Credentials signalingMessages chan *signaling.Message connectionStateChanges chan icex.ConnectionState @@ -50,8 +45,8 @@ func NewPeer(cp *core.Peer, e *Interface) (*Peer, error) { var err error p := &Peer{ - Peer: cp, - Discovery: e, + Peer: cp, + intf: e, signalingMessages: make(chan *signaling.Message, 32), connectionStateChanges: make(chan icex.ConnectionState, 32), @@ -65,7 +60,7 @@ func NewPeer(cp *core.Peer, e *Interface) (*Peer, error) { // Initialize signaling channel kp := p.PublicPrivateKeyPair() - if _, err := p.backend.Subscribe(context.Background(), kp, p); err != nil { + if _, err := p.intf.Daemon.Backend.Subscribe(context.Background(), kp, p); err != nil { // TODO: Attempt retry? return nil, fmt.Errorf("failed to subscribe to offers: %w", err) } @@ -98,7 +93,7 @@ func (p *Peer) Resubscribe(ctx context.Context, skOld crypto.Key) error { // Create new subscription kpNew := p.PublicPrivateKeyPair() - if _, err := p.backend.Subscribe(ctx, kpNew, p); err != nil { + if _, err := p.intf.Daemon.Backend.Subscribe(ctx, kpNew, p); err != nil { return fmt.Errorf("failed to subscribe to offers: %w", err) } @@ -108,7 +103,7 @@ func (p *Peer) Resubscribe(ctx context.Context, skOld crypto.Key) error { Theirs: p.PublicKey(), } - if _, err := p.backend.Unsubscribe(ctx, kpOld, p); err != nil { + if _, err := p.intf.Daemon.Backend.Unsubscribe(ctx, kpOld, p); err != nil { return fmt.Errorf("failed to unsubscribe from offers: %w", err) } @@ -174,7 +169,7 @@ func (p *Peer) sendCredentials(need bool) error { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - if err := p.backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil { + if err := p.intf.Daemon.Backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil { return err } @@ -191,7 +186,7 @@ func (p *Peer) sendCandidate(c ice.Candidate) error { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) defer cancel() - if err := p.backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil { + if err := p.intf.Daemon.Backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil { return err } @@ -211,7 +206,7 @@ func (p *Peer) createAgent() error { // Prepare ICE agent configuration pk := p.Interface.PublicKey() - acfg, err := p.config.AgentConfig(context.Background(), &pk) + acfg, err := p.intf.Settings.AgentConfig(context.Background(), &pk) if err != nil { return fmt.Errorf("failed to generate ICE agent configuration: %w", err) } @@ -219,11 +214,11 @@ func (p *Peer) createAgent() error { // Do not use WireGuard interfaces for ICE origFilter := acfg.InterfaceFilter acfg.InterfaceFilter = func(name string) bool { - return origFilter(name) && p.Discovery.Daemon.InterfaceByName(name) == nil + return origFilter(name) && p.intf.Daemon.InterfaceByName(name) == nil } - acfg.UDPMux = p.Discovery.udpMux - acfg.UDPMuxSrflx = p.Discovery.udpMuxSrflx + acfg.UDPMux = p.intf.udpMux + acfg.UDPMuxSrflx = p.intf.udpMuxSrflx acfg.LoggerFactory = log.NewPionLoggerFactory(p.logger) p.credentials = protoepdisc.NewCredentials() @@ -353,7 +348,7 @@ func (p *Peer) setConnectionState(new icex.ConnectionState) icex.ConnectionState zap.String("new", strings.ToLower(new.String())), zap.String("previous", strings.ToLower(prev.String()))) - for _, h := range p.Discovery.onConnectionStateChange { + for _, h := range p.intf.onConnectionStateChange { h.OnConnectionStateChange(p, new, prev) } @@ -371,7 +366,7 @@ func (p *Peer) setConnectionStateIf(prev, new icex.ConnectionState) bool { zap.String("new", strings.ToLower(new.String())), zap.String("previous", strings.ToLower(prev.String()))) - for _, h := range p.Discovery.onConnectionStateChange { + for _, h := range p.intf.onConnectionStateChange { h.OnConnectionStateChange(p, new, prev) } } diff --git a/pkg/daemon/feature/hsync/hsync.go b/pkg/daemon/feature/hsync/hsync.go index c525f352..1a230846 100644 --- a/pkg/daemon/feature/hsync/hsync.go +++ b/pkg/daemon/feature/hsync/hsync.go @@ -20,7 +20,7 @@ func init() { daemon.Features["hsync"] = &daemon.FeaturePlugin{ New: New, Description: "Hosts synchronization", - Order: 40, + Order: 100, } } diff --git a/pkg/watcher/watcher.go b/pkg/watcher/watcher.go index aeeb8119..52d9ed66 100644 --- a/pkg/watcher/watcher.go +++ b/pkg/watcher/watcher.go @@ -58,6 +58,7 @@ type Watcher struct { events chan InterfaceEvent errors chan error stop chan any + manual chan any // Settings filter InterfaceFilterFunc @@ -80,6 +81,7 @@ func New(client *wgctrl.Client, interval time.Duration, filter InterfaceFilterFu events: make(chan InterfaceEvent, 16), errors: make(chan error, 16), stop: make(chan any), + manual: make(chan any, 16), logger: zap.L().Named("watcher"), }, nil @@ -117,18 +119,23 @@ func (w *Watcher) Watch() { out: for { select { + case <-w.manual: + logger.Debug("Start manual interface synchronization") + if err := w.sync(); err != nil { + w.logger.Error("Synchronization failed", zap.Error(err)) + } + // We still a need periodic sync we can not (yet) monitor WireGuard interfaces // for changes via a netlink socket (patch is pending) case <-ticker.C: - logger.Debug("Started periodic interface synchronization") - if err := w.Sync(); err != nil { + logger.Debug("Start periodic interface synchronization") + if err := w.sync(); err != nil { w.logger.Error("Synchronization failed", zap.Error(err)) } - logger.Debug("Completed periodic interface synchronization") case event := <-w.events: logger.Debug("Received interface event", zap.Any("event", event)) - if err := w.Sync(); err != nil { + if err := w.sync(); err != nil { w.logger.Error("Synchronization failed", zap.Error(err)) } @@ -142,6 +149,12 @@ out: } func (w *Watcher) Sync() error { + w.manual <- nil + + return nil +} + +func (w *Watcher) sync() error { var err error var new = []*wgtypes.Device{} var old = w.devices