Files
cunicu/pkg/daemon.go
Steffen Vogel 0d0dc2334e replace syscall package by x/sys/unix
Signed-off-by: Steffen Vogel <post@steffenvogel.de>
2022-05-05 13:10:45 +02:00

298 lines
6.8 KiB
Go

package pkg
import (
"fmt"
"sync"
"time"
"github.com/cilium/ebpf/rlimit"
"go.uber.org/zap"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/wgctrl"
"riasc.eu/wice/internal"
"riasc.eu/wice/internal/config"
"riasc.eu/wice/pkg/intf"
"riasc.eu/wice/pkg/pb"
"riasc.eu/wice/pkg/signaling"
"go.uber.org/zap/zapio"
)
type Daemon struct {
Backend signaling.Backend
Client *wgctrl.Client
Config *config.Config
Interfaces intf.InterfaceList
InterfaceLock sync.RWMutex
Events chan *pb.Event
eventListeners map[chan *pb.Event]interface{}
eventListenersLock sync.Mutex
stop chan interface{}
logger *zap.Logger
}
func NewDaemon(cfg *config.Config) (*Daemon, error) {
var err error
logger := zap.L().Named("daemon")
events := make(chan *pb.Event, 16)
// Create backend
var backend signaling.Backend
if len(cfg.Backends) == 1 {
backend, err = signaling.NewBackend(&signaling.BackendConfig{
URI: cfg.Backends[0],
}, events)
} else {
backend, err = signaling.NewMultiBackend(cfg.Backends, &signaling.BackendConfig{}, events)
}
if err != nil {
return nil, fmt.Errorf("failed to initialize signaling backend: %w", err)
}
// Disable memlock for loading eBPF programs
if err := rlimit.RemoveMemlock(); err != nil {
panic(fmt.Errorf("failed to remove memlock: %w", err))
}
// Create Wireguard netlink socket
client, err := wgctrl.New()
if err != nil {
return nil, fmt.Errorf("failed to create Wireguard client: %w", err)
}
d := &Daemon{
Config: cfg,
Client: client,
Backend: backend,
Interfaces: intf.InterfaceList{},
InterfaceLock: sync.RWMutex{},
Events: events,
eventListeners: map[chan *pb.Event]interface{}{},
stop: make(chan interface{}),
logger: logger,
}
// Check if Wireguard interface can be created by the kernel
if !cfg.IsSet("wg.userspace") {
cfg.Set("wg.userspace", !intf.WireguardModuleExists())
}
return d, nil
}
func (d *Daemon) Run() error {
ifEvents := make(chan intf.InterfaceEvent, 16)
errors := make(chan error, 16)
signals := internal.SetupSignals()
if err := d.CreateInterfacesFromArgs(); err != nil {
return fmt.Errorf("failed to create interfaces: %w", err)
}
if err := intf.WatchWireguardUserspaceInterfaces(ifEvents, errors); err != nil {
return fmt.Errorf("failed to watch userspace interfaces: %w", err)
}
if err := intf.WatchWireguardKernelInterfaces(ifEvents, errors); err != nil {
return fmt.Errorf("failed to watch kernel interfaces: %w", err)
}
d.logger.Debug("Starting initial interface sync")
d.SyncAllInterfaces()
ticker := time.NewTicker(d.Config.GetDuration("watch_interval"))
out:
for {
select {
// We still a need periodic sync we can not (yet) monitor Wireguard interfaces
// for changes via a netlink socket (patch is pending)
case <-ticker.C:
d.logger.Debug("Starting periodic interface sync")
d.SyncAllInterfaces()
case <-d.stop:
d.logger.Info("Received stop request")
break out
case event := <-d.Events:
if event.Time == nil {
event.Time = pb.TimeNow()
}
d.eventListenersLock.Lock()
for ch := range d.eventListeners {
ch <- event
}
d.eventListenersLock.Unlock()
event.Log(d.logger, "Broadcasted event", zap.Int("listeners", len(d.eventListeners)))
case event := <-ifEvents:
d.logger.Debug("Received interface event", zap.String("event", event.String()))
d.SyncAllInterfaces()
case err := <-errors:
d.logger.Error("Failed to watch for interface changes", zap.Error(err))
case sig := <-signals:
d.logger.Debug("Received signal", zap.String("signal", sig.String()))
switch sig {
case unix.SIGUSR1:
d.SyncAllInterfaces()
default:
break out
}
}
}
return nil
}
func (d *Daemon) Close() error {
if err := d.Interfaces.Close(); err != nil {
return fmt.Errorf("failed to close interface: %w", err)
}
return nil
}
func (d *Daemon) GetInterfaceByName(name string) intf.Interface {
for _, intf := range d.Interfaces {
if intf.Name() == name {
return intf
}
}
return nil
}
func (d *Daemon) SyncAllInterfaces() error {
devices, err := d.Client.Devices()
if err != nil {
d.logger.Fatal("Failed to list Wireguard interfaces", zap.Error(err))
}
syncedInterfaces := intf.InterfaceList{}
keepInterfaces := intf.InterfaceList{}
for _, device := range devices {
if !d.Config.WireguardInterfaceFilter.MatchString(device.Name) {
continue // Skip interfaces which dont match the filter
}
// Find matching interface
interf := d.GetInterfaceByName(device.Name)
if interf == nil { // new interface
d.logger.Info("Adding new interface", zap.String("intf", device.Name))
i, err := intf.NewInterface(device, d.Client, d.Backend, d.Events, d.Config)
if err != nil {
d.logger.Fatal("Failed to create new interface",
zap.Error(err),
zap.String("intf", device.Name),
)
}
interf = &i
d.Interfaces = append(d.Interfaces, &i)
} else { // existing interface
d.logger.Debug("Sync existing interface", zap.String("intf", device.Name))
if err := interf.Sync(device); err != nil {
d.logger.Fatal("Failed to sync interface",
zap.Error(err),
zap.String("intf", device.Name),
)
}
}
syncedInterfaces = append(syncedInterfaces, interf)
}
for _, intf := range d.Interfaces {
i := syncedInterfaces.GetByName(intf.Name())
if i == nil {
d.logger.Info("Removing vanished interface", zap.String("intf", intf.Name()))
if err := intf.Close(); err != nil {
d.logger.Fatal("Failed to close interface", zap.Error(err))
}
d.Events <- &pb.Event{
Type: pb.Event_INTERFACE_REMOVED,
Interface: intf.Name(),
}
} else {
keepInterfaces = append(keepInterfaces, intf)
}
}
d.Interfaces = keepInterfaces
return nil
}
func (d *Daemon) CreateInterfacesFromArgs() error {
var devs intf.Devices
devs, err := d.Client.Devices()
if err != nil {
return err
}
for _, interfName := range d.Config.WireguardInterfaces {
dev := devs.GetByName(interfName)
if dev != nil {
d.logger.Warn("Interface already exists. Skipping..", zap.Any("intf", interfName))
continue
}
var interf intf.Interface
if d.Config.GetBool("wg.userspace") {
interf, err = intf.CreateUserInterface(interfName, d.Client, d.Backend, d.Events, d.Config)
} else {
interf, err = intf.CreateKernelInterface(interfName, d.Client, d.Backend, d.Events, d.Config)
}
if err != nil {
return fmt.Errorf("failed to create Wireguard device: %w", err)
}
if d.logger.Core().Enabled(zap.DebugLevel) {
d.logger.Debug("Intialized interface:")
interf.DumpConfig(&zapio.Writer{Log: d.logger})
}
d.Interfaces = append(d.Interfaces, interf)
}
return nil
}
func (d *Daemon) Stop() error {
close(d.stop)
return nil
}
func (d *Daemon) ListenEvents() chan *pb.Event {
events := make(chan *pb.Event, 100)
d.eventListenersLock.Lock()
d.eventListeners[events] = nil
d.eventListenersLock.Unlock()
return events
}