Improved readability of send/receive code

This commit is contained in:
Mathias Hall-Andersen
2017-09-09 15:03:01 +02:00
parent 89d0045214
commit f212795e51
2 changed files with 199 additions and 239 deletions

View File

@@ -128,7 +128,7 @@ func (device *Device) RoutineReceiveIncomming() {
// read next datagram
size, raddr, err := conn.ReadFromUDP(buffer[:]) // Blocks sometimes
size, raddr, err := conn.ReadFromUDP(buffer[:])
if err != nil {
break
@@ -222,7 +222,7 @@ func (device *Device) RoutineReceiveIncomming() {
}
func (device *Device) RoutineDecryption() {
var elem *QueueInboundElement
var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug
@@ -230,50 +230,51 @@ func (device *Device) RoutineDecryption() {
for {
select {
case elem = <-device.queue.decryption:
case <-device.signal.stop:
logDebug.Println("Routine, decryption worker, stopped")
return
}
// check if dropped
case elem := <-device.queue.decryption:
if elem.IsDropped() {
continue
}
// check if dropped
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt with key-pair
var err error
copy(nonce[4:], counter)
elem.counter = binary.LittleEndian.Uint64(counter)
elem.keyPair.receive.mutex.RLock()
if elem.keyPair.receive.aead == nil {
// very unlikely (the key was deleted during queuing)
elem.Drop()
} else {
elem.packet, err = elem.keyPair.receive.aead.Open(
elem.buffer[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.Drop()
if elem.IsDropped() {
continue
}
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// decrypt with key-pair
copy(nonce[4:], counter)
elem.counter = binary.LittleEndian.Uint64(counter)
elem.keyPair.receive.mutex.RLock()
if elem.keyPair.receive.aead == nil {
// very unlikely (the key was deleted during queuing)
elem.Drop()
} else {
var err error
elem.packet, err = elem.keyPair.receive.aead.Open(
elem.buffer[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.Drop()
}
}
elem.keyPair.receive.mutex.RUnlock()
elem.mutex.Unlock()
}
elem.keyPair.receive.mutex.RUnlock()
elem.mutex.Unlock()
}
}
/* Handles incomming packets related to handshake
*
*
*/
func (device *Device) RoutineHandshake() {
@@ -473,7 +474,6 @@ func (device *Device) RoutineHandshake() {
}
func (peer *Peer) RoutineSequentialReceiver() {
var elem *QueueInboundElement
device := peer.device
@@ -483,118 +483,119 @@ func (peer *Peer) RoutineSequentialReceiver() {
logDebug.Println("Routine, sequential receiver, started for peer", peer.id)
for {
// wait for decryption
select {
case <-peer.signal.stop:
logDebug.Println("Routine, sequential receiver, stopped for peer", peer.id)
return
case elem = <-peer.queue.inbound:
}
elem.mutex.Lock()
// process packet
case elem := <-peer.queue.inbound:
if elem.IsDropped() {
continue
}
// wait for decryption
// check for replay
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
continue
}
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
peer.KeepKeyFreshReceiving()
// check if using new key-pair
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.next == elem.keyPair {
peer.TimerHandshakeComplete()
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
}
kp.previous = kp.current
kp.current = kp.next
kp.next = nil
}
kp.mutex.Unlock()
// check for keep-alive
if len(elem.packet) == 0 {
logDebug.Println("Received keep-alive from", peer.String())
continue
}
peer.TimerDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
elem.mutex.Lock()
if elem.IsDropped() {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
// check for replay
if !elem.keyPair.replayFilter.ValidateCounter(elem.counter) {
continue
}
elem.packet = elem.packet[:length]
peer.TimerAnyAuthenticatedPacketTraversal()
peer.TimerAnyAuthenticatedPacketReceived()
peer.KeepKeyFreshReceiving()
// verify IPv4 source
// check if using new key-pair
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.next == elem.keyPair {
peer.TimerHandshakeComplete()
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
}
kp.previous = kp.current
kp.current = kp.next
kp.next = nil
}
kp.mutex.Unlock()
// check for keep-alive
if len(elem.packet) == 0 {
logDebug.Println("Received keep-alive from", peer.String())
continue
}
peer.TimerDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.routingTable.LookupIPv4(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
continue
}
case ipv6.Version:
// strip padding
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
continue
}
default:
logInfo.Println("Packet with invalid IP version from", peer.String())
continue
}
case ipv6.Version:
// write to tun
// strip padding
if len(elem.packet) < ipv6.HeaderLen {
continue
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logError.Println("Failed to write packet to TUN device:", err)
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.routingTable.LookupIPv6(src) != peer {
logInfo.Println("Packet with unallowed source IP from", peer.String())
continue
}
default:
logInfo.Println("Packet with invalid IP version from", peer.String())
continue
}
// write to tun
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.packet)
device.PutMessageBuffer(elem.buffer)
if err != nil {
logError.Println("Failed to write packet to TUN device:", err)
}
}
}