mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-09 10:30:10 +08:00
Improved readability of send/receive code
This commit is contained in:
257
src/receive.go
257
src/receive.go
@@ -128,7 +128,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
// read next datagram
|
||||
|
||||
size, raddr, err := conn.ReadFromUDP(buffer[:]) // Blocks sometimes
|
||||
size, raddr, err := conn.ReadFromUDP(buffer[:])
|
||||
|
||||
if err != nil {
|
||||
break
|
||||
@@ -222,7 +222,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
}
|
||||
|
||||
func (device *Device) RoutineDecryption() {
|
||||
var elem *QueueInboundElement
|
||||
|
||||
var nonce [chacha20poly1305.NonceSize]byte
|
||||
|
||||
logDebug := device.log.Debug
|
||||
@@ -230,50 +230,51 @@ func (device *Device) RoutineDecryption() {
|
||||
|
||||
for {
|
||||
select {
|
||||
case elem = <-device.queue.decryption:
|
||||
case <-device.signal.stop:
|
||||
logDebug.Println("Routine, decryption worker, stopped")
|
||||
return
|
||||
}
|
||||
|
||||
// check if dropped
|
||||
case elem := <-device.queue.decryption:
|
||||
|
||||
if elem.IsDropped() {
|
||||
continue
|
||||
}
|
||||
// check if dropped
|
||||
|
||||
// split message into fields
|
||||
|
||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||
content := elem.packet[MessageTransportOffsetContent:]
|
||||
|
||||
// decrypt with key-pair
|
||||
|
||||
var err error
|
||||
copy(nonce[4:], counter)
|
||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||
elem.keyPair.receive.mutex.RLock()
|
||||
if elem.keyPair.receive.aead == nil {
|
||||
// very unlikely (the key was deleted during queuing)
|
||||
elem.Drop()
|
||||
} else {
|
||||
elem.packet, err = elem.keyPair.receive.aead.Open(
|
||||
elem.buffer[:0],
|
||||
nonce[:],
|
||||
content,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
elem.Drop()
|
||||
if elem.IsDropped() {
|
||||
continue
|
||||
}
|
||||
|
||||
// split message into fields
|
||||
|
||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
||||
content := elem.packet[MessageTransportOffsetContent:]
|
||||
|
||||
// decrypt with key-pair
|
||||
|
||||
copy(nonce[4:], counter)
|
||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
||||
elem.keyPair.receive.mutex.RLock()
|
||||
if elem.keyPair.receive.aead == nil {
|
||||
// very unlikely (the key was deleted during queuing)
|
||||
elem.Drop()
|
||||
} else {
|
||||
var err error
|
||||
elem.packet, err = elem.keyPair.receive.aead.Open(
|
||||
elem.buffer[:0],
|
||||
nonce[:],
|
||||
content,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
elem.Drop()
|
||||
}
|
||||
}
|
||||
|
||||
elem.keyPair.receive.mutex.RUnlock()
|
||||
elem.mutex.Unlock()
|
||||
}
|
||||
elem.keyPair.receive.mutex.RUnlock()
|
||||
elem.mutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
/* Handles incomming packets related to handshake
|
||||
*
|
||||
*
|
||||
*/
|
||||
func (device *Device) RoutineHandshake() {
|
||||
|
||||
@@ -473,7 +474,6 @@ func (device *Device) RoutineHandshake() {
|
||||
}
|
||||
|
||||
func (peer *Peer) RoutineSequentialReceiver() {
|
||||
var elem *QueueInboundElement
|
||||
|
||||
device := peer.device
|
||||
|
||||
@@ -483,118 +483,119 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
|
||||
|
||||
for {
|
||||
// wait for decryption
|
||||
|
||||
select {
|
||||
case <-peer.signal.stop:
|
||||
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
|
||||
return
|
||||
case elem = <-peer.queue.inbound:
|
||||
}
|
||||
elem.mutex.Lock()
|
||||
|
||||
// process packet
|
||||
case elem := <-peer.queue.inbound:
|
||||
|
||||
if elem.IsDropped() {
|
||||
continue
|
||||
}
|
||||
// wait for decryption
|
||||
|
||||
// check for replay
|
||||
|
||||
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
|
||||
continue
|
||||
}
|
||||
|
||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||
peer.TimerAnyAuthenticatedPacketReceived()
|
||||
peer.KeepKeyFreshReceiving()
|
||||
|
||||
// check if using new key-pair
|
||||
|
||||
kp := &peer.keyPairs
|
||||
kp.mutex.Lock()
|
||||
if kp.next == elem.keyPair {
|
||||
peer.TimerHandshakeComplete()
|
||||
if kp.previous != nil {
|
||||
device.DeleteKeyPair(kp.previous)
|
||||
}
|
||||
kp.previous = kp.current
|
||||
kp.current = kp.next
|
||||
kp.next = nil
|
||||
}
|
||||
kp.mutex.Unlock()
|
||||
|
||||
// check for keep-alive
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
logDebug.Println("Received keep-alive from", peer.String())
|
||||
continue
|
||||
}
|
||||
peer.TimerDataReceived()
|
||||
|
||||
// verify source and strip padding
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
elem.mutex.Lock()
|
||||
if elem.IsDropped() {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
// check for replay
|
||||
|
||||
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
peer.TimerAnyAuthenticatedPacketTraversal()
|
||||
peer.TimerAnyAuthenticatedPacketReceived()
|
||||
peer.KeepKeyFreshReceiving()
|
||||
|
||||
// verify IPv4 source
|
||||
// check if using new key-pair
|
||||
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.routingTable.LookupIPv4(src) != peer {
|
||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||
kp := &peer.keyPairs
|
||||
kp.mutex.Lock()
|
||||
if kp.next == elem.keyPair {
|
||||
peer.TimerHandshakeComplete()
|
||||
if kp.previous != nil {
|
||||
device.DeleteKeyPair(kp.previous)
|
||||
}
|
||||
kp.previous = kp.current
|
||||
kp.current = kp.next
|
||||
kp.next = nil
|
||||
}
|
||||
kp.mutex.Unlock()
|
||||
|
||||
// check for keep-alive
|
||||
|
||||
if len(elem.packet) == 0 {
|
||||
logDebug.Println("Received keep-alive from", peer.String())
|
||||
continue
|
||||
}
|
||||
peer.TimerDataReceived()
|
||||
|
||||
// verify source and strip padding
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case ipv4.Version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv4 source
|
||||
|
||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
||||
if device.routingTable.LookupIPv4(src) != peer {
|
||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||
continue
|
||||
}
|
||||
|
||||
case ipv6.Version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
}
|
||||
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv6 source
|
||||
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.routingTable.LookupIPv6(src) != peer {
|
||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
logInfo.Println("Packet with invalid IP version from", peer.String())
|
||||
continue
|
||||
}
|
||||
|
||||
case ipv6.Version:
|
||||
// write to tun
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < ipv6.HeaderLen {
|
||||
continue
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||
_, err := device.tun.device.Write(elem.packet)
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
if err != nil {
|
||||
logError.Println("Failed to write packet to TUN device:", err)
|
||||
}
|
||||
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += ipv6.HeaderLen
|
||||
if int(length) > len(elem.packet) {
|
||||
continue
|
||||
}
|
||||
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv6 source
|
||||
|
||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||
if device.routingTable.LookupIPv6(src) != peer {
|
||||
logInfo.Println("Packet with unallowed source IP from", peer.String())
|
||||
continue
|
||||
}
|
||||
|
||||
default:
|
||||
logInfo.Println("Packet with invalid IP version from", peer.String())
|
||||
continue
|
||||
}
|
||||
|
||||
// write to tun
|
||||
|
||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||
_, err := device.tun.device.Write(elem.packet)
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
if err != nil {
|
||||
logError.Println("Failed to write packet to TUN device:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user