conn: introduce new package that splits out the Bind and Endpoint types

The sticky socket code stays in the device package for now,
as it reaches deeply into the peer list.

This is the first step in an effort to split some code out of
the very busy device package.

Signed-off-by: David Crawshaw <crawshaw@tailscale.com>
This commit is contained in:
David Crawshaw
2019-11-07 11:13:05 -05:00
committed by Jason A. Donenfeld
parent 6aefb61355
commit 203554620d
15 changed files with 562 additions and 452 deletions

View File

@@ -11,15 +11,14 @@ import (
"sync/atomic"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel"
"golang.zx2c4.com/wireguard/tun"
)
const (
DeviceRoutineNumberPerCPU = 3
DeviceRoutineNumberAdditional = 2
)
type Device struct {
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
@@ -39,9 +38,10 @@ type Device struct {
starting sync.WaitGroup
stopping sync.WaitGroup
sync.RWMutex
bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
bind conn.Bind // bind interface
netlinkCancel *rwcancel.RWCancel
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
staticIdentity struct {
@@ -299,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
cpus := runtime.NumCPU()
device.state.starting.Wait()
device.state.stopping.Wait()
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i += 1 {
device.state.starting.Add(3)
device.state.stopping.Add(3)
go device.RoutineEncryption()
go device.RoutineDecryption()
go device.RoutineHandshake()
}
device.state.starting.Add(2)
device.state.stopping.Add(2)
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
@@ -413,3 +415,127 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
}
device.peers.RUnlock()
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.netlinkCancel != nil {
netc.netlinkCancel.Cancel()
}
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = conn.CreateBind(netc.port)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
if err != nil {
netc.bind.Close()
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.starting.Add(2)
device.net.stopping.Add(2)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}