daemon: keep more member fields private

Signed-off-by: Steffen Vogel <post@steffenvogel.de>
This commit is contained in:
Steffen Vogel
2022-10-01 22:47:15 +02:00
parent 568594ae10
commit b7dba86a9b
6 changed files with 70 additions and 67 deletions

View File

@@ -15,18 +15,18 @@ import (
"github.com/stv0g/cunicu/pkg/wg" "github.com/stv0g/cunicu/pkg/wg"
"go.uber.org/zap" "go.uber.org/zap"
"golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/stv0g/cunicu/pkg/signaling" "github.com/stv0g/cunicu/pkg/signaling"
) )
type Daemon struct { type Daemon struct {
*watcher.Watcher
// Shared // Shared
Backend *signaling.MultiBackend Backend *signaling.MultiBackend
Client *wgctrl.Client client *wgctrl.Client
Config *config.Config Config *config.Config
watcher *watcher.Watcher
devices []device.Device devices []device.Device
interfaces map[*core.Interface]*Interface interfaces map[*core.Interface]*Interface
@@ -56,19 +56,13 @@ func New(cfg *config.Config) (*Daemon, error) {
d.logger = zap.L().Named("daemon") d.logger = zap.L().Named("daemon")
// Initialize some defaults configuration settings at runtime
if err = config.InitDefaults(); err != nil {
return nil, fmt.Errorf("failed to initialize defaults: %w", err)
}
// Create WireGuard netlink socket // Create WireGuard netlink socket
d.Client, err = wgctrl.New() if d.client, err = wgctrl.New(); err != nil {
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard client: %w", err) return nil, fmt.Errorf("failed to create WireGuard client: %w", err)
} }
// Create watcher // Create watcher
if d.Watcher, err = watcher.New(d.Client, cfg.WatchInterval, cfg.InterfaceFilter); err != nil { if d.watcher, err = watcher.New(d.client, cfg.WatchInterval, cfg.InterfaceFilter); err != nil {
return nil, fmt.Errorf("failed to initialize watcher: %w", err) return nil, fmt.Errorf("failed to initialize watcher: %w", err)
} }
@@ -78,13 +72,11 @@ func New(cfg *config.Config) (*Daemon, error) {
urls = append(urls, &u.URL) urls = append(urls, &u.URL)
} }
if d.Backend, err = signaling.NewMultiBackend(urls, &signaling.BackendConfig{ if d.Backend, err = signaling.NewMultiBackend(urls, &signaling.BackendConfig{}); err != nil {
OnReady: []signaling.BackendReadyHandler{},
}); err != nil {
return nil, fmt.Errorf("failed to initialize signaling backend: %w", err) return nil, fmt.Errorf("failed to initialize signaling backend: %w", err)
} }
d.Watcher.OnInterface(d) d.watcher.OnInterface(d)
return d, nil return d, nil
} }
@@ -100,13 +92,9 @@ func (d *Daemon) Run() error {
signals := util.SetupSignals(util.SigUpdate) signals := util.SetupSignals(util.SigUpdate)
d.logger.Debug("Started initial synchronization") go d.watcher.Watch()
if err := d.Watcher.Sync(); err != nil {
d.logger.Fatal("Initial synchronization failed", zap.Error(err))
}
d.logger.Debug("Finished initial synchronization")
go d.Watcher.Watch() d.watcher.Sync()
out: out:
for { for {
@@ -142,7 +130,7 @@ func (d *Daemon) Restart() {
} }
func (d *Daemon) Sync() error { func (d *Daemon) Sync() error {
if err := d.Watcher.Sync(); err != nil { if err := d.watcher.Sync(); err != nil {
return err return err
} }
@@ -156,13 +144,7 @@ func (d *Daemon) Sync() error {
} }
func (d *Daemon) Close() error { func (d *Daemon) Close() error {
for _, dev := range d.devices { if err := d.watcher.Close(); err != nil {
if err := dev.Close(); err != nil {
return fmt.Errorf("failed to delete device: %w", err)
}
}
if err := d.Watcher.Close(); err != nil {
return fmt.Errorf("failed to close watcher: %w", err) return fmt.Errorf("failed to close watcher: %w", err)
} }
@@ -172,7 +154,13 @@ func (d *Daemon) Close() error {
} }
} }
if err := d.Client.Close(); err != nil { for _, dev := range d.devices {
if err := dev.Close(); err != nil {
return fmt.Errorf("failed to delete device: %w", err)
}
}
if err := d.client.Close(); err != nil {
return fmt.Errorf("failed to close WireGuard client: %w", err) return fmt.Errorf("failed to close WireGuard client: %w", err)
} }
@@ -189,7 +177,7 @@ func (d *Daemon) CreateDevicesFromArgs() error {
var devs wg.DeviceList var devs wg.DeviceList
var err error var err error
if devs, err = d.Client.Devices(); err != nil { if devs, err = d.client.Devices(); err != nil {
return fmt.Errorf("failed to get existing WireGuard devices: %w", err) return fmt.Errorf("failed to get existing WireGuard devices: %w", err)
} }
@@ -227,7 +215,7 @@ func (d *Daemon) InterfaceByCore(ci *core.Interface) *Interface {
} }
func (d *Daemon) InterfaceByName(name string) *Interface { func (d *Daemon) InterfaceByName(name string) *Interface {
ci := d.Watcher.InterfaceByName(name) ci := d.watcher.InterfaceByName(name)
if ci == nil { if ci == nil {
return nil return nil
} }
@@ -236,7 +224,7 @@ func (d *Daemon) InterfaceByName(name string) *Interface {
} }
func (d *Daemon) InterfaceByPublicKey(pk crypto.Key) *Interface { func (d *Daemon) InterfaceByPublicKey(pk crypto.Key) *Interface {
ci := d.Watcher.InterfaceByPublicKey(pk) ci := d.watcher.InterfaceByPublicKey(pk)
if ci == nil { if ci == nil {
return nil return nil
} }
@@ -245,7 +233,7 @@ func (d *Daemon) InterfaceByPublicKey(pk crypto.Key) *Interface {
} }
func (d *Daemon) InterfaceByIndex(idx int) *Interface { func (d *Daemon) InterfaceByIndex(idx int) *Interface {
ci := d.Watcher.InterfaceByIndex(idx) ci := d.watcher.InterfaceByIndex(idx)
if ci == nil { if ci == nil {
return nil return nil
} }
@@ -262,3 +250,11 @@ func (d *Daemon) ForEachInterface(cb func(i *Interface) error) error {
return nil return nil
} }
func (d *Daemon) ConfigureDevice(name string, cfg wgtypes.Config) error {
if err := d.client.ConfigureDevice(name, cfg); err != nil {
return err
}
return d.watcher.Sync()
}

View File

@@ -132,7 +132,7 @@ func (a *Interface) configureWireGuardInterface() error {
} }
if cfg.PrivateKey != nil || cfg.ListenPort != nil { if cfg.PrivateKey != nil || cfg.ListenPort != nil {
if err := a.Daemon.Client.ConfigureDevice(a.Name(), cfg); err != nil { if err := a.Daemon.ConfigureDevice(a.Name(), cfg); err != nil {
return fmt.Errorf("failed to configure device: %w", err) return fmt.Errorf("failed to configure device: %w", err)
} }
} }

View File

@@ -10,7 +10,6 @@ import (
) )
func (a *Interface) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core.InterfaceModifier) { func (a *Interface) OnInterfaceModified(i *core.Interface, old *wg.Device, mod core.InterfaceModifier) {
// Update link-local addresses in case the interface key has changed // Update link-local addresses in case the interface key has changed
if mod&core.InterfaceModifiedPrivateKey != 0 { if mod&core.InterfaceModifiedPrivateKey != 0 {
oldPk := crypto.Key(old.PublicKey) oldPk := crypto.Key(old.PublicKey)

View File

@@ -11,7 +11,6 @@ import (
"github.com/pion/ice/v2" "github.com/pion/ice/v2"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/stv0g/cunicu/pkg/config"
"github.com/stv0g/cunicu/pkg/core" "github.com/stv0g/cunicu/pkg/core"
"github.com/stv0g/cunicu/pkg/crypto" "github.com/stv0g/cunicu/pkg/crypto"
"github.com/stv0g/cunicu/pkg/daemon/feature/epdisc/proxy" "github.com/stv0g/cunicu/pkg/daemon/feature/epdisc/proxy"
@@ -28,18 +27,14 @@ import (
type Peer struct { type Peer struct {
*core.Peer *core.Peer
Discovery *Interface intf *Interface
agent *ice.Agent
config *config.InterfaceSettings proxy proxy.Proxy
agent *ice.Agent connectionState util.AtomicEnum[icex.ConnectionState]
backend signaling.Backend lastStateChange time.Time
proxy proxy.Proxy lastEndpoint *net.UDPAddr
connectionState util.AtomicEnum[icex.ConnectionState] restarts uint
lastStateChange time.Time credentials protoepdisc.Credentials
lastEndpoint *net.UDPAddr
restarts uint
credentials protoepdisc.Credentials
signalingMessages chan *signaling.Message signalingMessages chan *signaling.Message
connectionStateChanges chan icex.ConnectionState connectionStateChanges chan icex.ConnectionState
@@ -50,8 +45,8 @@ func NewPeer(cp *core.Peer, e *Interface) (*Peer, error) {
var err error var err error
p := &Peer{ p := &Peer{
Peer: cp, Peer: cp,
Discovery: e, intf: e,
signalingMessages: make(chan *signaling.Message, 32), signalingMessages: make(chan *signaling.Message, 32),
connectionStateChanges: make(chan icex.ConnectionState, 32), connectionStateChanges: make(chan icex.ConnectionState, 32),
@@ -65,7 +60,7 @@ func NewPeer(cp *core.Peer, e *Interface) (*Peer, error) {
// Initialize signaling channel // Initialize signaling channel
kp := p.PublicPrivateKeyPair() kp := p.PublicPrivateKeyPair()
if _, err := p.backend.Subscribe(context.Background(), kp, p); err != nil { if _, err := p.intf.Daemon.Backend.Subscribe(context.Background(), kp, p); err != nil {
// TODO: Attempt retry? // TODO: Attempt retry?
return nil, fmt.Errorf("failed to subscribe to offers: %w", err) return nil, fmt.Errorf("failed to subscribe to offers: %w", err)
} }
@@ -98,7 +93,7 @@ func (p *Peer) Resubscribe(ctx context.Context, skOld crypto.Key) error {
// Create new subscription // Create new subscription
kpNew := p.PublicPrivateKeyPair() kpNew := p.PublicPrivateKeyPair()
if _, err := p.backend.Subscribe(ctx, kpNew, p); err != nil { if _, err := p.intf.Daemon.Backend.Subscribe(ctx, kpNew, p); err != nil {
return fmt.Errorf("failed to subscribe to offers: %w", err) return fmt.Errorf("failed to subscribe to offers: %w", err)
} }
@@ -108,7 +103,7 @@ func (p *Peer) Resubscribe(ctx context.Context, skOld crypto.Key) error {
Theirs: p.PublicKey(), Theirs: p.PublicKey(),
} }
if _, err := p.backend.Unsubscribe(ctx, kpOld, p); err != nil { if _, err := p.intf.Daemon.Backend.Unsubscribe(ctx, kpOld, p); err != nil {
return fmt.Errorf("failed to unsubscribe from offers: %w", err) return fmt.Errorf("failed to unsubscribe from offers: %w", err)
} }
@@ -174,7 +169,7 @@ func (p *Peer) sendCredentials(need bool) error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
if err := p.backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil { if err := p.intf.Daemon.Backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil {
return err return err
} }
@@ -191,7 +186,7 @@ func (p *Peer) sendCandidate(c ice.Candidate) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel() defer cancel()
if err := p.backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil { if err := p.intf.Daemon.Backend.Publish(ctx, p.PublicPrivateKeyPair(), msg); err != nil {
return err return err
} }
@@ -211,7 +206,7 @@ func (p *Peer) createAgent() error {
// Prepare ICE agent configuration // Prepare ICE agent configuration
pk := p.Interface.PublicKey() pk := p.Interface.PublicKey()
acfg, err := p.config.AgentConfig(context.Background(), &pk) acfg, err := p.intf.Settings.AgentConfig(context.Background(), &pk)
if err != nil { if err != nil {
return fmt.Errorf("failed to generate ICE agent configuration: %w", err) return fmt.Errorf("failed to generate ICE agent configuration: %w", err)
} }
@@ -219,11 +214,11 @@ func (p *Peer) createAgent() error {
// Do not use WireGuard interfaces for ICE // Do not use WireGuard interfaces for ICE
origFilter := acfg.InterfaceFilter origFilter := acfg.InterfaceFilter
acfg.InterfaceFilter = func(name string) bool { acfg.InterfaceFilter = func(name string) bool {
return origFilter(name) && p.Discovery.Daemon.InterfaceByName(name) == nil return origFilter(name) && p.intf.Daemon.InterfaceByName(name) == nil
} }
acfg.UDPMux = p.Discovery.udpMux acfg.UDPMux = p.intf.udpMux
acfg.UDPMuxSrflx = p.Discovery.udpMuxSrflx acfg.UDPMuxSrflx = p.intf.udpMuxSrflx
acfg.LoggerFactory = log.NewPionLoggerFactory(p.logger) acfg.LoggerFactory = log.NewPionLoggerFactory(p.logger)
p.credentials = protoepdisc.NewCredentials() p.credentials = protoepdisc.NewCredentials()
@@ -353,7 +348,7 @@ func (p *Peer) setConnectionState(new icex.ConnectionState) icex.ConnectionState
zap.String("new", strings.ToLower(new.String())), zap.String("new", strings.ToLower(new.String())),
zap.String("previous", strings.ToLower(prev.String()))) zap.String("previous", strings.ToLower(prev.String())))
for _, h := range p.Discovery.onConnectionStateChange { for _, h := range p.intf.onConnectionStateChange {
h.OnConnectionStateChange(p, new, prev) h.OnConnectionStateChange(p, new, prev)
} }
@@ -371,7 +366,7 @@ func (p *Peer) setConnectionStateIf(prev, new icex.ConnectionState) bool {
zap.String("new", strings.ToLower(new.String())), zap.String("new", strings.ToLower(new.String())),
zap.String("previous", strings.ToLower(prev.String()))) zap.String("previous", strings.ToLower(prev.String())))
for _, h := range p.Discovery.onConnectionStateChange { for _, h := range p.intf.onConnectionStateChange {
h.OnConnectionStateChange(p, new, prev) h.OnConnectionStateChange(p, new, prev)
} }
} }

View File

@@ -20,7 +20,7 @@ func init() {
daemon.Features["hsync"] = &daemon.FeaturePlugin{ daemon.Features["hsync"] = &daemon.FeaturePlugin{
New: New, New: New,
Description: "Hosts synchronization", Description: "Hosts synchronization",
Order: 40, Order: 100,
} }
} }

View File

@@ -58,6 +58,7 @@ type Watcher struct {
events chan InterfaceEvent events chan InterfaceEvent
errors chan error errors chan error
stop chan any stop chan any
manual chan any
// Settings // Settings
filter InterfaceFilterFunc filter InterfaceFilterFunc
@@ -80,6 +81,7 @@ func New(client *wgctrl.Client, interval time.Duration, filter InterfaceFilterFu
events: make(chan InterfaceEvent, 16), events: make(chan InterfaceEvent, 16),
errors: make(chan error, 16), errors: make(chan error, 16),
stop: make(chan any), stop: make(chan any),
manual: make(chan any, 16),
logger: zap.L().Named("watcher"), logger: zap.L().Named("watcher"),
}, nil }, nil
@@ -117,18 +119,23 @@ func (w *Watcher) Watch() {
out: out:
for { for {
select { select {
case <-w.manual:
logger.Debug("Start manual interface synchronization")
if err := w.sync(); err != nil {
w.logger.Error("Synchronization failed", zap.Error(err))
}
// We still a need periodic sync we can not (yet) monitor WireGuard interfaces // We still a need periodic sync we can not (yet) monitor WireGuard interfaces
// for changes via a netlink socket (patch is pending) // for changes via a netlink socket (patch is pending)
case <-ticker.C: case <-ticker.C:
logger.Debug("Started periodic interface synchronization") logger.Debug("Start periodic interface synchronization")
if err := w.Sync(); err != nil { if err := w.sync(); err != nil {
w.logger.Error("Synchronization failed", zap.Error(err)) w.logger.Error("Synchronization failed", zap.Error(err))
} }
logger.Debug("Completed periodic interface synchronization")
case event := <-w.events: case event := <-w.events:
logger.Debug("Received interface event", zap.Any("event", event)) logger.Debug("Received interface event", zap.Any("event", event))
if err := w.Sync(); err != nil { if err := w.sync(); err != nil {
w.logger.Error("Synchronization failed", zap.Error(err)) w.logger.Error("Synchronization failed", zap.Error(err))
} }
@@ -142,6 +149,12 @@ out:
} }
func (w *Watcher) Sync() error { func (w *Watcher) Sync() error {
w.manual <- nil
return nil
}
func (w *Watcher) sync() error {
var err error var err error
var new = []*wgtypes.Device{} var new = []*wgtypes.Device{}
var old = w.devices var old = w.devices