mirror of
				https://git.zx2c4.com/wireguard-go
				synced 2025-10-31 20:02:37 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			527 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			527 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| /* SPDX-License-Identifier: MIT
 | |
|  *
 | |
|  * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
 | |
|  */
 | |
| 
 | |
| package device
 | |
| 
 | |
| import (
 | |
| 	"runtime"
 | |
| 	"sync"
 | |
| 	"sync/atomic"
 | |
| 	"time"
 | |
| 
 | |
| 	"golang.org/x/net/ipv4"
 | |
| 	"golang.org/x/net/ipv6"
 | |
| 
 | |
| 	"golang.zx2c4.com/wireguard/conn"
 | |
| 	"golang.zx2c4.com/wireguard/ratelimiter"
 | |
| 	"golang.zx2c4.com/wireguard/rwcancel"
 | |
| 	"golang.zx2c4.com/wireguard/tun"
 | |
| )
 | |
| 
 | |
| type Device struct {
 | |
| 	state struct {
 | |
| 		// state holds the device's state. It is accessed atomically.
 | |
| 		// Use the device.deviceState method to read it.
 | |
| 		// device.deviceState does not acquire the mutex, so it captures only a snapshot.
 | |
| 		// During state transitions, the state variable is updated before the device itself.
 | |
| 		// The state is thus either the current state of the device or
 | |
| 		// the intended future state of the device.
 | |
| 		// For example, while executing a call to Up, state will be deviceStateUp.
 | |
| 		// There is no guarantee that that intended future state of the device
 | |
| 		// will become the actual state; Up can fail.
 | |
| 		// The device can also change state multiple times between time of check and time of use.
 | |
| 		// Unsynchronized uses of state must therefore be advisory/best-effort only.
 | |
| 		state uint32 // actually a deviceState, but typed uint32 for convenience
 | |
| 		// stopping blocks until all inputs to Device have been closed.
 | |
| 		stopping sync.WaitGroup
 | |
| 		// mu protects state changes.
 | |
| 		sync.Mutex
 | |
| 	}
 | |
| 
 | |
| 	net struct {
 | |
| 		stopping sync.WaitGroup
 | |
| 		sync.RWMutex
 | |
| 		bind          conn.Bind // bind interface
 | |
| 		netlinkCancel *rwcancel.RWCancel
 | |
| 		port          uint16 // listening port
 | |
| 		fwmark        uint32 // mark value (0 = disabled)
 | |
| 	}
 | |
| 
 | |
| 	staticIdentity struct {
 | |
| 		sync.RWMutex
 | |
| 		privateKey NoisePrivateKey
 | |
| 		publicKey  NoisePublicKey
 | |
| 	}
 | |
| 
 | |
| 	peers struct {
 | |
| 		empty        AtomicBool // empty reports whether len(keyMap) == 0
 | |
| 		sync.RWMutex            // protects keyMap
 | |
| 		keyMap       map[NoisePublicKey]*Peer
 | |
| 	}
 | |
| 
 | |
| 	allowedips    AllowedIPs
 | |
| 	indexTable    IndexTable
 | |
| 	cookieChecker CookieChecker
 | |
| 
 | |
| 	rate struct {
 | |
| 		underLoadUntil int64
 | |
| 		limiter        ratelimiter.Ratelimiter
 | |
| 	}
 | |
| 
 | |
| 	pool struct {
 | |
| 		messageBuffers   *WaitPool
 | |
| 		inboundElements  *WaitPool
 | |
| 		outboundElements *WaitPool
 | |
| 	}
 | |
| 
 | |
| 	queue struct {
 | |
| 		encryption *outboundQueue
 | |
| 		decryption *inboundQueue
 | |
| 		handshake  *handshakeQueue
 | |
| 	}
 | |
| 
 | |
| 	tun struct {
 | |
| 		device tun.Device
 | |
| 		mtu    int32
 | |
| 	}
 | |
| 
 | |
| 	ipcMutex sync.RWMutex
 | |
| 	closed   chan struct{}
 | |
| 	log      *Logger
 | |
| }
 | |
