conn, device, tun: implement vectorized I/O plumbing

Accept packet vectors for reading and writing in the tun.Device and
conn.Bind interfaces, so that the internal plumbing between these
interfaces now passes a vector of packets. Vectors move untouched
between these interfaces, i.e. if 128 packets are received from
conn.Bind.Read(), 128 packets are passed to tun.Device.Write(). There is
no internal buffering.

Currently, existing implementations are only adjusted to have vectors
of length one. Subsequent patches will improve that.

Also, as a related fixup, use the unix and windows packages rather than
the syscall package when possible.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited
2023-03-02 14:48:02 -08:00
committed by Jason A. Donenfeld
parent 21636207a6
commit 3bb8fec7e4
25 changed files with 1046 additions and 514 deletions

View File

@@ -66,7 +66,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() {
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
@@ -79,20 +79,33 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
// receive datagrams until conn is closed
buffer := device.GetMessageBuffer()
var (
buffsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
buffs = make([][]byte, maxBatchSize)
err error
size int
endpoint conn.Endpoint
sizes = make([]int, maxBatchSize)
count int
endpoints = make([]conn.Endpoint, maxBatchSize)
deathSpiral int
elemsByPeer = make(map[*Peer]*[]*QueueInboundElement, maxBatchSize)
)
for {
size, endpoint, err = recv(buffer[:])
for i := range buffsArrs {
buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
}
defer func() {
for i := 0; i < maxBatchSize; i++ {
if buffsArrs[i] != nil {
device.PutMessageBuffer(buffsArrs[i])
}
}
}()
for {
count, err = recv(buffs, sizes, endpoints)
if err != nil {
device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) {
return
}
@@ -103,101 +116,122 @@ func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
if deathSpiral < 10 {
deathSpiral++
time.Sleep(time.Second / 3)
buffer = device.GetMessageBuffer()
continue
}
return
}
deathSpiral = 0
if size < MinMessageSize {
continue
}
// check size of packet
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
var okay bool
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
// handle each packet in the batch
for i, size := range sizes[:count] {
if size < MinMessageSize {
continue
}
// lookup key pair
// check size of packet
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
packet := buffsArrs[i][:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = buffsArrs[i]
elem.keypair = keypair
elem.endpoint = endpoints[i]
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
elemsForPeer, ok := elemsByPeer[peer]
if !ok {
elemsForPeer = device.GetInboundElementsSlice()
elemsByPeer[peer] = elemsForPeer
}
*elemsForPeer = append(*elemsForPeer, elem)
buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
if len(packet) != MessageInitiationSize {
continue
}
case MessageResponseType:
if len(packet) != MessageResponseSize {
continue
}
case MessageCookieReplyType:
if len(packet) != MessageCookieReplySize {
continue
}
default:
device.log.Verbosef("Received message with unknown type")
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = buffer
elem.keypair = keypair
elem.endpoint = endpoint
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
// add to decryption queues
if peer.isRunning.Load() {
peer.queue.inbound.c <- elem
device.queue.decryption.c <- elem
buffer = device.GetMessageBuffer()
} else {
device.PutInboundElement(elem)
}
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
default:
device.log.Verbosef("Received message with unknown type")
}
if okay {
select {
case device.queue.handshake.c <- QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
buffer: buffsArrs[i],
packet: packet,
endpoint: endpoint,
endpoint: endpoints[i],
}:
buffer = device.GetMessageBuffer()
buffsArrs[i] = device.GetMessageBuffer()
buffs[i] = buffsArrs[i][:]
default:
}
}
for peer, elems := range elemsByPeer {
if peer.isRunning.Load() {
peer.queue.inbound.c <- elems
for _, elem := range *elems {
device.queue.decryption.c <- elem
}
} else {
for _, elem := range *elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
device.PutInboundElementsSlice(elems)
}
delete(elemsByPeer, peer)
}
}
}
@@ -393,7 +427,7 @@ func (device *Device) RoutineHandshake(id int) {
}
}
func (peer *Peer) RoutineSequentialReceiver() {
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
device := peer.device
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
@@ -401,89 +435,91 @@ func (peer *Peer) RoutineSequentialReceiver() {
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
for elem := range peer.queue.inbound.c {
if elem == nil {
buffs := make([][]byte, 0, maxBatchSize)
for elems := range peer.queue.inbound.c {
if elems == nil {
return
}
var err error
elem.Lock()
if elem.packet == nil {
// decryption failed
goto skip
}
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
goto skip
}
peer.SetEndpointFromPacket(elem.endpoint)
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
goto skip
}
peer.timersDataReceived()
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
goto skip
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
goto skip
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
goto skip
for _, elem := range *elems {
elem.Lock()
if elem.packet == nil {
// decryption failed
continue
}
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
goto skip
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
goto skip
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
goto skip
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
default:
device.log.Verbosef("Packet with invalid IP version from %v", peer)
goto skip
}
peer.SetEndpointFromPacket(elem.endpoint)
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
peer.SendStagedPackets()
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
peer.rxBytes.Add(uint64(len(elem.packet) + MinMessageSize))
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packet to TUN device: %v", err)
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
}
peer.timersDataReceived()
switch elem.packet[0] >> 4 {
case 4:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
continue
}
case 6:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.Lookup(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
continue
}
default:
device.log.Verbosef("Packet with invalid IP version from %v", peer)
continue
}
buffs = append(buffs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
}
if len(peer.queue.inbound.c) == 0 {
err = device.tun.device.Flush()
if err != nil {
peer.device.log.Errorf("Unable to flush packets: %v", err)
if len(buffs) > 0 {
_, err := device.tun.device.Write(buffs, MessageTransportOffsetContent)
if err != nil && !device.isClosed() {
device.log.Errorf("Failed to write packets to TUN device: %v", err)
}
}
skip:
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
for _, elem := range *elems {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
buffs = buffs[:0]
device.PutInboundElementsSlice(elems)
}
}