Files
cunicu/pkg/device/device_kernel.go
Steffen Vogel 3bee839348 fix: Update copyright years
Signed-off-by: Steffen Vogel <post@steffenvogel.de>
2025-01-01 22:45:39 +01:00

139 lines
2.8 KiB
Go

// SPDX-FileCopyrightText: 2023-2025 Steffen Vogel <post@steffenvogel.de>
// SPDX-License-Identifier: Apache-2.0
package device
import (
"errors"
"fmt"
"net"
"go.uber.org/zap"
wgconn "golang.zx2c4.com/wireguard/conn"
wgdevice "golang.zx2c4.com/wireguard/device"
"cunicu.li/cunicu/pkg/link"
"cunicu.li/cunicu/pkg/log"
"cunicu.li/cunicu/pkg/wg"
)
var errNotWireGuardLink = errors.New("link is not a WireGuard link")
type KernelDevice struct {
link.Link
ListenPort int
bind *wg.Bind
logger *log.Logger
}
func NewKernelDevice(name string) (*KernelDevice, error) {
logger := log.Global.Named("dev").With(
zap.String("dev", name),
zap.String("type", "kernel"),
)
lnk, err := link.CreateWireGuardLink(name)
if err != nil {
return nil, fmt.Errorf("failed to create WireGuard link: %w", err)
}
return &KernelDevice{
Link: lnk,
bind: wg.NewBind(logger),
logger: logger,
}, nil
}
func FindKernelDevice(name string) (*KernelDevice, error) {
logger := log.Global.Named("dev").With(
zap.String("dev", name),
zap.String("type", "kernel"),
)
lnk, err := link.FindLink(name)
if err != nil {
return nil, fmt.Errorf("failed to find WireGuard link: %w", err)
}
// TODO: Is this portable?
if lnk.Type() != link.TypeWireGuard {
return nil, fmt.Errorf("%w: %s", errNotWireGuardLink, lnk.Name())
}
return &KernelDevice{
Link: lnk,
bind: wg.NewBind(logger),
logger: logger,
}, nil
}
func (d *KernelDevice) Bind() *wg.Bind {
return d.bind
}
func (d *KernelDevice) BindUpdate() error {
if d.ListenPort == 0 {
d.logger.Debug("Skip bind update as we no listen port yet")
return nil
}
if err := d.bind.Close(); err != nil {
return fmt.Errorf("failed to close bind: %w", err)
}
rcvFns, _, err := d.bind.Open(0)
if err != nil {
return fmt.Errorf("failed to open bind: %w", err)
}
for _, rcvFn := range rcvFns {
go d.doReceive(rcvFn)
}
return nil
}
func (d *KernelDevice) doReceive(rcvFn wgconn.ReceiveFunc) {
d.logger.Debug("Receive worker started")
batchSize := 1
packets := make([][]byte, batchSize)
sizes := make([]int, batchSize)
eps := make([]wgconn.Endpoint, batchSize)
packets[0] = make([]byte, wgdevice.MaxMessageSize)
for {
n, err := rcvFn(packets, sizes, eps)
if err != nil {
if errors.Is(err, net.ErrClosed) {
break
}
d.logger.Error("Failed to receive from bind", zap.Error(err))
continue
} else if n == 0 || sizes[0] == 0 {
continue
}
ep := eps[0].(*wg.BindEndpoint) //nolint:forcetypeassert
kc, ok := ep.Conn.(wg.BindKernelConn)
if !ok {
d.logger.Error("No kernel connection found", zap.String("ep", ep.DstToString()))
continue
}
if _, err := kc.WriteKernel(packets[0][:sizes[0]]); err != nil {
d.logger.Error("Failed to write to kernel", zap.Error(err))
}
}
d.logger.Debug("Receive worker stopped")
}