| 
 | |
| // deviceState represents the state of a Device.
 | |
| // There are three states: down, up, closed.
 | |
| // Transitions:
 | |
| //
 | |
| //   down -----+
 | |
| //     ↑↓      ↓
 | |
| //     up -> closed
 | |
| //
 | |
| type deviceState uint32
 | |
| 
 | |
| //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
 | |
| const (
 | |
| 	deviceStateDown deviceState = iota
 | |
| 	deviceStateUp
 | |
| 	deviceStateClosed
 | |
| )
 | |
| 
 | |
| // deviceState returns device.state.state as a deviceState
 | |
| // See those docs for how to interpret this value.
 | |
| func (device *Device) deviceState() deviceState {
 | |
| 	return deviceState(atomic.LoadUint32(&device.state.state))
 | |
| }
 | |
| 
 | |
| // isClosed reports whether the device is closed (or is closing).
 | |
| // See device.state.state comments for how to interpret this value.
 | |
| func (device *Device) isClosed() bool {
 | |
| 	return device.deviceState() == deviceStateClosed
 | |
| }
 | |
| 
 | |
| // isUp reports whether the device is up (or is attempting to come up).
 | |
| // See device.state.state comments for how to interpret this value.
 | |
| func (device *Device) isUp() bool {
 | |
| 	return device.deviceState() == deviceStateUp
 | |
| }
 | |
| 
 | |
| // Must hold device.peers.Lock()
 | |
| func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
 | |
| 	// stop routing and processing of packets
 | |
| 	device.allowedips.RemoveByPeer(peer)
 | |
| 	peer.Stop()
 | |
| 
 | |
| 	// remove from peer map
 | |
| 	delete(device.peers.keyMap, key)
 | |
| 	device.peers.empty.Set(len(device.peers.keyMap) == 0)
 | |
| }
 | |
| 
 | |
| // changeState attempts to change the device state to match want.
 | |
| func (device *Device) changeState(want deviceState) (err error) {
 | |
| 	device.state.Lock()
 | |
| 	defer device.state.Unlock()
 | |
| 	old := device.deviceState()
 | |
| 	if old == deviceStateClosed {
 | |
| 		// once closed, always closed
 | |
| 		device.log.Verbosef("Interface closed, ignored requested state %s", want)
 | |
| 		return nil
 | |
| 	}
 | |
| 	switch want {
 | |
| 	case old:
 | |
| 		return nil
 | |
| 	case deviceStateUp:
 | |
| 		atomic.StoreUint32(&device.state.state, uint32(deviceStateUp))
 | |
| 		err = device.upLocked()
 | |
| 		if err == nil {
 | |
| 			break
 | |
| 		}
 | |
| 		fallthrough // up failed; bring the device all the way back down
 | |
| 	case deviceStateDown:
 | |
| 		atomic.StoreUint32(&device.state.state, uint32(deviceStateDown))
 | |
| 		errDown := device.downLocked()
 | |
| 		if err == nil {
 | |
| 			err = errDown
 | |
| 		}
 | |
| 	}
 | |
| 	device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
 | |
| 	return
 | |
| }
 | |
| 
 | |
| // upLocked attempts to bring the device up and reports whether it succeeded.
 | |
| // The caller must hold device.state.mu and is responsible for updating device.state.state.
 | |
