mirror of
https://git.zx2c4.com/wireguard-go
synced 2025-10-07 01:22:54 +08:00
Added source verification
This commit is contained in:
137
src/receive.go
137
src/receive.go
@@ -72,12 +72,48 @@ func addToHandshakeQueue(
|
||||
}
|
||||
}
|
||||
|
||||
/* Routine determining the busy state of the interface
|
||||
*
|
||||
* TODO: prehaps nicer to do this in response to events
|
||||
* TODO: more well reasoned definition of "busy"
|
||||
*/
|
||||
func (device *Device) RoutineBusyMonitor() {
|
||||
samples := 0
|
||||
interval := time.Second
|
||||
for timer := time.NewTimer(interval); ; {
|
||||
|
||||
select {
|
||||
case <-device.signal.stop:
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
// compute busy heuristic
|
||||
|
||||
if len(device.queue.handshake) > QueueHandshakeBusySize {
|
||||
samples += 1
|
||||
} else if samples > 0 {
|
||||
samples -= 1
|
||||
}
|
||||
samples %= 30
|
||||
busy := samples > 5
|
||||
|
||||
// update busy state
|
||||
|
||||
if busy {
|
||||
atomic.StoreInt32(&device.congestionState, CongestionStateUnderLoad)
|
||||
} else {
|
||||
atomic.StoreInt32(&device.congestionState, CongestionStateOkay)
|
||||
}
|
||||
|
||||
timer.Reset(interval)
|
||||
}
|
||||
}
|
||||
|
||||
func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
debugLog := device.log.Debug
|
||||
debugLog.Println("Routine, receive incomming, started")
|
||||
|
||||
errorLog := device.log.Error
|
||||
logDebug := device.log.Debug
|
||||
logDebug.Println("Routine, receive incomming, started")
|
||||
|
||||
var buffer []byte
|
||||
|
||||
@@ -122,33 +158,6 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
case MessageInitiationType, MessageResponseType:
|
||||
|
||||
// verify mac1
|
||||
|
||||
if !device.mac.CheckMAC1(packet) {
|
||||
debugLog.Println("Received packet with invalid mac1")
|
||||
return
|
||||
}
|
||||
|
||||
// check if busy, TODO: refine definition of "busy"
|
||||
|
||||
busy := len(device.queue.handshake) > QueueHandshakeBusySize
|
||||
if busy && !device.mac.CheckMAC2(packet, raddr) {
|
||||
sender := binary.LittleEndian.Uint32(packet[4:8]) // "sender" always follows "type"
|
||||
reply, err := device.CreateMessageCookieReply(packet, sender, raddr)
|
||||
if err != nil {
|
||||
errorLog.Println("Failed to create cookie reply:", err)
|
||||
return
|
||||
}
|
||||
writer := bytes.NewBuffer(packet[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
packet = writer.Bytes()
|
||||
_, err = device.net.conn.WriteToUDP(packet, raddr)
|
||||
if err != nil {
|
||||
debugLog.Println("Failed to send cookie reply:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// add to handshake queue
|
||||
|
||||
addToHandshakeQueue(
|
||||
@@ -173,7 +182,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
reader := bytes.NewReader(packet)
|
||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
||||
if err != nil {
|
||||
debugLog.Println("Failed to decode cookie reply")
|
||||
logDebug.Println("Failed to decode cookie reply")
|
||||
return
|
||||
}
|
||||
device.ConsumeMessageCookieReply(&reply)
|
||||
@@ -218,7 +227,7 @@ func (device *Device) RoutineReceiveIncomming() {
|
||||
|
||||
default:
|
||||
// unknown message type
|
||||
debugLog.Println("Got unknown message from:", raddr)
|
||||
logDebug.Println("Got unknown message from:", raddr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -285,6 +294,38 @@ func (device *Device) RoutineHandshake() {
|
||||
|
||||
func() {
|
||||
|
||||
// verify mac1
|
||||
|
||||
if !device.mac.CheckMAC1(elem.packet) {
|
||||
logDebug.Println("Received packet with invalid mac1")
|
||||
return
|
||||
}
|
||||
|
||||
// verify mac2
|
||||
|
||||
busy := atomic.LoadInt32(&device.congestionState) == CongestionStateUnderLoad
|
||||
|
||||
if busy && !device.mac.CheckMAC2(elem.packet, elem.source) {
|
||||
sender := binary.LittleEndian.Uint32(elem.packet[4:8]) // "sender" always follows "type"
|
||||
reply, err := device.CreateMessageCookieReply(elem.packet, sender, elem.source)
|
||||
if err != nil {
|
||||
logError.Println("Failed to create cookie reply:", err)
|
||||
return
|
||||
}
|
||||
writer := bytes.NewBuffer(elem.packet[:0])
|
||||
binary.Write(writer, binary.LittleEndian, reply)
|
||||
elem.packet = writer.Bytes()
|
||||
_, err = device.net.conn.WriteToUDP(elem.packet, elem.source)
|
||||
if err != nil {
|
||||
logDebug.Println("Failed to send cookie reply:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ratelimit
|
||||
|
||||
// handle messages
|
||||
|
||||
switch elem.msgType {
|
||||
case MessageInitiationType:
|
||||
|
||||
@@ -321,12 +362,12 @@ func (device *Device) RoutineHandshake() {
|
||||
logError.Println("Failed to create response message:", err)
|
||||
return
|
||||
}
|
||||
|
||||
outElem := device.NewOutboundElement()
|
||||
writer := bytes.NewBuffer(outElem.data[:0])
|
||||
binary.Write(writer, binary.LittleEndian, response)
|
||||
elem.packet = writer.Bytes()
|
||||
peer.mac.AddMacs(elem.packet)
|
||||
device.log.Debug.Println(elem.packet)
|
||||
addToOutboundQueue(peer.queue.outbound, outElem)
|
||||
|
||||
case MessageResponseType:
|
||||
@@ -388,7 +429,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
}
|
||||
elem.mutex.Lock()
|
||||
|
||||
// process IP packet
|
||||
// process packet
|
||||
|
||||
func() {
|
||||
if elem.IsDropped() {
|
||||
@@ -407,30 +448,54 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
||||
return
|
||||
}
|
||||
|
||||
// strip padding
|
||||
// verify source and strip padding
|
||||
|
||||
switch elem.packet[0] >> 4 {
|
||||
case IPv4version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < IPv4headerSize {
|
||||
return
|
||||
}
|
||||
|
||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv4 source
|
||||
|
||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
||||
if device.routingTable.LookupIPv4(dst) != peer {
|
||||
return
|
||||
}
|
||||
|
||||
case IPv6version:
|
||||
|
||||
// strip padding
|
||||
|
||||
if len(elem.packet) < IPv6headerSize {
|
||||
return
|
||||
}
|
||||
|
||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
||||
length := binary.BigEndian.Uint16(field)
|
||||
length += IPv6headerSize
|
||||
elem.packet = elem.packet[:length]
|
||||
|
||||
// verify IPv6 source
|
||||
|
||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
||||
if device.routingTable.LookupIPv6(dst) != peer {
|
||||
return
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Debug.Println("Receieved packet with unknown IP version")
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddUint64(&peer.rxBytes, uint64(len(elem.packet)))
|
||||
addToInboundQueue(device.queue.inbound, elem)
|
||||
}()
|
||||
}
|
||||
|
Reference in New Issue
Block a user