feat(daemon): Allow passing systemd socket to RPC server

Signed-off-by: Steffen Vogel <post@steffenvogel.de>
This commit is contained in:
Steffen Vogel
2025-01-02 00:36:55 +01:00
parent 76fd99d309
commit 44cea5bf5d
2 changed files with 63 additions and 11 deletions

View File

@@ -8,13 +8,17 @@ import (
"fmt"
"math/rand"
"net"
"os"
"strings"
"cunicu.li/cunicu/pkg/os/systemd"
)
var (
errInvalidPortRange = errors.New("minimal port must be larger than maximal port number")
errInvalidNetwork = errors.New("unsupported network")
errNoPortFound = errors.New("failed to find port")
errInvalidPortRange = errors.New("minimal port must be larger than maximal port number")
errInvalidNetwork = errors.New("unsupported network")
errNoPortFound = errors.New("failed to find port")
errNoSystemdListeners = errors.New("no file descriptors passed from systemd")
)
func FindRandomPortToListen(network string, mini, maxi int) (int, error) {
@@ -61,3 +65,57 @@ func canListenOnPort(network string, port int) bool {
return false
}
func Listen(socket string) (l net.Listener, err error) {
var network, address string
if p := strings.SplitN(socket, ":", 2); len(p) >= 2 {
network = p[0]
address = p[1]
} else if p[0] == "systemd" { //nolint:goconst
network = p[0]
} else {
network = "unix"
address = p[0]
}
switch {
case network == "systemd" && address == "":
sdListeners, err := systemd.Listeners()
if err != nil {
return nil, fmt.Errorf("failed to get listeners from systemd: %w", err)
}
if len(sdListeners) == 0 {
return nil, errNoSystemdListeners
}
l = sdListeners[0]
case network == "systemd" && address != "":
sdListeners, err := systemd.ListenersWithNames()
if err != nil {
return nil, fmt.Errorf("failed to get listeners from systemd: %w", err)
}
if ls, ok := sdListeners[address]; !ok || len(ls) == 0 {
return nil, fmt.Errorf("%w: with name %s", errNoSystemdListeners, address)
} else {
l = ls[0]
}
case network == "unix":
if err := os.RemoveAll(address); err != nil {
return nil, fmt.Errorf("failed to remove old socket: %w", err)
}
fallthrough
default:
if l, err = net.Listen(network, address); err != nil {
return nil, fmt.Errorf("failed to listen at %s: %w", socket, err)
}
}
return l, nil
}

View File

@@ -6,8 +6,6 @@ package rpc
import (
"context"
"fmt"
"net"
"os"
"sync"
"go.uber.org/zap"
@@ -16,6 +14,7 @@ import (
"cunicu.li/cunicu/pkg/daemon"
"cunicu.li/cunicu/pkg/log"
xnet "cunicu.li/cunicu/pkg/net"
rpcproto "cunicu.li/cunicu/pkg/proto/rpc"
"cunicu.li/cunicu/pkg/types"
)
@@ -52,12 +51,7 @@ func NewServer(d *daemon.Daemon, socket string) (*Server, error) {
s.signaling = NewSignalingServer(s, d.Backend)
s.epdisc = NewEndpointDiscoveryServer(s)
// Remove old unix sockets
if err := os.RemoveAll(socket); err != nil {
return nil, fmt.Errorf("failed to remove old socket: %w", err)
}
l, err := net.Listen("unix", socket)
l, err := xnet.Listen(socket)
if err != nil {
return nil, fmt.Errorf("failed to listen at %s: %w", socket, err)
}