Files
cunicu/pkg/core/peer.go
Steffen Vogel 92a7ad2f7f daemon: use per-interface features
Signed-off-by: Steffen Vogel <post@steffenvogel.de>
2022-10-07 18:30:50 +02:00

310 lines
7.5 KiB
Go

package core
import (
"fmt"
"math/big"
"net"
"time"
"go.uber.org/zap"
"github.com/stv0g/cunicu/pkg/crypto"
"github.com/stv0g/cunicu/pkg/util"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
proto "github.com/stv0g/cunicu/pkg/proto"
coreproto "github.com/stv0g/cunicu/pkg/proto/core"
)
type SignalingState int
type Peer struct {
*wgtypes.Peer
Name string
Interface *Interface
LastReceiveTime time.Time
LastTransmitTime time.Time
onModified []PeerModifiedHandler
client *wgctrl.Client
logger *zap.Logger
}
// NewPeer creates a peer and initiates a new ICE agent
func NewPeer(wgp *wgtypes.Peer, i *Interface) (*Peer, error) {
p := &Peer{
Interface: i,
Peer: wgp,
onModified: []PeerModifiedHandler{},
client: i.client,
logger: zap.L().Named("peer").With(
zap.String("intf", i.Name()),
zap.Any("peer", wgp.PublicKey),
),
}
// We intentionally prune the AllowedIP list here for the initial sync
p.Peer.AllowedIPs = nil
return p, nil
}
// Getters
// String returns the peers public key as a base64-encoded string
func (p *Peer) String() string {
if p.Name != "" {
return fmt.Sprintf("[%s]%s", p.Name, p.PublicKey())
}
return p.PublicKey().String()
}
// PublicKey returns the Curve25199 public key of the WireGuard peer
func (p *Peer) PublicKey() crypto.Key {
return crypto.Key(p.Peer.PublicKey)
}
// PresharedKey returns the Curve25199 preshared key of the WireGuard peer
func (p *Peer) PresharedKey() crypto.Key {
return crypto.Key(p.Peer.PresharedKey)
}
// PublicKeyPair returns both the public key of the local (our) and remote peer (theirs)
func (p *Peer) PublicKeyPair() *crypto.PublicKeyPair {
return &crypto.PublicKeyPair{
Ours: p.Interface.PublicKey(),
Theirs: p.PublicKey(),
}
}
// PublicPrivateKeyPair returns both the public key of the local (our) and remote peer (theirs)
func (p *Peer) PublicPrivateKeyPair() *crypto.KeyPair {
return &crypto.KeyPair{
Ours: p.Interface.PrivateKey(),
Theirs: p.PublicKey(),
}
}
// IsControlling determines if the peer is controlling the ICE session
// by selecting the peer which has the smaller public key
func (p *Peer) IsControlling() bool {
var pkOur, pkTheir big.Int
pkOur.SetBytes(p.Interface.Device.PublicKey[:])
pkTheir.SetBytes(p.Peer.PublicKey[:])
return pkOur.Cmp(&pkTheir) == -1
}
// WireGuardConfig return the WireGuard peer configuration
func (p *Peer) WireGuardConfig() *wgtypes.PeerConfig {
cfg := &wgtypes.PeerConfig{
PublicKey: *(*wgtypes.Key)(&p.Peer.PublicKey),
Endpoint: p.Endpoint,
AllowedIPs: p.Peer.AllowedIPs,
}
if crypto.Key(p.Peer.PresharedKey).IsSet() {
cfg.PresharedKey = &p.Peer.PresharedKey
}
if p.PersistentKeepaliveInterval > 0 {
cfg.PersistentKeepaliveInterval = &p.PersistentKeepaliveInterval
}
return cfg
}
// OnModified registers a new handler which is called whenever the peer has been modified
func (p *Peer) OnModified(h PeerModifiedHandler) {
p.onModified = append(p.onModified, h)
}
// UpdateEndpoint sets a new endpoint for the WireGuard peer
func (p *Peer) UpdateEndpoint(addr *net.UDPAddr) error {
cfg := wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
PublicKey: p.Peer.PublicKey,
UpdateOnly: true,
ReplaceAllowedIPs: false,
Endpoint: addr,
},
},
}
if err := p.client.ConfigureDevice(p.Interface.Device.Name, cfg); err != nil {
return fmt.Errorf("failed to update peer endpoint: %w", err)
}
p.logger.Debug("Peer endpoint updated", zap.Any("ep", addr))
return nil
}
// SetPresharedKey sets a new preshared key for the WireGuard peer
func (p *Peer) SetPresharedKey(psk *crypto.Key) error {
cfg := wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
PublicKey: p.Peer.PublicKey,
UpdateOnly: true,
PresharedKey: (*wgtypes.Key)(psk),
},
},
}
if err := p.client.ConfigureDevice(p.Interface.Device.Name, cfg); err != nil {
return fmt.Errorf("failed to update peer preshared key: %w", err)
}
// TODO: Remove PSK from log
p.logger.Debug("Peer preshared key updated", zap.Any("psk", psk))
return nil
}
// AddAllowedIP adds a new IP network to the allowed ip list of the WireGuard peer
func (p *Peer) AddAllowedIP(a net.IPNet) error {
if util.SliceContains(p.AllowedIPs, func(n net.IPNet) bool {
return util.CmpNet(n, a) == 0
}) {
p.logger.Warn("Not adding already existing allowed IP", zap.Any("ip", a))
return nil
}
cfg := wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
UpdateOnly: true,
PublicKey: wgtypes.Key(p.PublicKey()),
AllowedIPs: []net.IPNet{a},
},
},
}
p.logger.Debug("Adding new allowed IP", zap.Any("ip", a))
return p.client.ConfigureDevice(p.Interface.Device.Name, cfg)
}
// RemoveAllowedIP removes a new IP network from the allowed ip list of the WireGuard peer
func (p *Peer) RemoveAllowedIP(a net.IPNet) error {
ips := util.SliceFilter(p.Peer.AllowedIPs, func(b net.IPNet) bool {
return util.CmpNet(a, b) != 0
})
// TODO: Check is net is in AllowedIPs before attempting removing it
cfg := wgtypes.Config{
Peers: []wgtypes.PeerConfig{
{
UpdateOnly: true,
PublicKey: wgtypes.Key(p.PublicKey()),
ReplaceAllowedIPs: true,
AllowedIPs: ips,
},
},
}
p.logger.Debug("Remove allowed IP", zap.Any("ip", a))
return p.client.ConfigureDevice(p.Interface.Device.Name, cfg)
}
func (p *Peer) Sync(new *wgtypes.Peer) (PeerModifier, []net.IPNet, []net.IPNet) {
old := p.Peer
mod := PeerModifiedNone
now := time.Now()
// Compare peer properties
if new.PresharedKey != old.PresharedKey {
mod |= PeerModifiedPresharedKey
}
if util.CmpEndpoint(new.Endpoint, old.Endpoint) != 0 {
mod |= PeerModifiedEndpoint
}
if new.PersistentKeepaliveInterval != old.PersistentKeepaliveInterval {
mod |= PeerModifiedKeepaliveInterval
}
if new.LastHandshakeTime != old.LastHandshakeTime {
mod |= PeerModifiedHandshakeTime
}
if new.ReceiveBytes != old.ReceiveBytes {
mod |= PeerModifiedReceiveBytes
p.LastReceiveTime = now
}
if new.TransmitBytes != old.TransmitBytes {
mod |= PeerModifiedTransmitBytes
p.LastTransmitTime = now
}
if new.ProtocolVersion != old.ProtocolVersion {
mod |= PeerModifiedProtocolVersion
}
// Find changes in AllowedIP list
ipsAdded, ipsRemoved, _ := util.SliceDiffFunc(old.AllowedIPs, new.AllowedIPs, util.CmpNet)
if len(ipsAdded) > 0 || len(ipsRemoved) > 0 {
mod |= PeerModifiedAllowedIPs
}
p.Peer = new
if mod != PeerModifiedNone {
p.logger.Info("Peer has been modified", zap.Strings("changes", mod.Strings()))
for _, h := range p.onModified {
h.OnPeerModified(p, old, mod, ipsAdded, ipsRemoved)
}
}
return mod, ipsAdded, ipsRemoved
}
func (p *Peer) Marshal() *coreproto.Peer {
allowedIPs := []string{}
for _, allowedIP := range p.AllowedIPs {
allowedIPs = append(allowedIPs, allowedIP.String())
}
q := &coreproto.Peer{
PublicKey: p.PublicKey().Bytes(),
PersistentKeepaliveInterval: uint32(p.PersistentKeepaliveInterval / time.Second),
TransmitBytes: p.TransmitBytes,
ReceiveBytes: p.ReceiveBytes,
AllowedIps: allowedIPs,
ProtocolVersion: uint32(p.ProtocolVersion),
}
if p.Endpoint != nil {
q.Endpoint = p.Endpoint.String()
}
if p.PresharedKey().IsSet() {
q.PresharedKey = p.PresharedKey().Bytes()
}
if !p.LastHandshakeTime.IsZero() {
q.LastHandshakeTimestamp = proto.Time(p.LastHandshakeTime)
}
if !p.LastReceiveTime.IsZero() {
q.LastReceiveTimestamp = proto.Time(p.LastReceiveTime)
}
if !p.LastTransmitTime.IsZero() {
q.LastTransmitTimestamp = proto.Time(p.LastTransmitTime)
}
return q
}