mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-07 09:31:07 +08:00
device: use new model queues for handshakes
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
@@ -48,15 +48,6 @@ func (elem *QueueInboundElement) clearPointers() {
|
||||
elem.endpoint = nil
|
||||
}
|
||||
|
||||
func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, elem QueueHandshakeElement) bool {
|
||||
select {
|
||||
case queue <- elem:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/* Called when a new authenticated message has been received
|
||||
*
|
||||
* NOTE: Not thread safe, but called by sequential receiver!
|
||||
@@ -81,6 +72,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
|
||||
device.queue.decryption.wg.Done()
|
||||
device.queue.handshake.wg.Done()
|
||||
device.net.stopping.Done()
|
||||
}()
|
||||
|
||||
@@ -202,16 +194,15 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||
}
|
||||
|
||||
if okay {
|
||||
if (device.addToHandshakeQueue(
|
||||
device.queue.handshake,
|
||||
QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
packet: packet,
|
||||
endpoint: endpoint,
|
||||
},
|
||||
)) {
|
||||
select {
|
||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
||||
msgType: msgType,
|
||||
buffer: buffer,
|
||||
packet: packet,
|
||||
endpoint: endpoint,
|
||||
}:
|
||||
buffer = device.GetMessageBuffer()
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -251,34 +242,13 @@ func (device *Device) RoutineDecryption() {
|
||||
/* Handles incoming packets related to handshake
|
||||
*/
|
||||
func (device *Device) RoutineHandshake() {
|
||||
var elem QueueHandshakeElement
|
||||
var ok bool
|
||||
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: handshake worker - stopped")
|
||||
device.state.stopping.Done()
|
||||
if elem.buffer != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}()
|
||||
|
||||
device.log.Verbosef("Routine: handshake worker - started")
|
||||
|
||||
for {
|
||||
if elem.buffer != nil {
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
elem.buffer = nil
|
||||
}
|
||||
|
||||
select {
|
||||
case elem, ok = <-device.queue.handshake:
|
||||
case <-device.signals.stop:
|
||||
return
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for elem := range device.queue.handshake.c {
|
||||
|
||||
// handle cookie fields and ratelimiting
|
||||
|
||||
@@ -293,7 +263,7 @@ func (device *Device) RoutineHandshake() {
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
if err != nil {
|
||||
device.log.Verbosef("Failed to decode cookie reply")
|
||||
return
|
||||
goto skip
|
||||
}
|
||||
|
||||
// lookup peer from index
|
||||
@@ -301,7 +271,7 @@ func (device *Device) RoutineHandshake() {
|
||||
entry := device.indexTable.Lookup(reply.Receiver)
|
||||
|
||||
if entry.peer == nil {
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume reply
|
||||
@@ -313,7 +283,7 @@ func (device *Device) RoutineHandshake() {
|
||||
}
|
||||
}
|
||||
|
||||
continue
|
||||
goto skip
|
||||
|
||||
case MessageInitiationType, MessageResponseType:
|
||||
|
||||
@@ -321,7 +291,7 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
||||
device.log.Verbosef("Received packet with invalid mac1")
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// endpoints destination address is the source of the datagram
|
||||
@@ -332,19 +302,19 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
||||
device.SendHandshakeCookie(&elem)
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// check ratelimiter
|
||||
|
||||
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// handle handshake initiation/response content
|
||||
@@ -359,7 +329,7 @@ func (device *Device) RoutineHandshake() {
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode initiation message")
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume initiation
|
||||
@@ -367,7 +337,7 @@ func (device *Device) RoutineHandshake() {
|
||||
peer := device.ConsumeMessageInitiation(&msg)
|
||||
if peer == nil {
|
||||
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// update timers
|
||||
@@ -392,7 +362,7 @@ func (device *Device) RoutineHandshake() {
|
||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
||||
if err != nil {
|
||||
device.log.Errorf("Failed to decode response message")
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// consume response
|
||||
@@ -400,7 +370,7 @@ func (device *Device) RoutineHandshake() {
|
||||
peer := device.ConsumeMessageResponse(&msg)
|
||||
if peer == nil {
|
||||
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
// update endpoint
|
||||
@@ -420,13 +390,15 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
if err != nil {
|
||||
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
||||
continue
|
||||
goto skip
|
||||
}
|
||||
|
||||
peer.timersSessionDerived()
|
||||
peer.timersHandshakeComplete()
|
||||
peer.SendKeepalive()
|
||||
}
|
||||
skip:
|
||||
device.PutMessageBuffer(elem.buffer)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user