| func (device *Device) upLocked() error {
 | |
| 	if err := device.BindUpdate(); err != nil {
 | |
| 		device.log.Errorf("Unable to update bind: %v", err)
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	device.peers.RLock()
 | |
| 	for _, peer := range device.peers.keyMap {
 | |
| 		peer.Start()
 | |
| 		if atomic.LoadUint32(&peer.persistentKeepaliveInterval) > 0 {
 | |
| 			peer.SendKeepalive()
 | |
| 		}
 | |
| 	}
 | |
| 	device.peers.RUnlock()
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| // downLocked attempts to bring the device down.
 | |
| // The caller must hold device.state.mu and is responsible for updating device.state.state.
 | |
| func (device *Device) downLocked() error {
 | |
| 	err := device.BindClose()
 | |
| 	if err != nil {
 | |
| 		device.log.Errorf("Bind close failed: %v", err)
 | |
| 	}
 | |
| 
 | |
| 	device.peers.RLock()
 | |
| 	for _, peer := range device.peers.keyMap {
 | |
| 		peer.Stop()
 | |
| 	}
 | |
| 	device.peers.RUnlock()
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (device *Device) Up() error {
 | |
| 	return device.changeState(deviceStateUp)
 | |
| }
 | |
| 
 | |
| func (device *Device) Down() error {
 | |
| 	return device.changeState(deviceStateDown)
 | |
| }
 | |
| 
 | |
| func (device *Device) IsUnderLoad() bool {
 | |
| 	// check if currently under load
 | |
| 	now := time.Now()
 | |
| 	underLoad := len(device.queue.handshake.c) >= UnderLoadQueueSize
 | |
| 	if underLoad {
 | |
| 		atomic.StoreInt64(&device.rate.underLoadUntil, now.Add(UnderLoadAfterTime).UnixNano())
 | |
| 		return true
 | |
| 	}
 | |
| 	// check if recently under load
 | |
| 	return atomic.LoadInt64(&device.rate.underLoadUntil) > now.UnixNano()
 | |
| }
 | |
| 
 | |
| func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 | |
| 	// lock required resources
 | |
| 
 | |
| 	device.staticIdentity.Lock()
 | |
| 	defer device.staticIdentity.Unlock()
 | |
| 
 | |
| 	if sk.Equals(device.staticIdentity.privateKey) {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	device.peers.Lock()
 | |
| 	defer device.peers.Unlock()
 | |
| 
 | |
| 	lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
 | |
| 	for _, peer := range device.peers.keyMap {
 | |
| 		peer.handshake.mutex.RLock()
 | |
| 		lockedPeers = append(lockedPeers, peer)
 | |
| 	}
 | |
| 
 | |
| 	// remove peers with matching public keys
 | |
| 
 | |
| 	publicKey := sk.publicKey()
 | |
| 	for key, peer := range device.peers.keyMap {
 | |
| 		if peer.handshake.remoteStatic.Equals(publicKey) {
 | |
| 			peer.handshake.mutex.RUnlock()
 | |
| 			removePeerLocked(device, peer, key)
 | |
| 			peer.handshake.mutex.RLock()
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// update key material
 | |
| 
 | |
| 	device.staticIdentity.privateKey = sk
 | |
| 	device.staticIdentity.publicKey = publicKey
 | |
| 	device.cookieChecker.Init(publicKey)
 | |
| 
 | |
| 	// do static-static DH pre-computations
 | |
| 
 | |
| 	expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
 | |
| 	for _, peer := range device.peers.keyMap {
 | |
| 		handshake := &peer.handshake
 | |
| 		handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
 | |
| 		expiredPeers = append(expiredPeers, peer)
 | |
| 	}
 | |
| 
 | |
| 	for _, peer := range lockedPeers {
 | |
| 		peer.handshake.mutex.RUnlock()
 | |
| 	}
 | |
| 	for _, peer := range expiredPeers {
 | |
| 		peer.ExpireCurrentKeypairs()
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
 | |
| 	device := new(Device)
 | |
| 	device.state.state = uint32(deviceStateDown)
 | |
| 	device.closed = make(chan struct{})
 | |
| 	device.log = logger
 | |
| 	device.tun.device = tunDevice
 | |
| 	mtu, err := device.tun.device.MTU()
 | |
| 	if err != nil {
 | |
| 		device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
 | |
| 		mtu = DefaultMTU
 | |
| 	}
 | |
| 	device.tun.mtu = int32(mtu)
 | |
| 	device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 | |
| 	device.rate.limiter.Init()
 | |
| 	device.indexTable.Init()
 | |
| 	device.PopulatePools()
 | |
| 
 | |
| 	// create queues
 | |
| 
 | |
| 	device.queue.handshake = newHandshakeQueue()
 | |
| 	device.queue.encryption = newOutboundQueue()
 | |
| 	device.queue.decryption = newInboundQueue()
 | |
| 
 | |
| 	// prepare net
 | |
| 
 | |
| 	device.net.port = 0
 | |
| 	device.net.bind = nil
 | |
| 
 | |
| 	// start workers
 | |
| 
 | |
| 	cpus := runtime.NumCPU()
 | |
| 	device.state.stopping.Wait()
 | |
| 	device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
 | |
| 	for i := 0; i < cpus; i++ {
 | |
| 		go device.RoutineEncryption()
 | |
| 		go device.RoutineDecryption()
 | |
| 		go device.RoutineHandshake()
 | |
| 	}
 | |
| 
 | |
| 	device.state.stopping.Add(1)      // RoutineReadFromTUN
 | |
| 	device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
 | |
| 	go device.RoutineReadFromTUN()
 | |
| 	go device.RoutineTUNEventReader()
 | |
| 
 | |
| 	return device
 | |
| }
 | |
| 
 | |
| func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
 | |
| 	device.peers.RLock()
 | |
| 	defer device.peers.RUnlock()
 | |
| 
 | |
| 	return device.peers.keyMap[pk]
 | |
| }
 | |
| 
 | |
| func (device *Device) RemovePeer(key NoisePublicKey) {
 | |
| 	device.peers.Lock()
 | |
| 	defer device.peers.Unlock()
 | |
| 	// stop peer and remove from routing
 | |
| 
 | |
| 	peer, ok := device.peers.keyMap[key]
 | |
| 	if ok {
 | |
| 		removePeerLocked(device, peer, key)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func (device *Device) RemoveAllPeers() {
 | |
| 	device.peers.Lock()
 | |
| 	defer device.peers.Unlock()
 | |
| 
 | |
| 	for key, peer := range device.peers.keyMap {
 | |
| 		removePeerLocked(device, peer, key)
 | |
| 	}
 | |
| 
 | |
| 	device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 | |
| }
 | |
| 
 | |
| func (device *Device) Close() {
 | |
| 	device.state.Lock()
 | |
| 	defer device.state.Unlock()
 | |
| 	if device.isClosed() {
 | |
| 		return
 | |
| 	}
 | |
| 	atomic.StoreUint32(&device.state.state, uint32(deviceStateClosed))
 | |
| 	device.log.Verbosef("Device closing")
 | |
| 
 | |
| 	device.tun.device.Close()
 | |
| 	device.downLocked()
 | |
| 
 | |
| 	// Remove peers before closing queues,
 | |
| 	// because peers assume that queues are active.
 | |
| 	device.RemoveAllPeers()
 | |
| 
 | |
| 	// We kept a reference to the encryption and decryption queues,
 | |
| 	// in case we started any new peers that might write to them.
 | |
| 	// No new peers are coming; we are done with these queues.
 | |
| 	device.queue.encryption.wg.Done()
 | |
| 	device.queue.decryption.wg.Done()
 | |
| 	device.queue.handshake.wg.Done()
 | |
| 	device.state.stopping.Wait()
 | |
| 
 | |
| 	device.rate.limiter.Close()
 | |
| 
 | |
| 	device.log.Verbosef("Device closed")
 | |
| 	close(device.closed)
 | |
| }
 | |
| 
 | |
| func (device *Device) Wait() chan struct{} {
 | |
| 	return device.closed
 | |
| }
 | |
| 
 | |
| func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
 | |
| 	if !device.isUp() {
 | |
| 		return
 | |
| 	}
 | |
| 
 | |
| 	device.peers.RLock()
 | |
| 	for _, peer := range device.peers.keyMap {
 | |
| 		peer.keypairs.RLock()
 | |
| 		sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
 | |
| 		peer.keypairs.RUnlock()
 | |
| 		if sendKeepalive {
 | |
| 			peer.SendKeepalive()
 | |
| 		}
 | |
| 	}
 | |
| 	device.peers.RUnlock()
 | |
| }
 | |
| 
 | |
| func unsafeCloseBind(device *Device) error {
 | |
| 	var err error
 | |
| 	netc := &device.net
 | |
| 	if netc.netlinkCancel != nil {
 | |
| 		netc.netlinkCancel.Cancel()
 | |
| 	}
 | |
| 	if netc.bind != nil {
 | |
| 		err = netc.bind.Close()
 | |
| 		netc.bind = nil
 | |
| 	}
 | |
| 	netc.stopping.Wait()
 | |
| 	return err
 | |
| }
 | |
| 
 | |
| func (device *Device) Bind() conn.Bind {
 | |
| 	device.net.Lock()
 | |
| 	defer device.net.Unlock()
 | |
| 	return device.net.bind
 | |
| }
 | |
| 
 | |
| func (device *Device) BindSetMark(mark uint32) error {
 | |
| 	device.net.Lock()
 | |
| 	defer device.net.Unlock()
 | |
| 
 | |
| 	// check if modified
 | |
| 	if device.net.fwmark == mark {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	// update fwmark on existing bind
 | |
| 	device.net.fwmark = mark
 | |
| 	if device.isUp() && device.net.bind != nil {
 | |
| 		if err := device.net.bind.SetMark(mark); err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// clear cached source addresses
 | |
| 	device.peers.RLock()
 | |
| 	for _, peer := range device.peers.keyMap {
 | |
| 		peer.Lock()
 | |
| 		defer peer.Unlock()
 | |
| 		if peer.endpoint != nil {
 | |
| 			peer.endpoint.ClearSrc()
 | |
| 		}
 | |
| 	}
 | |
| 	device.peers.RUnlock()
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (device *Device) BindUpdate() error {
 | |
| 	device.net.Lock()
 | |
| 	defer device.net.Unlock()
 | |
| 
 | |
| 	// close existing sockets
 | |
| 	if err := unsafeCloseBind(device); err != nil {
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// open new sockets
 | |
| 	if !device.isUp() {
 | |
| 		return nil
 | |
| 	}
 | |
| 
 | |
| 	// bind to new port
 | |
| 	var err error
 | |
| 	netc := &device.net
 | |
| 	netc.bind, netc.port, err = conn.CreateBind(netc.port)
 | |
| 	if err != nil {
 | |
| 		netc.bind = nil
 | |
| 		netc.port = 0
 | |
| 		return err
 | |
| 	}
 | |
| 	netc.netlinkCancel, err = device.startRouteListener(netc.bind)
 | |
| 	if err != nil {
 | |
| 		netc.bind.Close()
 | |
| 		netc.bind = nil
 | |
| 		netc.port = 0
 | |
| 		return err
 | |
| 	}
 | |
| 
 | |
| 	// set fwmark
 | |
| 	if netc.fwmark != 0 {
 | |
| 		err = netc.bind.SetMark(netc.fwmark)
 | |
| 		if err != nil {
 | |
| 			return err
 | |
| 		}
 | |
| 	}
 | |
| 
 | |
| 	// clear cached source addresses
 | |
| 	device.peers.RLock()
 | |
| 	for _, peer := range device.peers.keyMap {
 | |
| 		peer.Lock()
 | |
| 		defer peer.Unlock()
 | |
| 		if peer.endpoint != nil {
 | |
| 			peer.endpoint.ClearSrc()
 | |
| 		}
 | |
| 	}
 | |
| 	device.peers.RUnlock()
 | |
| 
 | |
| 	// start receiving routines
 | |
| 	device.net.stopping.Add(2)
 | |
| 	device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
 | |
| 	device.queue.handshake.wg.Add(2)  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
 | |
| 	go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
 | |
| 	go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
 | |
| 
 | |
| 	device.log.Verbosef("UDP bind has been updated")
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (device *Device) BindClose() error {
 | |
| 	device.net.Lock()
 | |
| 	err := unsafeCloseBind(device)
 | |
| 	device.net.Unlock()
 | |
| 	return err
 | |
| }
 | 
