From 03c0e5427ffe54d6ff7589f45003102d2db3952e Mon Sep 17 00:00:00 2001 From: impact-eintr Date: Fri, 15 Oct 2021 19:17:11 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BF=99=E4=B8=AA=E4=BB=93=E5=BA=93=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E9=87=8F=E5=A4=AA=E5=A4=A7=E4=BA=86=20=E4=BB=A5?= =?UTF-8?q?=E8=AF=BB=E4=B8=BA=E4=B8=BB=20=E4=B8=8D=E5=86=99=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 1 - go.mod | 4 +- netstack/ilist/list.go | 207 + netstack/rand/rand.go | 27 + netstack/sleep/commit_amd64.s | 35 + netstack/sleep/commit_asm.go | 20 + netstack/sleep/commit_noasm.go | 42 + netstack/sleep/empty.s | 15 + netstack/sleep/sleep_test.go | 542 +++ netstack/sleep/sleep_unsafe.go | 395 ++ netstack/tcpip/buffer/prependable.go | 64 + netstack/tcpip/buffer/view.go | 146 + {tcpip => netstack/tcpip}/buffer/view_test.go | 0 netstack/tcpip/checker/checker.go | 588 +++ netstack/tcpip/header/arp.go | 107 + netstack/tcpip/header/checksum.go | 57 + netstack/tcpip/header/eth.go | 72 + netstack/tcpip/header/gue.go | 73 + netstack/tcpip/header/icmpv4.go | 108 + netstack/tcpip/header/icmpv6.go | 121 + netstack/tcpip/header/interfaces.go | 92 + netstack/tcpip/header/ipv4.go | 289 ++ netstack/tcpip/header/ipv6.go | 236 ++ netstack/tcpip/header/ipv6_fragment.go | 146 + netstack/tcpip/header/ipversion_test.go | 67 + netstack/tcpip/header/tcp.go | 536 +++ netstack/tcpip/header/tcp_test.go | 148 + netstack/tcpip/header/udp.go | 117 + netstack/tcpip/link/channel/channel.go | 125 + netstack/tcpip/link/fdbased/endpoint.go | 305 ++ netstack/tcpip/link/fdbased/endpoint_test.go | 341 ++ netstack/tcpip/link/loopback/loopback.go | 87 + .../tcpip/link/rawfile/blockingpoll_unsafe.go | 27 + netstack/tcpip/link/rawfile/errors.go | 70 + netstack/tcpip/link/rawfile/rawfile_unsafe.go | 145 + netstack/tcpip/link/sniffer/pcap.go | 66 + netstack/tcpip/link/sniffer/sniffer.go | 390 ++ netstack/tcpip/link/tuntap/tuntap.go | 141 + netstack/tcpip/network/arp/arp.go | 196 + netstack/tcpip/network/arp/arp_test.go | 149 + .../tcpip/network/fragmentation/frag_heap.go | 77 + .../network/fragmentation/frag_heap_test.go | 126 + .../network/fragmentation/fragmentation.go | 134 + .../fragmentation/fragmentation_test.go | 159 + .../network/fragmentation/reassembler.go | 118 + .../network/fragmentation/reassembler_list.go | 173 + .../network/fragmentation/reassembler_test.go | 105 + netstack/tcpip/network/hash/hash.go | 94 + netstack/tcpip/network/ip_test.go | 604 +++ netstack/tcpip/network/ipv4/icmp.go | 134 + netstack/tcpip/network/ipv4/ipv4.go | 280 ++ netstack/tcpip/network/ipv4/ipv4_test.go | 92 + netstack/tcpip/network/ipv6/icmp.go | 231 ++ netstack/tcpip/network/ipv6/icmp_test.go | 226 ++ netstack/tcpip/network/ipv6/ipv6.go | 187 + netstack/tcpip/ports/ports.go | 185 + netstack/tcpip/ports/ports_test.go | 145 + netstack/tcpip/seqnum/seqnum.go | 67 + netstack/tcpip/stack/linkaddrcache.go | 308 ++ netstack/tcpip/stack/linkaddrcache_test.go | 266 ++ netstack/tcpip/stack/nic.go | 668 +++ netstack/tcpip/stack/registration.go | 373 ++ netstack/tcpip/stack/route.go | 186 + netstack/tcpip/stack/stack.go | 980 +++++ netstack/tcpip/stack/stack_test.go | 849 ++++ netstack/tcpip/stack/transport_demuxer.go | 192 + netstack/tcpip/stack/transport_test.go | 422 ++ netstack/tcpip/tcpip.go | 834 ++++ netstack/tcpip/tcpip_test.go | 195 + netstack/tcpip/time.s | 15 + netstack/tcpip/time_unsafe.go | 42 + netstack/tcpip/transport/ping/endpoint.go | 717 ++++ .../tcpip/transport/ping/ping_packet_list.go | 173 + netstack/tcpip/transport/ping/protocol.go | 124 + netstack/tcpip/transport/tcp/accept.go | 443 ++ netstack/tcpip/transport/tcp/connect.go | 1158 ++++++ netstack/tcpip/transport/tcp/cubic.go | 255 ++ .../tcpip/transport/tcp/dual_stack_test.go | 560 +++ netstack/tcpip/transport/tcp/endpoint.go | 1637 ++++++++ netstack/tcpip/transport/tcp/forwarder.go | 174 + netstack/tcpip/transport/tcp/protocol.go | 241 ++ netstack/tcpip/transport/tcp/rcv.go | 251 ++ netstack/tcpip/transport/tcp/reno.go | 125 + netstack/tcpip/transport/tcp/sack.go | 103 + netstack/tcpip/transport/tcp/segment.go | 184 + netstack/tcpip/transport/tcp/segment_heap.go | 48 + netstack/tcpip/transport/tcp/segment_queue.go | 81 + netstack/tcpip/transport/tcp/snd.go | 797 ++++ netstack/tcpip/transport/tcp/tcp_sack_test.go | 350 ++ .../tcpip/transport/tcp/tcp_segment_list.go | 173 + netstack/tcpip/transport/tcp/tcp_test.go | 3565 +++++++++++++++++ .../tcpip/transport/tcp/tcp_timestamp_test.go | 316 ++ .../transport/tcp/testing/context/context.go | 960 +++++ netstack/tcpip/transport/tcp/timer.go | 142 + .../transport/tcpconntrack/tcp_conntrack.go | 348 ++ .../tcpconntrack/tcp_conntrack_test.go | 511 +++ netstack/tcpip/transport/udp/endpoint.go | 980 +++++ netstack/tcpip/transport/udp/protocol.go | 85 + .../tcpip/transport/udp/udp_packet_list.go | 174 + netstack/tcpip/transport/udp/udp_test.go | 928 +++++ netstack/tmutex/tmutex.go | 81 + netstack/tmutex/tmutex_test.go | 257 ++ netstack/waiter/waiter.go | 240 ++ netstack/waiter/waiter_test.go | 192 + sleep/sleep_unsafe.go | 4 - .../tcpip_bak}/buffer/prependable.go | 0 {tcpip => src_code/tcpip_bak}/buffer/view.go | 0 src_code/tcpip_bak/buffer/view_test.go | 235 ++ {tcpip => src_code/tcpip_bak}/header/eth.go | 0 {tcpip => src_code/tcpip_bak}/ilist/list.go | 0 .../tcpip_bak}/link/fdbased/endpoint.go | 0 .../link/rawfile/blockingpoll_unsafe.go | 0 .../tcpip_bak}/link/rawfile/errors.go | 0 .../tcpip_bak}/link/rawfile/rawfile_unsafe.go | 0 .../tcpip_bak}/link/tap1/main.go | 0 .../tcpip_bak}/link/tap2/main.go | 0 .../tcpip_bak}/link/tuntap/tuntap.go | 0 {tcpip => src_code/tcpip_bak}/ports/port.go | 0 .../tcpip_bak}/seqnum/seqnum.go | 0 .../tcpip_bak}/stack/linkaddresscache.go | 0 {tcpip => src_code/tcpip_bak}/stack/nic.go | 0 .../tcpip_bak}/stack/registration.go | 0 {tcpip => src_code/tcpip_bak}/stack/route.go | 0 {tcpip => src_code/tcpip_bak}/stack/stack.go | 0 .../tcpip_bak}/stack/transport_demuxer.go | 0 {tcpip => src_code/tcpip_bak}/tcpip.go | 0 126 files changed, 31369 insertions(+), 7 deletions(-) create mode 100644 netstack/ilist/list.go create mode 100644 netstack/rand/rand.go create mode 100644 netstack/sleep/commit_amd64.s create mode 100644 netstack/sleep/commit_asm.go create mode 100644 netstack/sleep/commit_noasm.go create mode 100644 netstack/sleep/empty.s create mode 100644 netstack/sleep/sleep_test.go create mode 100644 netstack/sleep/sleep_unsafe.go create mode 100644 netstack/tcpip/buffer/prependable.go create mode 100644 netstack/tcpip/buffer/view.go rename {tcpip => netstack/tcpip}/buffer/view_test.go (100%) create mode 100644 netstack/tcpip/checker/checker.go create mode 100644 netstack/tcpip/header/arp.go create mode 100644 netstack/tcpip/header/checksum.go create mode 100644 netstack/tcpip/header/eth.go create mode 100644 netstack/tcpip/header/gue.go create mode 100644 netstack/tcpip/header/icmpv4.go create mode 100644 netstack/tcpip/header/icmpv6.go create mode 100644 netstack/tcpip/header/interfaces.go create mode 100644 netstack/tcpip/header/ipv4.go create mode 100644 netstack/tcpip/header/ipv6.go create mode 100644 netstack/tcpip/header/ipv6_fragment.go create mode 100644 netstack/tcpip/header/ipversion_test.go create mode 100644 netstack/tcpip/header/tcp.go create mode 100644 netstack/tcpip/header/tcp_test.go create mode 100644 netstack/tcpip/header/udp.go create mode 100644 netstack/tcpip/link/channel/channel.go create mode 100644 netstack/tcpip/link/fdbased/endpoint.go create mode 100644 netstack/tcpip/link/fdbased/endpoint_test.go create mode 100644 netstack/tcpip/link/loopback/loopback.go create mode 100644 netstack/tcpip/link/rawfile/blockingpoll_unsafe.go create mode 100644 netstack/tcpip/link/rawfile/errors.go create mode 100644 netstack/tcpip/link/rawfile/rawfile_unsafe.go create mode 100644 netstack/tcpip/link/sniffer/pcap.go create mode 100644 netstack/tcpip/link/sniffer/sniffer.go create mode 100644 netstack/tcpip/link/tuntap/tuntap.go create mode 100644 netstack/tcpip/network/arp/arp.go create mode 100644 netstack/tcpip/network/arp/arp_test.go create mode 100644 netstack/tcpip/network/fragmentation/frag_heap.go create mode 100644 netstack/tcpip/network/fragmentation/frag_heap_test.go create mode 100644 netstack/tcpip/network/fragmentation/fragmentation.go create mode 100644 netstack/tcpip/network/fragmentation/fragmentation_test.go create mode 100644 netstack/tcpip/network/fragmentation/reassembler.go create mode 100644 netstack/tcpip/network/fragmentation/reassembler_list.go create mode 100644 netstack/tcpip/network/fragmentation/reassembler_test.go create mode 100644 netstack/tcpip/network/hash/hash.go create mode 100644 netstack/tcpip/network/ip_test.go create mode 100644 netstack/tcpip/network/ipv4/icmp.go create mode 100644 netstack/tcpip/network/ipv4/ipv4.go create mode 100644 netstack/tcpip/network/ipv4/ipv4_test.go create mode 100644 netstack/tcpip/network/ipv6/icmp.go create mode 100644 netstack/tcpip/network/ipv6/icmp_test.go create mode 100644 netstack/tcpip/network/ipv6/ipv6.go create mode 100644 netstack/tcpip/ports/ports.go create mode 100644 netstack/tcpip/ports/ports_test.go create mode 100644 netstack/tcpip/seqnum/seqnum.go create mode 100644 netstack/tcpip/stack/linkaddrcache.go create mode 100644 netstack/tcpip/stack/linkaddrcache_test.go create mode 100644 netstack/tcpip/stack/nic.go create mode 100644 netstack/tcpip/stack/registration.go create mode 100644 netstack/tcpip/stack/route.go create mode 100644 netstack/tcpip/stack/stack.go create mode 100644 netstack/tcpip/stack/stack_test.go create mode 100644 netstack/tcpip/stack/transport_demuxer.go create mode 100644 netstack/tcpip/stack/transport_test.go create mode 100644 netstack/tcpip/tcpip.go create mode 100644 netstack/tcpip/tcpip_test.go create mode 100644 netstack/tcpip/time.s create mode 100644 netstack/tcpip/time_unsafe.go create mode 100644 netstack/tcpip/transport/ping/endpoint.go create mode 100644 netstack/tcpip/transport/ping/ping_packet_list.go create mode 100644 netstack/tcpip/transport/ping/protocol.go create mode 100644 netstack/tcpip/transport/tcp/accept.go create mode 100644 netstack/tcpip/transport/tcp/connect.go create mode 100644 netstack/tcpip/transport/tcp/cubic.go create mode 100644 netstack/tcpip/transport/tcp/dual_stack_test.go create mode 100644 netstack/tcpip/transport/tcp/endpoint.go create mode 100644 netstack/tcpip/transport/tcp/forwarder.go create mode 100644 netstack/tcpip/transport/tcp/protocol.go create mode 100644 netstack/tcpip/transport/tcp/rcv.go create mode 100644 netstack/tcpip/transport/tcp/reno.go create mode 100644 netstack/tcpip/transport/tcp/sack.go create mode 100644 netstack/tcpip/transport/tcp/segment.go create mode 100644 netstack/tcpip/transport/tcp/segment_heap.go create mode 100644 netstack/tcpip/transport/tcp/segment_queue.go create mode 100644 netstack/tcpip/transport/tcp/snd.go create mode 100644 netstack/tcpip/transport/tcp/tcp_sack_test.go create mode 100644 netstack/tcpip/transport/tcp/tcp_segment_list.go create mode 100644 netstack/tcpip/transport/tcp/tcp_test.go create mode 100644 netstack/tcpip/transport/tcp/tcp_timestamp_test.go create mode 100644 netstack/tcpip/transport/tcp/testing/context/context.go create mode 100644 netstack/tcpip/transport/tcp/timer.go create mode 100644 netstack/tcpip/transport/tcpconntrack/tcp_conntrack.go create mode 100644 netstack/tcpip/transport/tcpconntrack/tcp_conntrack_test.go create mode 100644 netstack/tcpip/transport/udp/endpoint.go create mode 100644 netstack/tcpip/transport/udp/protocol.go create mode 100644 netstack/tcpip/transport/udp/udp_packet_list.go create mode 100644 netstack/tcpip/transport/udp/udp_test.go create mode 100644 netstack/tmutex/tmutex.go create mode 100644 netstack/tmutex/tmutex_test.go create mode 100644 netstack/waiter/waiter.go create mode 100644 netstack/waiter/waiter_test.go delete mode 100644 sleep/sleep_unsafe.go rename {tcpip => src_code/tcpip_bak}/buffer/prependable.go (100%) rename {tcpip => src_code/tcpip_bak}/buffer/view.go (100%) create mode 100644 src_code/tcpip_bak/buffer/view_test.go rename {tcpip => src_code/tcpip_bak}/header/eth.go (100%) rename {tcpip => src_code/tcpip_bak}/ilist/list.go (100%) rename {tcpip => src_code/tcpip_bak}/link/fdbased/endpoint.go (100%) rename {tcpip => src_code/tcpip_bak}/link/rawfile/blockingpoll_unsafe.go (100%) rename {tcpip => src_code/tcpip_bak}/link/rawfile/errors.go (100%) rename {tcpip => src_code/tcpip_bak}/link/rawfile/rawfile_unsafe.go (100%) rename {tcpip => src_code/tcpip_bak}/link/tap1/main.go (100%) rename {tcpip => src_code/tcpip_bak}/link/tap2/main.go (100%) rename {tcpip => src_code/tcpip_bak}/link/tuntap/tuntap.go (100%) rename {tcpip => src_code/tcpip_bak}/ports/port.go (100%) rename {tcpip => src_code/tcpip_bak}/seqnum/seqnum.go (100%) rename {tcpip => src_code/tcpip_bak}/stack/linkaddresscache.go (100%) rename {tcpip => src_code/tcpip_bak}/stack/nic.go (100%) rename {tcpip => src_code/tcpip_bak}/stack/registration.go (100%) rename {tcpip => src_code/tcpip_bak}/stack/route.go (100%) rename {tcpip => src_code/tcpip_bak}/stack/stack.go (100%) rename {tcpip => src_code/tcpip_bak}/stack/transport_demuxer.go (100%) rename {tcpip => src_code/tcpip_bak}/tcpip.go (100%) diff --git a/README.md b/README.md index d3f99fc..aa2c950 100644 --- a/README.md +++ b/README.md @@ -283,7 +283,6 @@ tap0 Link encap:Ethernet HWaddr 22:e2:f2:93:ff:bf ``` - 删除网卡可以使用如下命令: diff --git a/go.mod b/go.mod index 0670dc1..1833b1f 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/impact-eintr/netstack +module tcpip -go 1.16 +go 1.14 diff --git a/netstack/ilist/list.go b/netstack/ilist/list.go new file mode 100644 index 0000000..392c63e --- /dev/null +++ b/netstack/ilist/list.go @@ -0,0 +1,207 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ilist provides the implementation of intrusive linked lists. +package ilist + +// Linker is the interface that objects must implement if they want to be added +// to and/or removed from List objects. +// +// N.B. When substituted in a template instantiation, Linker doesn't need to +// be an interface, and in most cases won't be. +type Linker interface { + Next() Element + Prev() Element + SetNext(Element) + SetPrev(Element) +} + +// Element the item that is used at the API level. +// +// N.B. Like Linker, this is unlikely to be an interface in most cases. +type Element interface { + Linker +} + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type ElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (ElementMapper) linkerFor(elem Element) Linker { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type List struct { + head Element + tail Element +} + +// Reset resets list l to the empty state. +func (l *List) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *List) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *List) Front() Element { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *List) Back() Element { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *List) PushFront(e Element) { + ElementMapper{}.linkerFor(e).SetNext(l.head) + ElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + ElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *List) PushBack(e Element) { + ElementMapper{}.linkerFor(e).SetNext(nil) + ElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + ElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *List) PushBackList(m *List) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + ElementMapper{}.linkerFor(l.tail).SetNext(m.head) + ElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *List) InsertAfter(b, e Element) { + a := ElementMapper{}.linkerFor(b).Next() + ElementMapper{}.linkerFor(e).SetNext(a) + ElementMapper{}.linkerFor(e).SetPrev(b) + ElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + ElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *List) InsertBefore(a, e Element) { + b := ElementMapper{}.linkerFor(a).Prev() + ElementMapper{}.linkerFor(e).SetNext(a) + ElementMapper{}.linkerFor(e).SetPrev(b) + ElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + ElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *List) Remove(e Element) { + prev := ElementMapper{}.linkerFor(e).Prev() + next := ElementMapper{}.linkerFor(e).Next() + + if prev != nil { + ElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + ElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type Entry struct { + next Element + prev Element +} + +// Next returns the entry that follows e in the list. +func (e *Entry) Next() Element { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *Entry) Prev() Element { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *Entry) SetNext(elem Element) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *Entry) SetPrev(elem Element) { + e.prev = elem +} diff --git a/netstack/rand/rand.go b/netstack/rand/rand.go new file mode 100644 index 0000000..d9ef4cc --- /dev/null +++ b/netstack/rand/rand.go @@ -0,0 +1,27 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package rand implements a cryptographically secure pseudorandom number +// generator. +package rand + +import "crypto/rand" + +// Reader is the default reader. +var Reader = rand.Reader + +// Read implements io.Reader.Read. +func Read(b []byte) (int, error) { + return rand.Read(b) +} diff --git a/netstack/sleep/commit_amd64.s b/netstack/sleep/commit_amd64.s new file mode 100644 index 0000000..8585ac6 --- /dev/null +++ b/netstack/sleep/commit_amd64.s @@ -0,0 +1,35 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "textflag.h" + +#define preparingG 1 + +// See commit_noasm.go for a description of commitSleep. +// +// func commitSleep(g uintptr, waitingG *uintptr) bool +TEXT ·commitSleep(SB),NOSPLIT,$0-24 + MOVQ waitingG+8(FP), CX + MOVQ g+0(FP), DX + + // Store the G in waitingG if it's still preparingG. If it's anything + // else it means a waker has aborted the sleep. + MOVQ $preparingG, AX + LOCK + CMPXCHGQ DX, 0(CX) + + SETEQ AX + MOVB AX, ret+16(FP) + + RET diff --git a/netstack/sleep/commit_asm.go b/netstack/sleep/commit_asm.go new file mode 100644 index 0000000..adc81d5 --- /dev/null +++ b/netstack/sleep/commit_asm.go @@ -0,0 +1,20 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package sleep + +// See commit_noasm.go for a description of commitSleep. +func commitSleep(g uintptr, waitingG *uintptr) bool diff --git a/netstack/sleep/commit_noasm.go b/netstack/sleep/commit_noasm.go new file mode 100644 index 0000000..985c580 --- /dev/null +++ b/netstack/sleep/commit_noasm.go @@ -0,0 +1,42 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build !race +// +build !amd64 + +package sleep + +import "sync/atomic" + +// commitSleep signals to wakers that the given g is now sleeping. Wakers can +// then fetch it and wake it. +// +// The commit may fail if wakers have been asserted after our last check, in +// which case they will have set s.waitingG to zero. +// +// It is written in assembly because it is called from g0, so it doesn't have +// a race context. +func commitSleep(g uintptr, waitingG *uintptr) bool { + for { + // Check if the wait was aborted. + if atomic.LoadUintptr(waitingG) == 0 { + return false + } + + // Try to store the G so that wakers know who to wake. + if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) { + return true + } + } +} diff --git a/netstack/sleep/empty.s b/netstack/sleep/empty.s new file mode 100644 index 0000000..e07189f --- /dev/null +++ b/netstack/sleep/empty.s @@ -0,0 +1,15 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Empty assembly file so empty func definitions work. diff --git a/netstack/sleep/sleep_test.go b/netstack/sleep/sleep_test.go new file mode 100644 index 0000000..e4e4093 --- /dev/null +++ b/netstack/sleep/sleep_test.go @@ -0,0 +1,542 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sleep + +import ( + "math/rand" + "runtime" + "testing" + "time" +) + +// ZeroWakerNotAsserted tests that a zero-value waker is in non-asserted state. +func ZeroWakerNotAsserted(t *testing.T) { + var w Waker + if w.IsAsserted() { + t.Fatalf("Zero waker is asserted") + } + + if w.Clear() { + t.Fatalf("Zero waker is asserted") + } +} + +// AssertedWakerAfterAssert tests that a waker properly reports its state as +// asserted once its Assert() method is called. +func AssertedWakerAfterAssert(t *testing.T) { + var w Waker + w.Assert() + if !w.IsAsserted() { + t.Fatalf("Asserted waker is not reported as such") + } + + if !w.Clear() { + t.Fatalf("Asserted waker is not reported as such") + } +} + +// AssertedWakerAfterTwoAsserts tests that a waker properly reports its state as +// asserted once its Assert() method is called twice. +func AssertedWakerAfterTwoAsserts(t *testing.T) { + var w Waker + w.Assert() + w.Assert() + if !w.IsAsserted() { + t.Fatalf("Asserted waker is not reported as such") + } + + if !w.Clear() { + t.Fatalf("Asserted waker is not reported as such") + } +} + +// NotAssertedWakerWithSleeper tests that a waker properly reports its state as +// not asserted after a sleeper is associated with it. +func NotAssertedWakerWithSleeper(t *testing.T) { + var w Waker + var s Sleeper + s.AddWaker(&w, 0) + if w.IsAsserted() { + t.Fatalf("Non-asserted waker is reported as asserted") + } + + if w.Clear() { + t.Fatalf("Non-asserted waker is reported as asserted") + } +} + +// NotAssertedWakerAfterWake tests that a waker properly reports its state as +// not asserted after a previous assert is consumed by a sleeper. That is, tests +// the "edge-triggered" behavior. +func NotAssertedWakerAfterWake(t *testing.T) { + var w Waker + var s Sleeper + s.AddWaker(&w, 0) + w.Assert() + s.Fetch(true) + if w.IsAsserted() { + t.Fatalf("Consumed waker is reported as asserted") + } + + if w.Clear() { + t.Fatalf("Consumed waker is reported as asserted") + } +} + +// AssertedWakerBeforeAdd tests that a waker causes a sleeper to not sleep if +// it's already asserted before being added. +func AssertedWakerBeforeAdd(t *testing.T) { + var w Waker + var s Sleeper + w.Assert() + s.AddWaker(&w, 0) + + if _, ok := s.Fetch(false); !ok { + t.Fatalf("Fetch failed even though asserted waker was added") + } +} + +// ClearedWaker tests that a waker properly reports its state as not asserted +// after it is cleared. +func ClearedWaker(t *testing.T) { + var w Waker + w.Assert() + w.Clear() + if w.IsAsserted() { + t.Fatalf("Cleared waker is reported as asserted") + } + + if w.Clear() { + t.Fatalf("Cleared waker is reported as asserted") + } +} + +// ClearedWakerWithSleeper tests that a waker properly reports its state as +// not asserted when it is cleared while it has a sleeper associated with it. +func ClearedWakerWithSleeper(t *testing.T) { + var w Waker + var s Sleeper + s.AddWaker(&w, 0) + w.Clear() + if w.IsAsserted() { + t.Fatalf("Cleared waker is reported as asserted") + } + + if w.Clear() { + t.Fatalf("Cleared waker is reported as asserted") + } +} + +// ClearedWakerAssertedWithSleeper tests that a waker properly reports its state +// as not asserted when it is cleared while it has a sleeper associated with it +// and has been asserted. +func ClearedWakerAssertedWithSleeper(t *testing.T) { + var w Waker + var s Sleeper + s.AddWaker(&w, 0) + w.Assert() + w.Clear() + if w.IsAsserted() { + t.Fatalf("Cleared waker is reported as asserted") + } + + if w.Clear() { + t.Fatalf("Cleared waker is reported as asserted") + } +} + +// TestBlock tests that a sleeper actually blocks waiting for the waker to +// assert its state. +func TestBlock(t *testing.T) { + var w Waker + var s Sleeper + + s.AddWaker(&w, 0) + + // Assert waker after one second. + before := time.Now() + go func() { + time.Sleep(1 * time.Second) + w.Assert() + }() + + // Fetch the result and make sure it took at least 500ms. + if _, ok := s.Fetch(true); !ok { + t.Fatalf("Fetch failed unexpectedly") + } + if d := time.Now().Sub(before); d < 500*time.Millisecond { + t.Fatalf("Duration was too short: %v", d) + } + + // Check that already-asserted waker completes inline. + w.Assert() + if _, ok := s.Fetch(true); !ok { + t.Fatalf("Fetch failed unexpectedly") + } + + // Check that fetch sleeps if waker had been asserted but was reset + // before Fetch is called. + w.Assert() + w.Clear() + before = time.Now() + go func() { + time.Sleep(1 * time.Second) + w.Assert() + }() + if _, ok := s.Fetch(true); !ok { + t.Fatalf("Fetch failed unexpectedly") + } + if d := time.Now().Sub(before); d < 500*time.Millisecond { + t.Fatalf("Duration was too short: %v", d) + } +} + +// TestNonBlock checks that a sleeper won't block if waker isn't asserted. +func TestNonBlock(t *testing.T) { + var w Waker + var s Sleeper + + // Don't block when there's no waker. + if _, ok := s.Fetch(false); ok { + t.Fatalf("Fetch succeeded when there is no waker") + } + + // Don't block when waker isn't asserted. + s.AddWaker(&w, 0) + if _, ok := s.Fetch(false); ok { + t.Fatalf("Fetch succeeded when waker was not asserted") + } + + // Don't block when waker was asserted, but isn't anymore. + w.Assert() + w.Clear() + if _, ok := s.Fetch(false); ok { + t.Fatalf("Fetch succeeded when waker was not asserted anymore") + } + + // Don't block when waker was consumed by previous Fetch(). + w.Assert() + if _, ok := s.Fetch(false); !ok { + t.Fatalf("Fetch failed even though waker was asserted") + } + + if _, ok := s.Fetch(false); ok { + t.Fatalf("Fetch succeeded when waker had been consumed") + } +} + +// TestMultiple checks that a sleeper can wait for and receives notifications +// from multiple wakers. +func TestMultiple(t *testing.T) { + s := Sleeper{} + w1 := Waker{} + w2 := Waker{} + + s.AddWaker(&w1, 0) + s.AddWaker(&w2, 1) + + w1.Assert() + w2.Assert() + + v, ok := s.Fetch(false) + if !ok { + t.Fatalf("Fetch failed when there are asserted wakers") + } + + if v != 0 && v != 1 { + t.Fatalf("Unexpected waker id: %v", v) + } + + want := 1 - v + v, ok = s.Fetch(false) + if !ok { + t.Fatalf("Fetch failed when there is an asserted waker") + } + + if v != want { + t.Fatalf("Unexpected waker id, got %v, want %v", v, want) + } +} + +// TestDoneFunction tests if calling Done() on a sleeper works properly. +func TestDoneFunction(t *testing.T) { + // Trivial case of no waker. + s := Sleeper{} + s.Done() + + // Cases when the sleeper has n wakers, but none are asserted. + for n := 1; n < 20; n++ { + s := Sleeper{} + w := make([]Waker, n) + for j := 0; j < n; j++ { + s.AddWaker(&w[j], j) + } + s.Done() + } + + // Cases when the sleeper has n wakers, and only the i-th one is + // asserted. + for n := 1; n < 20; n++ { + for i := 0; i < n; i++ { + s := Sleeper{} + w := make([]Waker, n) + for j := 0; j < n; j++ { + s.AddWaker(&w[j], j) + } + w[i].Assert() + s.Done() + } + } + + // Cases when the sleeper has n wakers, and the i-th one is asserted + // and cleared. + for n := 1; n < 20; n++ { + for i := 0; i < n; i++ { + s := Sleeper{} + w := make([]Waker, n) + for j := 0; j < n; j++ { + s.AddWaker(&w[j], j) + } + w[i].Assert() + w[i].Clear() + s.Done() + } + } + + // Cases when the sleeper has n wakers, with a random number of them + // asserted. + for n := 1; n < 20; n++ { + for iters := 0; iters < 1000; iters++ { + s := Sleeper{} + w := make([]Waker, n) + for j := 0; j < n; j++ { + s.AddWaker(&w[j], j) + } + + // Pick the number of asserted elements, then assert + // random wakers. + asserted := rand.Int() % (n + 1) + for j := 0; j < asserted; j++ { + w[rand.Int()%n].Assert() + } + s.Done() + } + } +} + +// TestRace tests that multiple wakers can continuously send wake requests to +// the sleeper. +func TestRace(t *testing.T) { + const wakers = 100 + const wakeRequests = 10000 + + counts := make([]int, wakers) + w := make([]Waker, wakers) + s := Sleeper{} + + // Associate each waker and start goroutines that will assert them. + for i := range w { + s.AddWaker(&w[i], i) + go func(w *Waker) { + n := 0 + for n < wakeRequests { + if !w.IsAsserted() { + w.Assert() + n++ + } else { + runtime.Gosched() + } + } + }(&w[i]) + } + + // Wait for all wake up notifications from all wakers. + for i := 0; i < wakers*wakeRequests; i++ { + v, _ := s.Fetch(true) + counts[v]++ + } + + // Check that we got the right number for each. + for i, v := range counts { + if v != wakeRequests { + t.Errorf("Waker %v only got %v wakes", i, v) + } + } +} + +// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up +// from 4 wakers when at least one is already asserted. +func BenchmarkSleeperMultiSelect(b *testing.B) { + const count = 4 + s := Sleeper{} + w := make([]Waker, count) + for i := range w { + s.AddWaker(&w[i], i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w[count-1].Assert() + s.Fetch(true) + } +} + +// BenchmarkGoMultiSelect measures how long it takes to fetch a zero-length +// struct from one of 4 channels when at least one is ready. +func BenchmarkGoMultiSelect(b *testing.B) { + const count = 4 + ch := make([]chan struct{}, count) + for i := range ch { + ch[i] = make(chan struct{}, 1) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ch[count-1] <- struct{}{} + select { + case <-ch[0]: + case <-ch[1]: + case <-ch[2]: + case <-ch[3]: + } + } +} + +// BenchmarkSleeperSingleSelect measures how long it takes to fetch a wake up +// from one waker that is already asserted. +func BenchmarkSleeperSingleSelect(b *testing.B) { + s := Sleeper{} + w := Waker{} + s.AddWaker(&w, 0) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + w.Assert() + s.Fetch(true) + } +} + +// BenchmarkGoSingleSelect measures how long it takes to fetch a zero-length +// struct from a channel that already has it buffered. +func BenchmarkGoSingleSelect(b *testing.B) { + ch := make(chan struct{}, 1) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + ch <- struct{}{} + <-ch + } +} + +// BenchmarkSleeperAssertNonWaiting measures how long it takes to assert a +// channel that is already asserted. +func BenchmarkSleeperAssertNonWaiting(b *testing.B) { + w := Waker{} + w.Assert() + for i := 0; i < b.N; i++ { + w.Assert() + } + +} + +// BenchmarkGoAssertNonWaiting measures how long it takes to write to a channel +// that has already something written to it. +func BenchmarkGoAssertNonWaiting(b *testing.B) { + ch := make(chan struct{}, 1) + ch <- struct{}{} + for i := 0; i < b.N; i++ { + select { + case ch <- struct{}{}: + default: + } + } +} + +// BenchmarkSleeperWaitOnSingleSelect measures how long it takes to wait on one +// waker channel while another goroutine wakes up the sleeper. This assumes that +// a new goroutine doesn't run immediately (i.e., the creator of a new goroutine +// is allowed to go to sleep before the new goroutine has a chance to run). +func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) { + s := Sleeper{} + w := Waker{} + s.AddWaker(&w, 0) + for i := 0; i < b.N; i++ { + go func() { + w.Assert() + }() + s.Fetch(true) + } + +} + +// BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one +// channel while another goroutine wakes up the sleeper. This assumes that a new +// goroutine doesn't run immediately (i.e., the creator of a new goroutine is +// allowed to go to sleep before the new goroutine has a chance to run). +func BenchmarkGoWaitOnSingleSelect(b *testing.B) { + ch := make(chan struct{}, 1) + for i := 0; i < b.N; i++ { + go func() { + ch <- struct{}{} + }() + <-ch + } +} + +// BenchmarkSleeperWaitOnMultiSelect measures how long it takes to wait on 4 +// wakers while another goroutine wakes up the sleeper. This assumes that a new +// goroutine doesn't run immediately (i.e., the creator of a new goroutine is +// allowed to go to sleep before the new goroutine has a chance to run). +func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) { + const count = 4 + s := Sleeper{} + w := make([]Waker, count) + for i := range w { + s.AddWaker(&w[i], i) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + w[count-1].Assert() + }() + s.Fetch(true) + } +} + +// BenchmarkGoWaitOnMultiSelect measures how long it takes to wait on 4 channels +// while another goroutine wakes up the sleeper. This assumes that a new +// goroutine doesn't run immediately (i.e., the creator of a new goroutine is +// allowed to go to sleep before the new goroutine has a chance to run). +func BenchmarkGoWaitOnMultiSelect(b *testing.B) { + const count = 4 + ch := make([]chan struct{}, count) + for i := range ch { + ch[i] = make(chan struct{}, 1) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + go func() { + ch[count-1] <- struct{}{} + }() + select { + case <-ch[0]: + case <-ch[1]: + case <-ch[2]: + case <-ch[3]: + } + } +} diff --git a/netstack/sleep/sleep_unsafe.go b/netstack/sleep/sleep_unsafe.go new file mode 100644 index 0000000..f63cadc --- /dev/null +++ b/netstack/sleep/sleep_unsafe.go @@ -0,0 +1,395 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sleep allows goroutines to efficiently sleep on multiple sources of +// notifications (wakers). It offers O(1) complexity, which is different from +// multi-channel selects which have O(n) complexity (where n is the number of +// channels) and a considerable constant factor. +// +// It is similar to edge-triggered epoll waits, where the user registers each +// object of interest once, and then can repeatedly wait on all of them. +// +// A Waker object is used to wake a sleeping goroutine (G) up, or prevent it +// from going to sleep next. A Sleeper object is used to receive notifications +// from wakers, and if no notifications are available, to optionally sleep until +// one becomes available. +// +// A Waker can be associated with at most one Sleeper, but a Sleeper can be +// associated with multiple Wakers. A Sleeper has a list of asserted (ready) +// wakers; when Fetch() is called repeatedly, elements from this list are +// returned until the list becomes empty in which case the goroutine goes to +// sleep. When Assert() is called on a Waker, it adds itself to the Sleeper's +// asserted list and wakes the G up from its sleep if needed. +// +// Sleeper objects are expected to be used as follows, with just one goroutine +// executing this code: +// +// // One time set-up. +// s := sleep.Sleeper{} +// s.AddWaker(&w1, constant1) +// s.AddWaker(&w2, constant2) +// +// // Called repeatedly. +// for { +// switch id, _ := s.Fetch(true); id { +// case constant1: +// // Do work triggered by w1 being asserted. +// case constant2: +// // Do work triggered by w2 being asserted. +// } +// } +// +// And Waker objects are expected to call w.Assert() when they want the sleeper +// to wake up and perform work. +// +// The notifications are edge-triggered, which means that if a Waker calls +// Assert() several times before the sleeper has the chance to wake up, it will +// only be notified once and should perform all pending work (alternatively, it +// can also call Assert() on the waker, to ensure that it will wake up again). +// +// The "unsafeness" here is in the casts to/from unsafe.Pointer, which is safe +// when only one type is used for each unsafe.Pointer (which is the case here), +// we should just make sure that this remains the case in the future. The usage +// of unsafe package could be confined to sharedWaker and sharedSleeper types +// that would hold pointers in atomic.Pointers, but the go compiler currently +// can't optimize these as well (it won't inline their method calls), which +// reduces performance. +package sleep + +import ( + "sync/atomic" + "unsafe" +) + +const ( + // preparingG is stored in sleepers to indicate that they're preparing + // to sleep. + preparingG = 1 +) + +var ( + // assertedSleeper is a sentinel sleeper. A pointer to it is stored in + // wakers that are asserted. + assertedSleeper Sleeper +) + +//go:linkname gopark runtime.gopark +func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason string, traceEv byte, traceskip int) + +//go:linkname goready runtime.goready +func goready(g uintptr, traceskip int) + +// Sleeper allows a goroutine to sleep and receive wake up notifications from +// Wakers in an efficient way. +// +// This is similar to edge-triggered epoll in that wakers are added to the +// sleeper once and the sleeper can then repeatedly sleep in O(1) time while +// waiting on all wakers. +// +// None of the methods in a Sleeper can be called concurrently. Wakers that have +// been added to a sleeper A can only be added to another sleeper after A.Done() +// returns. These restrictions allow this to be implemented lock-free. +// +// This struct is thread-compatible. +type Sleeper struct { + // sharedList is a "stack" of asserted wakers. They atomically add + // themselves to the front of this list as they become asserted. + sharedList unsafe.Pointer + + // localList is a list of asserted wakers that is only accessible to the + // waiter, and thus doesn't have to be accessed atomically. When + // fetching more wakers, the waiter will first go through this list, and + // only when it's empty will it atomically fetch wakers from + // sharedList. + localList *Waker + + // allWakers is a list with all wakers that have been added to this + // sleeper. It is used during cleanup to remove associations. + allWakers *Waker + + // waitingG holds the G that is sleeping, if any. It is used by wakers + // to determine which G, if any, they should wake. + waitingG uintptr +} + +// AddWaker associates the given waker to the sleeper. id is the value to be +// returned when the sleeper is woken by the given waker. +func (s *Sleeper) AddWaker(w *Waker, id int) { + // Add the waker to the list of all wakers. + w.allWakersNext = s.allWakers + s.allWakers = w + w.id = id + + // Try to associate the waker with the sleeper. If it's already + // asserted, we simply enqueue it in the "ready" list. + for { + p := (*Sleeper)(atomic.LoadPointer(&w.s)) + if p == &assertedSleeper { + s.enqueueAssertedWaker(w) + return + } + + if atomic.CompareAndSwapPointer(&w.s, usleeper(p), usleeper(s)) { + return + } + } +} + +// nextWaker returns the next waker in the notification list, blocking if +// needed. +func (s *Sleeper) nextWaker(block bool) *Waker { + // Attempt to replenish the local list if it's currently empty. + if s.localList == nil { + for atomic.LoadPointer(&s.sharedList) == nil { + // Fail request if caller requested that we + // don't block. + if !block { + return nil + } + + // Indicate to wakers that we're about to sleep, + // this allows them to abort the wait by setting + // waitingG back to zero (which we'll notice + // before committing the sleep). + atomic.StoreUintptr(&s.waitingG, preparingG) + + // Check if something was queued while we were + // preparing to sleep. We need this interleaving + // to avoid missing wake ups. + if atomic.LoadPointer(&s.sharedList) != nil { + atomic.StoreUintptr(&s.waitingG, 0) + break + } + + // Try to commit the sleep and report it to the + // tracer as a select. + // + // gopark puts the caller to sleep and calls + // commitSleep to decide whether to immediately + // wake the caller up or to leave it sleeping. + const traceEvGoBlockSelect = 24 + gopark(commitSleep, &s.waitingG, "sleeper", traceEvGoBlockSelect, 0) + } + + // Pull the shared list out and reverse it in the local + // list. Given that wakers push themselves in reverse + // order, we fix things here. + v := (*Waker)(atomic.SwapPointer(&s.sharedList, nil)) + for v != nil { + cur := v + v = v.next + + cur.next = s.localList + s.localList = cur + } + } + + // Remove the waker in the front of the list. + w := s.localList + s.localList = w.next + + return w +} + +// Fetch fetches the next wake-up notification. If a notification is immediately +// available, it is returned right away. Otherwise, the behavior depends on the +// value of 'block': if true, the current goroutine blocks until a notification +// arrives, then returns it; if false, returns 'ok' as false. +// +// When 'ok' is true, the value of 'id' corresponds to the id associated with +// the waker; when 'ok' is false, 'id' is undefined. +// +// N.B. This method is *not* thread-safe. Only one goroutine at a time is +// allowed to call this method. +func (s *Sleeper) Fetch(block bool) (id int, ok bool) { + for { + w := s.nextWaker(block) + if w == nil { + return -1, false + } + + // Reassociate the waker with the sleeper. If the waker was + // still asserted we can return it, otherwise try the next one. + old := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(s))) + if old == &assertedSleeper { + return w.id, true + } + } +} + +// Done is used to indicate that the caller won't use this Sleeper anymore. It +// removes the association with all wakers so that they can be safely reused +// by another sleeper after Done() returns. +func (s *Sleeper) Done() { + // Remove all associations that we can, and build a list of the ones + // we could not. An association can be removed right away from waker w + // if w.s has a pointer to the sleeper, that is, the waker is not + // asserted yet. By atomically switching w.s to nil, we guarantee that + // subsequent calls to Assert() on the waker will not result in it being + // queued to this sleeper. + var pending *Waker + w := s.allWakers + for w != nil { + next := w.allWakersNext + for { + t := atomic.LoadPointer(&w.s) + if t != usleeper(s) { + w.allWakersNext = pending + pending = w + break + } + + if atomic.CompareAndSwapPointer(&w.s, t, nil) { + break + } + } + w = next + } + + // The associations that we could not remove are either asserted, or in + // the process of being asserted, or have been asserted and cleared + // before being pulled from the sleeper lists. We must wait for them all + // to make it to the sleeper lists, so that we know that the wakers + // won't do any more work towards waking this sleeper up. + for pending != nil { + pulled := s.nextWaker(true) + + // Remove the waker we just pulled from the list of associated + // wakers. + prev := &pending + for w := *prev; w != nil; w = *prev { + if pulled == w { + *prev = w.allWakersNext + break + } + prev = &w.allWakersNext + } + } + s.allWakers = nil +} + +// enqueueAssertedWaker enqueues an asserted waker to the "ready" circular list +// of wakers that want to notify the sleeper. +func (s *Sleeper) enqueueAssertedWaker(w *Waker) { + // Add the new waker to the front of the list. + for { + v := (*Waker)(atomic.LoadPointer(&s.sharedList)) + w.next = v + if atomic.CompareAndSwapPointer(&s.sharedList, uwaker(v), uwaker(w)) { + break + } + } + + for { + // Nothing to do if there isn't a G waiting. + g := atomic.LoadUintptr(&s.waitingG) + if g == 0 { + return + } + + // Signal to the sleeper that a waker has been asserted. + if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) { + if g != preparingG { + // We managed to get a G. Wake it up. + goready(g, 0) + } + } + } +} + +// Waker represents a source of wake-up notifications to be sent to sleepers. A +// waker can be associated with at most one sleeper at a time, and at any given +// time is either in asserted or non-asserted state. +// +// Once asserted, the waker remains so until it is manually cleared or a sleeper +// consumes its assertion (i.e., a sleeper wakes up or is prevented from going +// to sleep due to the waker). +// +// This struct is thread-safe, that is, its methods can be called concurrently +// by multiple goroutines. +type Waker struct { + // s is the sleeper that this waker can wake up. Only one sleeper at a + // time is allowed. This field can have three classes of values: + // nil -- the waker is not asserted: it either is not associated with + // a sleeper, or is queued to a sleeper due to being previously + // asserted. This is the zero value. + // &assertedSleeper -- the waker is asserted. + // otherwise -- the waker is not asserted, and is associated with the + // given sleeper. Once it transitions to asserted state, the + // associated sleeper will be woken. + s unsafe.Pointer + + // next is used to form a linked list of asserted wakers in a sleeper. + next *Waker + + // allWakersNext is used to form a linked list of all wakers associated + // to a given sleeper. + allWakersNext *Waker + + // id is the value to be returned to sleepers when they wake up due to + // this waker being asserted. + id int +} + +// Assert moves the waker to an asserted state, if it isn't asserted yet. When +// asserted, the waker will cause its matching sleeper to wake up. +func (w *Waker) Assert() { + // Nothing to do if the waker is already asserted. This check allows us + // to complete this case (already asserted) without any interlocked + // operations on x86. + if atomic.LoadPointer(&w.s) == usleeper(&assertedSleeper) { + return + } + + // Mark the waker as asserted, and wake up a sleeper if there is one. + switch s := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(&assertedSleeper))); s { + case nil: + case &assertedSleeper: + default: + s.enqueueAssertedWaker(w) + } +} + +// Clear moves the waker to then non-asserted state and returns whether it was +// asserted before being cleared. +// +// N.B. The waker isn't removed from the "ready" list of a sleeper (if it +// happens to be in one), but the sleeper will notice that it is not asserted +// anymore and won't return it to the caller. +func (w *Waker) Clear() bool { + // Nothing to do if the waker is not asserted. This check allows us to + // complete this case (already not asserted) without any interlocked + // operations on x86. + if atomic.LoadPointer(&w.s) != usleeper(&assertedSleeper) { + return false + } + + // Try to store nil in the sleeper, which indicates that the waker is + // not asserted. + return atomic.CompareAndSwapPointer(&w.s, usleeper(&assertedSleeper), nil) +} + +// IsAsserted returns whether the waker is currently asserted (i.e., if it's +// currently in a state that would cause its matching sleeper to wake up). +func (w *Waker) IsAsserted() bool { + return (*Sleeper)(atomic.LoadPointer(&w.s)) == &assertedSleeper +} + +func usleeper(s *Sleeper) unsafe.Pointer { + return unsafe.Pointer(s) +} + +func uwaker(w *Waker) unsafe.Pointer { + return unsafe.Pointer(w) +} diff --git a/netstack/tcpip/buffer/prependable.go b/netstack/tcpip/buffer/prependable.go new file mode 100644 index 0000000..2b29bd8 --- /dev/null +++ b/netstack/tcpip/buffer/prependable.go @@ -0,0 +1,64 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package buffer + +// Prependable is a buffer that grows backwards, that is, more data can be +// prepended to it. It is useful when building networking packets, where each +// protocol adds its own headers to the front of the higher-level protocol +// header and payload; for example, TCP would prepend its header to the payload, +// then IP would prepend its own, then ethernet. +type Prependable struct { + // Buf is the buffer backing the prependable buffer. + buf View + + // usedIdx is the index where the used part of the buffer begins. + usedIdx int +} + +// NewPrependable allocates a new prependable buffer with the given size. +func NewPrependable(size int) Prependable { + return Prependable{buf: NewView(size), usedIdx: size} +} + +// NewPrependableFromView creates an entirely-used Prependable from a View. +// +// NewPrependableFromView takes ownership of v. Note that since the entire +// prependable is used, further attempts to call Prepend will note that size > +// p.usedIdx and return nil. +func NewPrependableFromView(v View) Prependable { + return Prependable{buf: v, usedIdx: 0} +} + +// View returns a View of the backing buffer that contains all prepended +// data so far. +func (p Prependable) View() View { + return p.buf[p.usedIdx:] +} + +// UsedLength returns the number of bytes used so far. +func (p Prependable) UsedLength() int { + return len(p.buf) - p.usedIdx +} + +// Prepend reserves the requested space in front of the buffer, returning a +// slice that represents the reserved space. +func (p *Prependable) Prepend(size int) []byte { + if size > p.usedIdx { + return nil + } + + p.usedIdx -= size + return p.View()[:size:size] +} diff --git a/netstack/tcpip/buffer/view.go b/netstack/tcpip/buffer/view.go new file mode 100644 index 0000000..697c5cd --- /dev/null +++ b/netstack/tcpip/buffer/view.go @@ -0,0 +1,146 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package buffer provides the implementation of a buffer view. +package buffer + +// View is a slice of a buffer, with convenience methods. +type View []byte + +// NewView allocates a new buffer and returns an initialized view that covers +// the whole buffer. +func NewView(size int) View { + return make(View, size) +} + +// NewViewFromBytes allocates a new buffer and copies in the given bytes. +func NewViewFromBytes(b []byte) View { + return append(View(nil), b...) +} + +// TrimFront removes the first "count" bytes from the visible section of the +// buffer. +func (v *View) TrimFront(count int) { + *v = (*v)[count:] +} + +// CapLength irreversibly reduces the length of the visible section of the +// buffer to the value specified. +func (v *View) CapLength(length int) { + // We also set the slice cap because if we don't, one would be able to + // expand the view back to include the region just excluded. We want to + // prevent that to avoid potential data leak if we have uninitialized + // data in excluded region. + *v = (*v)[:length:length] +} + +// ToVectorisedView returns a VectorisedView containing the receiver. +func (v View) ToVectorisedView() VectorisedView { + return NewVectorisedView(len(v), []View{v}) +} + +// VectorisedView is a vectorised version of View using non contigous memory. +// It supports all the convenience methods supported by View. +// +// +stateify savable +type VectorisedView struct { + views []View + size int +} + +// NewVectorisedView creates a new vectorised view from an already-allocated slice +// of View and sets its size. +func NewVectorisedView(size int, views []View) VectorisedView { + return VectorisedView{views: views, size: size} +} + +// TrimFront removes the first "count" bytes of the vectorised view. +func (vv *VectorisedView) TrimFront(count int) { + for count > 0 && len(vv.views) > 0 { + if count < len(vv.views[0]) { + vv.size -= count + vv.views[0].TrimFront(count) + return + } + count -= len(vv.views[0]) + vv.RemoveFirst() + } +} + +// CapLength irreversibly reduces the length of the vectorised view. +func (vv *VectorisedView) CapLength(length int) { + if length < 0 { + length = 0 + } + if vv.size < length { + return + } + vv.size = length + for i := range vv.views { + v := &vv.views[i] + if len(*v) >= length { + if length == 0 { + vv.views = vv.views[:i] + } else { + v.CapLength(length) + vv.views = vv.views[:i+1] + } + return + } + length -= len(*v) + } +} + +// Clone returns a clone of this VectorisedView. +// If the buffer argument is large enough to contain all the Views of this VectorisedView, +// the method will avoid allocations and use the buffer to store the Views of the clone. +func (vv VectorisedView) Clone(buffer []View) VectorisedView { + return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size} +} + +// First returns the first view of the vectorised view. +func (vv VectorisedView) First() View { + if len(vv.views) == 0 { + return nil + } + return vv.views[0] +} + +// RemoveFirst removes the first view of the vectorised view. +func (vv *VectorisedView) RemoveFirst() { + if len(vv.views) == 0 { + return + } + vv.size -= len(vv.views[0]) + vv.views = vv.views[1:] +} + +// Size returns the size in bytes of the entire content stored in the vectorised view. +func (vv VectorisedView) Size() int { + return vv.size +} + +// ToView returns a single view containing the content of the vectorised view. +func (vv VectorisedView) ToView() View { + u := make([]byte, 0, vv.size) + for _, v := range vv.views { + u = append(u, v...) + } + return u +} + +// Views returns the slice containing the all views. +func (vv VectorisedView) Views() []View { + return vv.views +} diff --git a/tcpip/buffer/view_test.go b/netstack/tcpip/buffer/view_test.go similarity index 100% rename from tcpip/buffer/view_test.go rename to netstack/tcpip/buffer/view_test.go diff --git a/netstack/tcpip/checker/checker.go b/netstack/tcpip/checker/checker.go new file mode 100644 index 0000000..88df0bb --- /dev/null +++ b/netstack/tcpip/checker/checker.go @@ -0,0 +1,588 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checker provides helper functions to check networking packets for +// validity. +package checker + +import ( + "encoding/binary" + "reflect" + "testing" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" +) + +// NetworkChecker is a function to check a property of a network packet. +type NetworkChecker func(*testing.T, []header.Network) + +// TransportChecker is a function to check a property of a transport packet. +type TransportChecker func(*testing.T, header.Transport) + +// IPv4 checks the validity and properties of the given IPv4 packet. It is +// expected to be used in conjunction with other network checkers for specific +// properties. For example, to check the source and destination address, one +// would call: +// +// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y)) +func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv4 := header.IPv4(b) + + if !ipv4.IsValid(len(b)) { + t.Error("Not a valid IPv4 packet") + } + + xsum := ipv4.CalculateChecksum() + if xsum != 0 && xsum != 0xffff { + t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum()) + } + + for _, f := range checkers { + f(t, []header.Network{ipv4}) + } + if t.Failed() { + t.FailNow() + } +} + +// IPv6 checks the validity and properties of the given IPv6 packet. The usage +// is similar to IPv4. +func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) { + t.Helper() + + ipv6 := header.IPv6(b) + if !ipv6.IsValid(len(b)) { + t.Error("Not a valid IPv6 packet") + } + + for _, f := range checkers { + f(t, []header.Network{ipv6}) + } + if t.Failed() { + t.FailNow() + } +} + +// SrcAddr creates a checker that checks the source address. +func SrcAddr(addr tcpip.Address) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + if a := h[0].SourceAddress(); a != addr { + t.Errorf("Bad source address, got %v, want %v", a, addr) + } + } +} + +// DstAddr creates a checker that checks the destination address. +func DstAddr(addr tcpip.Address) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + if a := h[0].DestinationAddress(); a != addr { + t.Errorf("Bad destination address, got %v, want %v", a, addr) + } + } +} + +// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6). +func TTL(ttl uint8) NetworkChecker { + return func(t *testing.T, h []header.Network) { + var v uint8 + switch ip := h[0].(type) { + case header.IPv4: + v = ip.TTL() + case header.IPv6: + v = ip.HopLimit() + } + if v != ttl { + t.Fatalf("Bad TTL, got %v, want %v", v, ttl) + } + } +} + +// PayloadLen creates a checker that checks the payload length. +func PayloadLen(plen int) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + if l := len(h[0].Payload()); l != plen { + t.Errorf("Bad payload length, got %v, want %v", l, plen) + } + } +} + +// FragmentOffset creates a checker that checks the FragmentOffset field. +func FragmentOffset(offset uint16) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // We only do this of IPv4 for now. + switch ip := h[0].(type) { + case header.IPv4: + if v := ip.FragmentOffset(); v != offset { + t.Errorf("Bad fragment offset, got %v, want %v", v, offset) + } + } + } +} + +// FragmentFlags creates a checker that checks the fragment flags field. +func FragmentFlags(flags uint8) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + // We only do this of IPv4 for now. + switch ip := h[0].(type) { + case header.IPv4: + if v := ip.Flags(); v != flags { + t.Errorf("Bad fragment offset, got %v, want %v", v, flags) + } + } + } +} + +// TOS creates a checker that checks the TOS field. +func TOS(tos uint8, label uint32) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + if v, l := h[0].TOS(); v != tos || l != label { + t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label) + } + } +} + +// Raw creates a checker that checks the bytes of payload. +// The checker always checks the payload of the last network header. +// For instance, in case of IPv6 fragments, the payload that will be checked +// is the one containing the actual data that the packet is carrying, without +// the bytes added by the IPv6 fragmentation. +func Raw(want []byte) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) { + t.Errorf("Wrong payload, got %v, want %v", got, want) + } + } +} + +// IPv6Fragment creates a checker that validates an IPv6 fragment. +func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader { + t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + } + + ipv6Frag := header.IPv6Fragment(h[0].Payload()) + if !ipv6Frag.IsValid() { + t.Error("Not a valid IPv6 fragment") + } + + for _, f := range checkers { + f(t, []header.Network{h[0], ipv6Frag}) + } + if t.Failed() { + t.FailNow() + } + } +} + +// TCP creates a checker that checks that the transport protocol is TCP and +// potentially additional transport header fields. +func TCP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + first := h[0] + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.TCPProtocolNumber { + t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber) + } + + // Verify the checksum. + tcp := header.TCP(last.Payload()) + l := uint16(len(tcp)) + + xsum := header.Checksum([]byte(first.SourceAddress()), 0) + xsum = header.Checksum([]byte(first.DestinationAddress()), xsum) + xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum) + xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum) + xsum = header.Checksum(tcp, xsum) + + if xsum != 0 && xsum != 0xffff { + t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum()) + } + + // Run the transport checkers. + for _, f := range checkers { + f(t, tcp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// UDP creates a checker that checks that the transport protocol is UDP and +// potentially additional transport header fields. +func UDP(checkers ...TransportChecker) NetworkChecker { + return func(t *testing.T, h []header.Network) { + t.Helper() + + last := h[len(h)-1] + + if p := last.TransportProtocol(); p != header.UDPProtocolNumber { + t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber) + } + + udp := header.UDP(last.Payload()) + for _, f := range checkers { + f(t, udp) + } + if t.Failed() { + t.FailNow() + } + } +} + +// SrcPort creates a checker that checks the source port. +func SrcPort(port uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + if p := h.SourcePort(); p != port { + t.Errorf("Bad source port, got %v, want %v", p, port) + } + } +} + +// DstPort creates a checker that checks the destination port. +func DstPort(port uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + if p := h.DestinationPort(); p != port { + t.Errorf("Bad destination port, got %v, want %v", p, port) + } + } +} + +// SeqNum creates a checker that checks the sequence number. +func SeqNum(seq uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + return + } + + if s := tcp.SequenceNumber(); s != seq { + t.Errorf("Bad sequence number, got %v, want %v", s, seq) + } + } +} + +// AckNum creates a checker that checks the ack number. +func AckNum(seq uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) + if !ok { + return + } + + if s := tcp.AckNumber(); s != seq { + t.Errorf("Bad ack number, got %v, want %v", s, seq) + } + } +} + +// Window creates a checker that checks the tcp window. +func Window(window uint16) TransportChecker { + return func(t *testing.T, h header.Transport) { + tcp, ok := h.(header.TCP) + if !ok { + return + } + + if w := tcp.WindowSize(); w != window { + t.Errorf("Bad window, got 0x%x, want 0x%x", w, window) + } + } +} + +// TCPFlags creates a checker that checks the tcp flags. +func TCPFlags(flags uint8) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + tcp, ok := h.(header.TCP) + if !ok { + return + } + + if f := tcp.Flags(); f != flags { + t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags) + } + } +} + +// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the +// given mask, match the supplied flags. +func TCPFlagsMatch(flags, mask uint8) TransportChecker { + return func(t *testing.T, h header.Transport) { + tcp, ok := h.(header.TCP) + if !ok { + return + } + + if f := tcp.Flags(); (f & mask) != (flags & mask) { + t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask) + } + } +} + +// TCPSynOptions creates a checker that checks the presence of TCP options in +// SYN segments. +// +// If wndscale is negative, the window scale option must not be present. +func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker { + return func(t *testing.T, h header.Transport) { + tcp, ok := h.(header.TCP) + if !ok { + return + } + opts := tcp.Options() + limit := len(opts) + foundMSS := false + foundWS := false + foundTS := false + foundSACKPermitted := false + tsVal := uint32(0) + tsEcr := uint32(0) + for i := 0; i < limit; { + switch opts[i] { + case header.TCPOptionEOL: + i = limit + case header.TCPOptionNOP: + i++ + case header.TCPOptionMSS: + v := uint16(opts[i+2])<<8 | uint16(opts[i+3]) + if wantOpts.MSS != v { + t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS) + } + foundMSS = true + i += 4 + case header.TCPOptionWS: + if wantOpts.WS < 0 { + t.Error("WS present when it shouldn't be") + } + v := int(opts[i+2]) + if v != wantOpts.WS { + t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS) + } + foundWS = true + i += 3 + case header.TCPOptionTS: + if i+9 >= limit { + t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i) + } + if opts[i+1] != 10 { + t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit) + } + tsVal = binary.BigEndian.Uint32(opts[i+2:]) + tsEcr = uint32(0) + if tcp.Flags()&header.TCPFlagAck != 0 { + // If the syn is an SYN-ACK then read + // the tsEcr value as well. + tsEcr = binary.BigEndian.Uint32(opts[i+6:]) + } + foundTS = true + i += 10 + case header.TCPOptionSACKPermitted: + if i+1 >= limit { + t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i) + } + if opts[i+1] != 2 { + t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit) + } + foundSACKPermitted = true + i += 2 + + default: + i += int(opts[i+1]) + } + } + + if !foundMSS { + t.Errorf("MSS option not found. Options: %x", opts) + } + + if !foundWS && wantOpts.WS >= 0 { + t.Errorf("WS option not found. Options: %x", opts) + } + if wantOpts.TS && !foundTS { + t.Errorf("TS option not found. Options: %x", opts) + } + if foundTS && tsVal == 0 { + t.Error("TS option specified but the timestamp value is zero") + } + if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 { + t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr) + } + if wantOpts.SACKPermitted && !foundSACKPermitted { + t.Errorf("SACKPermitted option not found. Options: %x", opts) + } + } +} + +// TCPTimestampChecker creates a checker that validates that a TCP segment has a +// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and +// wantTSEcr values with those in the TCP segment (if present). +// +// If wantTSVal or wantTSEcr is zero then the corresponding comparison is +// skipped. +func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + tcp, ok := h.(header.TCP) + if !ok { + return + } + opts := []byte(tcp.Options()) + limit := len(opts) + foundTS := false + tsVal := uint32(0) + tsEcr := uint32(0) + for i := 0; i < limit; { + switch opts[i] { + case header.TCPOptionEOL: + i = limit + case header.TCPOptionNOP: + i++ + case header.TCPOptionTS: + if i+9 >= limit { + t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i) + } + if opts[i+1] != 10 { + t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1]) + } + tsVal = binary.BigEndian.Uint32(opts[i+2:]) + tsEcr = binary.BigEndian.Uint32(opts[i+6:]) + foundTS = true + i += 10 + default: + // We don't recognize this option, just skip over it. + if i+2 > limit { + return + } + l := int(opts[i+1]) + if i < 2 || i+l > limit { + return + } + i += l + } + } + + if wantTS != foundTS { + t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS) + } + if wantTS && wantTSVal != 0 && wantTSVal != tsVal { + t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal) + } + if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr { + t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr) + } + } +} + +// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not +// contain any SACK blocks in the TCP options. +func TCPNoSACKBlockChecker() TransportChecker { + return TCPSACKBlockChecker(nil) +} + +// TCPSACKBlockChecker creates a checker that verifies that the segment does +// contain the specified SACK blocks in the TCP options. +func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + tcp, ok := h.(header.TCP) + if !ok { + return + } + var gotSACKBlocks []header.SACKBlock + + opts := []byte(tcp.Options()) + limit := len(opts) + for i := 0; i < limit; { + switch opts[i] { + case header.TCPOptionEOL: + i = limit + case header.TCPOptionNOP: + i++ + case header.TCPOptionSACK: + if i+2 > limit { + // Malformed SACK block. + t.Errorf("malformed SACK option in options: %v", opts) + } + sackOptionLen := int(opts[i+1]) + if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { + // Malformed SACK block. + t.Errorf("malformed SACK option length in options: %v", opts) + } + numBlocks := sackOptionLen / 8 + for j := 0; j < numBlocks; j++ { + start := binary.BigEndian.Uint32(opts[i+2+j*8:]) + end := binary.BigEndian.Uint32(opts[i+2+j*8+4:]) + gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{ + Start: seqnum.Value(start), + End: seqnum.Value(end), + }) + } + i += sackOptionLen + default: + // We don't recognize this option, just skip over it. + if i+2 > limit { + break + } + l := int(opts[i+1]) + if l < 2 || i+l > limit { + break + } + i += l + } + } + + if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) { + t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks) + } + } +} + +// Payload creates a checker that checks the payload. +func Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + if got := h.Payload(); !reflect.DeepEqual(got, want) { + t.Errorf("Wrong payload, got %v, want %v", got, want) + } + } +} diff --git a/netstack/tcpip/header/arp.go b/netstack/tcpip/header/arp.go new file mode 100644 index 0000000..5ca3dbe --- /dev/null +++ b/netstack/tcpip/header/arp.go @@ -0,0 +1,107 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import "tcpip/netstack/tcpip" + +const ( + // ARPProtocolNumber是ARP协议号,为0x0806 + ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806 + + // ARPSize是ARP报文在IPV4网络下的长度 + ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4 +) + +// ARPOp表示ARP的操作码 +type ARPOp uint16 + +// RFC 826定义的操作码. +const ( + // arp请求 + ARPRequest ARPOp = 1 + // arp应答 + ARPReply ARPOp = 2 +) + +// ARP报文的封装 +type ARP []byte + +// 从报文中得到硬件类型 +func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) } + +// 从报文中得到协议类型 +func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) } + +// 从报文中得到硬件地址的长度 +func (a ARP) hardwareAddressSize() int { return int(a[4]) } + +// 从报文中得到协议的地址长度 +func (a ARP) protocolAddressSize() int { return int(a[5]) } + +// Op从报文中得到arp操作码. +func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) } + +// SetOp设置arp操作码. +func (a ARP) SetOp(op ARPOp) { + a[6] = uint8(op >> 8) + a[7] = uint8(op) +} + +// SetIPv4OverEthernet设置IPV4网络在以太网中arp报文的硬件和协议信息. +func (a ARP) SetIPv4OverEthernet() { + a[0], a[1] = 0, 1 // htypeEthernet + a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber + a[4] = 6 // macSize + a[5] = uint8(IPv4AddressSize) +} + +// HardwareAddressSender从报文中得到arp发送方的硬件地址 +func (a ARP) HardwareAddressSender() []byte { + const s = 8 + return a[s : s+6] +} + +// ProtocolAddressSender从报文中得到arp发送方的协议地址,为ipv4地址 +func (a ARP) ProtocolAddressSender() []byte { + const s = 8 + 6 + return a[s : s+4] +} + +// HardwareAddressTarget从报文中得到arp目的方的硬件地址 +func (a ARP) HardwareAddressTarget() []byte { + const s = 8 + 6 + 4 + return a[s : s+6] +} + +// ProtocolAddressTarget从报文中得到arp目的方的协议地址,为ipv4地址 +func (a ARP) ProtocolAddressTarget() []byte { + const s = 8 + 6 + 4 + 6 + return a[s : s+4] +} + +// IsValid检查arp报文是否有效 +func (a ARP) IsValid() bool { + // 比arp报文的长度小,返回无效 + if len(a) < ARPSize { + return false + } + const htypeEthernet = 1 + const macSize = 6 + // 是否以太网、ipv4、硬件和协议长度都对 + return a.hardwareAddressSpace() == htypeEthernet && + a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) && + a.hardwareAddressSize() == macSize && + a.protocolAddressSize() == IPv4AddressSize +} diff --git a/netstack/tcpip/header/checksum.go b/netstack/tcpip/header/checksum.go new file mode 100644 index 0000000..f17676b --- /dev/null +++ b/netstack/tcpip/header/checksum.go @@ -0,0 +1,57 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package header provides the implementation of the encoding and decoding of +// network protocol headers. +package header + +import ( + "tcpip/netstack/tcpip" +) + +// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the +// given byte array. +// 校验和的计算 +func Checksum(buf []byte, initial uint16) uint16 { + v := uint32(initial) + + l := len(buf) + if l&1 != 0 { + l-- + v += uint32(buf[l]) << 8 + } + + for i := 0; i < l; i += 2 { + v += (uint32(buf[i]) << 8) + uint32(buf[i+1]) + } + + return ChecksumCombine(uint16(v), uint16(v>>16)) +} + +// ChecksumCombine combines the two uint16 to form their checksum. This is done +// by adding them and the carry. +func ChecksumCombine(a, b uint16) uint16 { + v := uint32(a) + uint32(b) + return uint16(v + v>>16) +} + +// PseudoHeaderChecksum calculates the pseudo-header checksum for the +// given destination protocol and network address, ignoring the length +// field. Pseudo-headers are needed by transport layers when calculating +// their own checksum. +func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address) uint16 { + xsum := Checksum([]byte(srcAddr), 0) + xsum = Checksum([]byte(dstAddr), xsum) + return Checksum([]byte{0, uint8(protocol)}, xsum) +} diff --git a/netstack/tcpip/header/eth.go b/netstack/tcpip/header/eth.go new file mode 100644 index 0000000..76e6026 --- /dev/null +++ b/netstack/tcpip/header/eth.go @@ -0,0 +1,72 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" +) + +const ( + dstMAC = 0 + srcMAC = 6 + ethType = 12 +) + +// EthernetFields表示链路层以太网帧的头部 +type EthernetFields struct { + // 源地址 + SrcAddr tcpip.LinkAddress + + // 目的地址 + DstAddr tcpip.LinkAddress + + // 协议类型 + Type tcpip.NetworkProtocolNumber +} + +// Ethernet以太网数据包的封装 +type Ethernet []byte + +const ( + // EthernetMinimumSize以太网帧最小的长度 + EthernetMinimumSize = 14 + + // EthernetAddressSize以太网地址的长度 + EthernetAddressSize = 6 +) + +// SourceAddress从帧头部中得到源地址 +func (b Ethernet) SourceAddress() tcpip.LinkAddress { + return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize]) +} + +// DestinationAddress从帧头部中得到目的地址 +func (b Ethernet) DestinationAddress() tcpip.LinkAddress { + return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize]) +} + +// Type从帧头部中得到协议类型 +func (b Ethernet) Type() tcpip.NetworkProtocolNumber { + return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:])) +} + +// Encode根据传入的帧头部信息编码成Ethernet二进制形式,注意Ethernet应先分配好内存 +func (b Ethernet) Encode(e *EthernetFields) { + binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type)) + copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr) + copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr) +} diff --git a/netstack/tcpip/header/gue.go b/netstack/tcpip/header/gue.go new file mode 100644 index 0000000..af9dfab --- /dev/null +++ b/netstack/tcpip/header/gue.go @@ -0,0 +1,73 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +const ( + typeHLen = 0 + encapProto = 1 +) + +// GUEFields contains the fields of a GUE packet. It is used to describe the +// fields of a packet that needs to be encoded. +type GUEFields struct { + // Type is the "type" field of the GUE header. + Type uint8 + + // Control is the "control" field of the GUE header. + Control bool + + // HeaderLength is the "header length" field of the GUE header. It must + // be at least 4 octets, and a multiple of 4 as well. + HeaderLength uint8 + + // Protocol is the "protocol" field of the GUE header. This is one of + // the IPPROTO_* values. + Protocol uint8 +} + +// GUE represents a Generic UDP Encapsulation header stored in a byte array, the +// fields are described in https://tools.ietf.org/html/draft-ietf-nvo3-gue-01. +type GUE []byte + +const ( + // GUEMinimumSize is the minimum size of a valid GUE packet. + GUEMinimumSize = 4 +) + +// TypeAndControl returns the GUE packet type (top 3 bits of the first byte, +// which includes the control bit). +func (b GUE) TypeAndControl() uint8 { + return b[typeHLen] >> 5 +} + +// HeaderLength returns the total length of the GUE header. +func (b GUE) HeaderLength() uint8 { + return 4 + 4*(b[typeHLen]&0x1f) +} + +// Protocol returns the protocol field of the GUE header. +func (b GUE) Protocol() uint8 { + return b[encapProto] +} + +// Encode encodes all the fields of the GUE header. +func (b GUE) Encode(i *GUEFields) { + ctl := uint8(0) + if i.Control { + ctl = 1 << 5 + } + b[typeHLen] = ctl | i.Type<<6 | (i.HeaderLength-4)/4 + b[encapProto] = i.Protocol +} diff --git a/netstack/tcpip/header/icmpv4.go b/netstack/tcpip/header/icmpv4.go new file mode 100644 index 0000000..cb1df11 --- /dev/null +++ b/netstack/tcpip/header/icmpv4.go @@ -0,0 +1,108 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" +) + +// ICMPv4 represents an ICMPv4 header stored in a byte array. +type ICMPv4 []byte + +const ( + // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. + ICMPv4MinimumSize = 4 + + // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet. + ICMPv4EchoMinimumSize = 6 + + // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP + // destination unreachable packet. + ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4 + + // ICMPv4ProtocolNumber is the ICMP transport protocol number. + ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 +) + +// ICMPv4Type is the ICMP type field described in RFC 792. +type ICMPv4Type byte + +// Typical values of ICMPv4Type defined in RFC 792. +const ( + ICMPv4EchoReply ICMPv4Type = 0 + ICMPv4DstUnreachable ICMPv4Type = 3 + ICMPv4SrcQuench ICMPv4Type = 4 + ICMPv4Redirect ICMPv4Type = 5 + ICMPv4Echo ICMPv4Type = 8 + ICMPv4TimeExceeded ICMPv4Type = 11 + ICMPv4ParamProblem ICMPv4Type = 12 + ICMPv4Timestamp ICMPv4Type = 13 + ICMPv4TimestampReply ICMPv4Type = 14 + ICMPv4InfoRequest ICMPv4Type = 15 + ICMPv4InfoReply ICMPv4Type = 16 +) + +// Values for ICMP code as defined in RFC 792. +const ( + ICMPv4PortUnreachable = 3 + ICMPv4FragmentationNeeded = 4 +) + +// Type is the ICMP type field. +func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) } + +// SetType sets the ICMP type field. +func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) } + +// Code is the ICMP code field. Its meaning depends on the value of Type. +func (b ICMPv4) Code() byte { return b[1] } + +// SetCode sets the ICMP code field. +func (b ICMPv4) SetCode(c byte) { b[1] = c } + +// Checksum is the ICMP checksum field. +func (b ICMPv4) Checksum() uint16 { + return binary.BigEndian.Uint16(b[2:]) +} + +// SetChecksum sets the ICMP checksum field. +func (b ICMPv4) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[2:], checksum) +} + +// SourcePort implements Transport.SourcePort. +func (ICMPv4) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (ICMPv4) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (ICMPv4) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (ICMPv4) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (b ICMPv4) Payload() []byte { + return b[ICMPv4MinimumSize:] +} diff --git a/netstack/tcpip/header/icmpv6.go b/netstack/tcpip/header/icmpv6.go new file mode 100644 index 0000000..00be954 --- /dev/null +++ b/netstack/tcpip/header/icmpv6.go @@ -0,0 +1,121 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" +) + +// ICMPv6 represents an ICMPv6 header stored in a byte array. +type ICMPv6 []byte + +const ( + // ICMPv6MinimumSize is the minimum size of a valid ICMP packet. + ICMPv6MinimumSize = 4 + + // ICMPv6ProtocolNumber is the ICMP transport protocol number. + ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58 + + // ICMPv6NeighborSolicitMinimumSize is the minimum size of a + // neighbor solicitation packet. + ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16 + + // ICMPv6NeighborAdvertSize is size of a neighbor advertisement. + ICMPv6NeighborAdvertSize = 32 + + // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet. + ICMPv6EchoMinimumSize = 8 + + // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP + // destination unreachable packet. + ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4 + + // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP + // packet-too-big packet. + ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4 +) + +// ICMPv6Type is the ICMP type field described in RFC 4443 and friends. +type ICMPv6Type byte + +// Typical values of ICMPv6Type defined in RFC 4443. +const ( + ICMPv6DstUnreachable ICMPv6Type = 1 + ICMPv6PacketTooBig ICMPv6Type = 2 + ICMPv6TimeExceeded ICMPv6Type = 3 + ICMPv6ParamProblem ICMPv6Type = 4 + ICMPv6EchoRequest ICMPv6Type = 128 + ICMPv6EchoReply ICMPv6Type = 129 + + // Neighbor Discovery Protocol (NDP) messages, see RFC 4861. + + ICMPv6RouterSolicit ICMPv6Type = 133 + ICMPv6RouterAdvert ICMPv6Type = 134 + ICMPv6NeighborSolicit ICMPv6Type = 135 + ICMPv6NeighborAdvert ICMPv6Type = 136 + ICMPv6RedirectMsg ICMPv6Type = 137 +) + +// Values for ICMP code as defined in RFC 4443. +const ( + ICMPv6PortUnreachable = 4 +) + +// Type is the ICMP type field. +func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) } + +// SetType sets the ICMP type field. +func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) } + +// Code is the ICMP code field. Its meaning depends on the value of Type. +func (b ICMPv6) Code() byte { return b[1] } + +// SetCode sets the ICMP code field. +func (b ICMPv6) SetCode(c byte) { b[1] = c } + +// Checksum is the ICMP checksum field. +func (b ICMPv6) Checksum() uint16 { + return binary.BigEndian.Uint16(b[2:]) +} + +// SetChecksum calculates and sets the ICMP checksum field. +func (b ICMPv6) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[2:], checksum) +} + +// SourcePort implements Transport.SourcePort. +func (ICMPv6) SourcePort() uint16 { + return 0 +} + +// DestinationPort implements Transport.DestinationPort. +func (ICMPv6) DestinationPort() uint16 { + return 0 +} + +// SetSourcePort implements Transport.SetSourcePort. +func (ICMPv6) SetSourcePort(uint16) { +} + +// SetDestinationPort implements Transport.SetDestinationPort. +func (ICMPv6) SetDestinationPort(uint16) { +} + +// Payload implements Transport.Payload. +func (b ICMPv6) Payload() []byte { + return b[ICMPv6MinimumSize:] +} diff --git a/netstack/tcpip/header/interfaces.go b/netstack/tcpip/header/interfaces.go new file mode 100644 index 0000000..739a395 --- /dev/null +++ b/netstack/tcpip/header/interfaces.go @@ -0,0 +1,92 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "tcpip/netstack/tcpip" +) + +const ( + // MaxIPPacketSize is the maximum supported IP packet size, excluding + // jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit + // in 16 bits). For IPv6, the payload max size (excluding jumbograms) is + // 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where + // m is the minimum IPv6 header size; we leave room for some potential + // IP options. + MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize +) + +// Transport offers generic methods to query and/or update the fields of the +// header of a transport protocol buffer. +type Transport interface { + // SourcePort returns the value of the "source port" field. + SourcePort() uint16 + + // Destination returns the value of the "destination port" field. + DestinationPort() uint16 + + // Checksum returns the value of the "checksum" field. + Checksum() uint16 + + // SetSourcePort sets the value of the "source port" field. + SetSourcePort(uint16) + + // SetDestinationPort sets the value of the "destination port" field. + SetDestinationPort(uint16) + + // SetChecksum sets the value of the "checksum" field. + SetChecksum(uint16) + + // Payload returns the data carried in the transport buffer. + Payload() []byte +} + +// Network offers generic methods to query and/or update the fields of the +// header of a network protocol buffer. +type Network interface { + // SourceAddress returns the value of the "source address" field. + SourceAddress() tcpip.Address + + // DestinationAddress returns the value of the "destination address" + // field. + DestinationAddress() tcpip.Address + + // Checksum returns the value of the "checksum" field. + Checksum() uint16 + + // SetSourceAddress sets the value of the "source address" field. + SetSourceAddress(tcpip.Address) + + // SetDestinationAddress sets the value of the "destination address" + // field. + SetDestinationAddress(tcpip.Address) + + // SetChecksum sets the value of the "checksum" field. + SetChecksum(uint16) + + // TransportProtocol returns the number of the transport protocol + // stored in the payload. + TransportProtocol() tcpip.TransportProtocolNumber + + // Payload returns a byte slice containing the payload of the network + // packet. + Payload() []byte + + // TOS returns the values of the "type of service" and "flow label" fields. + TOS() (uint8, uint32) + + // SetTOS sets the values of the "type of service" and "flow label" fields. + SetTOS(t uint8, l uint32) +} diff --git a/netstack/tcpip/header/ipv4.go b/netstack/tcpip/header/ipv4.go new file mode 100644 index 0000000..c9605e8 --- /dev/null +++ b/netstack/tcpip/header/ipv4.go @@ -0,0 +1,289 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" +) + +const ( + versIHL = 0 + tos = 1 + totalLen = 2 + id = 4 + flagsFO = 6 + ttl = 8 + protocol = 9 + checksum = 10 + srcAddr = 12 + dstAddr = 16 +) + +// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the +// fields of a packet that needs to be encoded. +// 表示IPv4头部信息的结构体 +type IPv4Fields struct { + // IHL is the "internet header length" field of an IPv4 packet. + // 头部长度 + IHL uint8 + + // TOS is the "type of service" field of an IPv4 packet. + // 服务区分的表示 + TOS uint8 + + // TotalLength is the "total length" field of an IPv4 packet. + // 数据报文总长 + TotalLength uint16 + + // ID is the "identification" field of an IPv4 packet. + // 标识符 + ID uint16 + + // Flags is the "flags" field of an IPv4 packet. + // 标签 + Flags uint8 + + // FragmentOffset is the "fragment offset" field of an IPv4 packet. + // 分片偏移 + FragmentOffset uint16 + + // TTL is the "time to live" field of an IPv4 packet. + // 存活时间 + TTL uint8 + + // Protocol is the "protocol" field of an IPv4 packet. + // 表示的传输层协议 + Protocol uint8 + + // Checksum is the "checksum" field of an IPv4 packet. + // 首部校验和 + Checksum uint16 + + // SrcAddr is the "source ip address" of an IPv4 packet. + // 源IP地址 + SrcAddr tcpip.Address + + // DstAddr is the "destination ip address" of an IPv4 packet. + // 目的IP地址 + DstAddr tcpip.Address +} + +// IPv4 represents an ipv4 header stored in a byte array. +// Most of the methods of IPv4 access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv4 before using other methods. +type IPv4 []byte + +const ( + // IPv4MinimumSize is the minimum size of a valid IPv4 packet. + IPv4MinimumSize = 20 + + // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given + // that there are only 4 bits to represents the header length in 32-bit + // units, the header cannot exceed 15*4 = 60 bytes. + IPv4MaximumHeaderSize = 60 + + // IPv4AddressSize is the size, in bytes, of an IPv4 address. + IPv4AddressSize = 4 + + // IPv4ProtocolNumber is IPv4's network protocol number. + IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 + + // IPv4Version is the version of the ipv4 protocol. + IPv4Version = 4 + + // IPv4Broadcast is the broadcast address of the IPv4 procotol. + IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" + + // IPv4Any is the non-routable IPv4 "any" meta address. + IPv4Any tcpip.Address = "\x00\x00\x00\x00" +) + +// Flags that may be set in an IPv4 packet. +const ( + IPv4FlagMoreFragments = 1 << iota + IPv4FlagDontFragment +) + +// IPVersion returns the version of IP used in the given packet. It returns -1 +// if the packet is not large enough to contain the version field. +func IPVersion(b []byte) int { + // Length must be at least offset+length of version field. + if len(b) < versIHL+1 { + return -1 + } + return int(b[versIHL] >> 4) +} + +// HeaderLength returns the value of the "header length" field of the ipv4 +// header. +func (b IPv4) HeaderLength() uint8 { + return (b[versIHL] & 0xf) * 4 +} + +// ID returns the value of the identifier field of the ipv4 header. +func (b IPv4) ID() uint16 { + return binary.BigEndian.Uint16(b[id:]) +} + +// Protocol returns the value of the protocol field of the ipv4 header. +func (b IPv4) Protocol() uint8 { + return b[protocol] +} + +// Flags returns the "flags" field of the ipv4 header. +func (b IPv4) Flags() uint8 { + return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) +} + +// TTL returns the "TTL" field of the ipv4 header. +func (b IPv4) TTL() uint8 { + return b[ttl] +} + +// FragmentOffset returns the "fragment offset" field of the ipv4 header. +func (b IPv4) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[flagsFO:]) << 3 +} + +// TotalLength returns the "total length" field of the ipv4 header. +func (b IPv4) TotalLength() uint16 { + return binary.BigEndian.Uint16(b[totalLen:]) +} + +// Checksum returns the checksum field of the ipv4 header. +func (b IPv4) Checksum() uint16 { + return binary.BigEndian.Uint16(b[checksum:]) +} + +// SourceAddress returns the "source address" field of the ipv4 header. +func (b IPv4) SourceAddress() tcpip.Address { + return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize]) +} + +// DestinationAddress returns the "destination address" field of the ipv4 +// header. +func (b IPv4) DestinationAddress() tcpip.Address { + return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.Protocol()) +} + +// Payload implements Network.Payload. +func (b IPv4) Payload() []byte { + return b[b.HeaderLength():][:b.PayloadLength()] +} + +// PayloadLength returns the length of the payload portion of the ipv4 packet. +func (b IPv4) PayloadLength() uint16 { + return b.TotalLength() - uint16(b.HeaderLength()) +} + +// TOS returns the "type of service" field of the ipv4 header. +func (b IPv4) TOS() (uint8, uint32) { + return b[tos], 0 +} + +// SetTOS sets the "type of service" field of the ipv4 header. +func (b IPv4) SetTOS(v uint8, _ uint32) { + b[tos] = v +} + +// SetTotalLength sets the "total length" field of the ipv4 header. +func (b IPv4) SetTotalLength(totalLength uint16) { + binary.BigEndian.PutUint16(b[totalLen:], totalLength) +} + +// SetChecksum sets the checksum field of the ipv4 header. +func (b IPv4) SetChecksum(v uint16) { + binary.BigEndian.PutUint16(b[checksum:], v) +} + +// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the +// ipv4 header. +func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { + v := (uint16(flags) << 13) | (offset >> 3) + binary.BigEndian.PutUint16(b[flagsFO:], v) +} + +// SetSourceAddress sets the "source address" field of the ipv4 header. +func (b IPv4) SetSourceAddress(addr tcpip.Address) { + copy(b[srcAddr:srcAddr+IPv4AddressSize], addr) +} + +// SetDestinationAddress sets the "destination address" field of the ipv4 +// header. +func (b IPv4) SetDestinationAddress(addr tcpip.Address) { + copy(b[dstAddr:dstAddr+IPv4AddressSize], addr) +} + +// CalculateChecksum calculates the checksum of the ipv4 header. +func (b IPv4) CalculateChecksum() uint16 { + return Checksum(b[:b.HeaderLength()], 0) +} + +// Encode encodes all the fields of the ipv4 header. +func (b IPv4) Encode(i *IPv4Fields) { + b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf) + b[tos] = i.TOS + b.SetTotalLength(i.TotalLength) + binary.BigEndian.PutUint16(b[id:], i.ID) + b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset) + b[ttl] = i.TTL + b[protocol] = i.Protocol + b.SetChecksum(i.Checksum) + copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr) + copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) +} + +// EncodePartial updates the total length and checksum fields of ipv4 header, +// taking in the partial checksum, which is the checksum of the header without +// the total length and checksum fields. It is useful in cases when similar +// packets are produced. +func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) { + b.SetTotalLength(totalLength) + checksum := Checksum(b[totalLen:totalLen+2], partialChecksum) + b.SetChecksum(^checksum) +} + +// IsValid performs basic validation on the packet. +func (b IPv4) IsValid(pktSize int) bool { + if len(b) < IPv4MinimumSize { + return false + } + + hlen := int(b.HeaderLength()) + tlen := int(b.TotalLength()) + if hlen > tlen || tlen > pktSize { + return false + } + + return true +} + +// IsV4MulticastAddress determines if the provided address is an IPv4 multicast +// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits +// will be 1110 = 0xe0. +func IsV4MulticastAddress(addr tcpip.Address) bool { + if len(addr) != IPv4AddressSize { + return false + } + return (addr[0] & 0xf0) == 0xe0 +} diff --git a/netstack/tcpip/header/ipv6.go b/netstack/tcpip/header/ipv6.go new file mode 100644 index 0000000..592d28f --- /dev/null +++ b/netstack/tcpip/header/ipv6.go @@ -0,0 +1,236 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + "strings" + + "tcpip/netstack/tcpip" +) + +const ( + versTCFL = 0 + payloadLen = 4 + nextHdr = 6 + hopLimit = 7 + v6SrcAddr = 8 + v6DstAddr = 24 +) + +// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the +// fields of a packet that needs to be encoded. +type IPv6Fields struct { + // TrafficClass is the "traffic class" field of an IPv6 packet. + TrafficClass uint8 + + // FlowLabel is the "flow label" field of an IPv6 packet. + FlowLabel uint32 + + // PayloadLength is the "payload length" field of an IPv6 packet. + PayloadLength uint16 + + // NextHeader is the "next header" field of an IPv6 packet. + NextHeader uint8 + + // HopLimit is the "hop limit" field of an IPv6 packet. + HopLimit uint8 + + // SrcAddr is the "source ip address" of an IPv6 packet. + SrcAddr tcpip.Address + + // DstAddr is the "destination ip address" of an IPv6 packet. + DstAddr tcpip.Address +} + +// IPv6 represents an ipv6 header stored in a byte array. +// Most of the methods of IPv6 access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv6 before using other methods. +type IPv6 []byte + +const ( + // IPv6MinimumSize is the minimum size of a valid IPv6 packet. + IPv6MinimumSize = 40 + + // IPv6AddressSize is the size, in bytes, of an IPv6 address. + IPv6AddressSize = 16 + + // IPv6ProtocolNumber is IPv6's network protocol number. + IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd + + // IPv6Version is the version of the ipv6 protocol. + IPv6Version = 6 + + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, + // section 5. + IPv6MinimumMTU = 1280 +) + +// PayloadLength returns the value of the "payload length" field of the ipv6 +// header. +func (b IPv6) PayloadLength() uint16 { + return binary.BigEndian.Uint16(b[payloadLen:]) +} + +// HopLimit returns the value of the "hop limit" field of the ipv6 header. +func (b IPv6) HopLimit() uint8 { + return b[hopLimit] +} + +// NextHeader returns the value of the "next header" field of the ipv6 header. +func (b IPv6) NextHeader() uint8 { + return b[nextHdr] +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.NextHeader()) +} + +// Payload implements Network.Payload. +func (b IPv6) Payload() []byte { + return b[IPv6MinimumSize:][:b.PayloadLength()] +} + +// SourceAddress returns the "source address" field of the ipv6 header. +func (b IPv6) SourceAddress() tcpip.Address { + return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize]) +} + +// DestinationAddress returns the "destination address" field of the ipv6 +// header. +func (b IPv6) DestinationAddress() tcpip.Address { + return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize]) +} + +// Checksum implements Network.Checksum. Given that IPv6 doesn't have a +// checksum, it just returns 0. +func (IPv6) Checksum() uint16 { + return 0 +} + +// TOS returns the "traffic class" and "flow label" fields of the ipv6 header. +func (b IPv6) TOS() (uint8, uint32) { + v := binary.BigEndian.Uint32(b[versTCFL:]) + return uint8(v >> 20), v & 0xfffff +} + +// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header. +func (b IPv6) SetTOS(t uint8, l uint32) { + vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff) + binary.BigEndian.PutUint32(b[versTCFL:], vtf) +} + +// SetPayloadLength sets the "payload length" field of the ipv6 header. +func (b IPv6) SetPayloadLength(payloadLength uint16) { + binary.BigEndian.PutUint16(b[payloadLen:], payloadLength) +} + +// SetSourceAddress sets the "source address" field of the ipv6 header. +func (b IPv6) SetSourceAddress(addr tcpip.Address) { + copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr) +} + +// SetDestinationAddress sets the "destination address" field of the ipv6 +// header. +func (b IPv6) SetDestinationAddress(addr tcpip.Address) { + copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr) +} + +// SetNextHeader sets the value of the "next header" field of the ipv6 header. +func (b IPv6) SetNextHeader(v uint8) { + b[nextHdr] = v +} + +// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a +// checksum, it is empty. +func (IPv6) SetChecksum(uint16) { +} + +// Encode encodes all the fields of the ipv6 header. +func (b IPv6) Encode(i *IPv6Fields) { + b.SetTOS(i.TrafficClass, i.FlowLabel) + b.SetPayloadLength(i.PayloadLength) + b[nextHdr] = i.NextHeader + b[hopLimit] = i.HopLimit + copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr) + copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr) +} + +// IsValid performs basic validation on the packet. +func (b IPv6) IsValid(pktSize int) bool { + if len(b) < IPv6MinimumSize { + return false + } + + dlen := int(b.PayloadLength()) + if dlen > pktSize-IPv6MinimumSize { + return false + } + + return true +} + +// IsV4MappedAddress determines if the provided address is an IPv4 mapped +// address by checking if its prefix is 0:0:0:0:0:ffff::/96. +func IsV4MappedAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + + return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff") +} + +// IsV6MulticastAddress determines if the provided address is an IPv6 +// multicast address (anything starting with FF). +func IsV6MulticastAddress(addr tcpip.Address) bool { + if len(addr) != IPv6AddressSize { + return false + } + return addr[0] == 0xff +} + +// SolicitedNodeAddr computes the solicited-node multicast address. This is +// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6 +// address. +func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address { + const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff" + return solicitedNodeMulticastPrefix + addr[len(addr)-3:] +} + +// LinkLocalAddr computes the default IPv6 link-local address from a link-layer +// (MAC) address. +func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address { + // Convert a 48-bit MAC to an EUI-64 and then prepend the link-local + // header, FE80::. + // + // The conversion is very nearly: + // aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff + // Note the capital A. The conversion aa->Aa involves a bit flip. + lladdrb := [16]byte{ + 0: 0xFE, + 1: 0x80, + 8: linkAddr[0] ^ 2, + 9: linkAddr[1], + 10: linkAddr[2], + 11: 0xFF, + 12: 0xFE, + 13: linkAddr[3], + 14: linkAddr[4], + 15: linkAddr[5], + } + return tcpip.Address(lladdrb[:]) +} diff --git a/netstack/tcpip/header/ipv6_fragment.go b/netstack/tcpip/header/ipv6_fragment.go new file mode 100644 index 0000000..2caeaa3 --- /dev/null +++ b/netstack/tcpip/header/ipv6_fragment.go @@ -0,0 +1,146 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" +) + +const ( + nextHdrFrag = 0 + fragOff = 2 + more = 3 + idV6 = 4 +) + +// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the +// fields of a packet that needs to be encoded. +type IPv6FragmentFields struct { + // NextHeader is the "next header" field of an IPv6 fragment. + NextHeader uint8 + + // FragmentOffset is the "fragment offset" field of an IPv6 fragment. + FragmentOffset uint16 + + // M is the "more" field of an IPv6 fragment. + M bool + + // Identification is the "identification" field of an IPv6 fragment. + Identification uint32 +} + +// IPv6Fragment represents an ipv6 fragment header stored in a byte array. +// Most of the methods of IPv6Fragment access to the underlying slice without +// checking the boundaries and could panic because of 'index out of range'. +// Always call IsValid() to validate an instance of IPv6Fragment before using other methods. +type IPv6Fragment []byte + +const ( + // IPv6FragmentHeader header is the number used to specify that the next + // header is a fragment header, per RFC 2460. + IPv6FragmentHeader = 44 + + // IPv6FragmentHeaderSize is the size of the fragment header. + IPv6FragmentHeaderSize = 8 +) + +// Encode encodes all the fields of the ipv6 fragment. +func (b IPv6Fragment) Encode(i *IPv6FragmentFields) { + b[nextHdrFrag] = i.NextHeader + binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3) + if i.M { + b[more] |= 1 + } + binary.BigEndian.PutUint32(b[idV6:], i.Identification) +} + +// IsValid performs basic validation on the fragment header. +func (b IPv6Fragment) IsValid() bool { + return len(b) >= IPv6FragmentHeaderSize +} + +// NextHeader returns the value of the "next header" field of the ipv6 fragment. +func (b IPv6Fragment) NextHeader() uint8 { + return b[nextHdrFrag] +} + +// FragmentOffset returns the "fragment offset" field of the ipv6 fragment. +func (b IPv6Fragment) FragmentOffset() uint16 { + return binary.BigEndian.Uint16(b[fragOff:]) >> 3 +} + +// More returns the "more" field of the ipv6 fragment. +func (b IPv6Fragment) More() bool { + return b[more]&1 > 0 +} + +// Payload implements Network.Payload. +func (b IPv6Fragment) Payload() []byte { + return b[IPv6FragmentHeaderSize:] +} + +// ID returns the value of the identifier field of the ipv6 fragment. +func (b IPv6Fragment) ID() uint32 { + return binary.BigEndian.Uint32(b[idV6:]) +} + +// TransportProtocol implements Network.TransportProtocol. +func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber { + return tcpip.TransportProtocolNumber(b.NextHeader()) +} + +// The functions below have been added only to satisfy the Network interface. + +// Checksum is not supported by IPv6Fragment. +func (b IPv6Fragment) Checksum() uint16 { + panic("not supported") +} + +// SourceAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SourceAddress() tcpip.Address { + panic("not supported") +} + +// DestinationAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) DestinationAddress() tcpip.Address { + panic("not supported") +} + +// SetSourceAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SetSourceAddress(tcpip.Address) { + panic("not supported") +} + +// SetDestinationAddress is not supported by IPv6Fragment. +func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) { + panic("not supported") +} + +// SetChecksum is not supported by IPv6Fragment. +func (b IPv6Fragment) SetChecksum(uint16) { + panic("not supported") +} + +// TOS is not supported by IPv6Fragment. +func (b IPv6Fragment) TOS() (uint8, uint32) { + panic("not supported") +} + +// SetTOS is not supported by IPv6Fragment. +func (b IPv6Fragment) SetTOS(t uint8, l uint32) { + panic("not supported") +} diff --git a/netstack/tcpip/header/ipversion_test.go b/netstack/tcpip/header/ipversion_test.go new file mode 100644 index 0000000..79d7088 --- /dev/null +++ b/netstack/tcpip/header/ipversion_test.go @@ -0,0 +1,67 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "testing" + + "tcpip/netstack/tcpip/header" +) + +func TestIPv4(t *testing.T) { + b := header.IPv4(make([]byte, header.IPv4MinimumSize)) + b.Encode(&header.IPv4Fields{}) + + const want = header.IPv4Version + if v := header.IPVersion(b); v != want { + t.Fatalf("Bad version, want %v, got %v", want, v) + } +} + +func TestIPv6(t *testing.T) { + b := header.IPv6(make([]byte, header.IPv6MinimumSize)) + b.Encode(&header.IPv6Fields{}) + + const want = header.IPv6Version + if v := header.IPVersion(b); v != want { + t.Fatalf("Bad version, want %v, got %v", want, v) + } +} + +func TestOtherVersion(t *testing.T) { + const want = header.IPv4Version + header.IPv6Version + b := make([]byte, 1) + b[0] = want << 4 + + if v := header.IPVersion(b); v != want { + t.Fatalf("Bad version, want %v, got %v", want, v) + } +} + +func TestTooShort(t *testing.T) { + b := make([]byte, 1) + b[0] = (header.IPv4Version + header.IPv6Version) << 4 + + // Get the version of a zero-length slice. + const want = -1 + if v := header.IPVersion(b[:0]); v != want { + t.Fatalf("Bad version, want %v, got %v", want, v) + } + + // Get the version of a nil slice. + if v := header.IPVersion(nil); v != want { + t.Fatalf("Bad version, want %v, got %v", want, v) + } +} diff --git a/netstack/tcpip/header/tcp.go b/netstack/tcpip/header/tcp.go new file mode 100644 index 0000000..0fd7d07 --- /dev/null +++ b/netstack/tcpip/header/tcp.go @@ -0,0 +1,536 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/seqnum" +) + +const ( + srcPort = 0 + dstPort = 2 + seqNum = 4 + ackNum = 8 + dataOffset = 12 + tcpFlags = 13 + winSize = 14 + tcpChecksum = 16 + urgentPtr = 18 +) + +const ( + // MaxWndScale is maximum allowed window scaling, as described in + // RFC 1323, section 2.3, page 11. + MaxWndScale = 14 + + // TCPMaxSACKBlocks is the maximum number of SACK blocks that can + // be encoded in a TCP option field. + TCPMaxSACKBlocks = 4 +) + +// Flags that may be set in a TCP segment. +const ( + TCPFlagFin = 1 << iota + TCPFlagSyn + TCPFlagRst + TCPFlagPsh + TCPFlagAck + TCPFlagUrg +) + +// Options that may be present in a TCP segment. +const ( + TCPOptionEOL = 0 + TCPOptionNOP = 1 + TCPOptionMSS = 2 + TCPOptionWS = 3 + TCPOptionTS = 8 + TCPOptionSACKPermitted = 4 + TCPOptionSACK = 5 +) + +// TCPFields contains the fields of a TCP packet. It is used to describe the +// fields of a packet that needs to be encoded. +// tcp首部字段 +type TCPFields struct { + // SrcPort is the "source port" field of a TCP packet. + SrcPort uint16 + + // DstPort is the "destination port" field of a TCP packet. + DstPort uint16 + + // SeqNum is the "sequence number" field of a TCP packet. + SeqNum uint32 + + // AckNum is the "acknowledgement number" field of a TCP packet. + AckNum uint32 + + // DataOffset is the "data offset" field of a TCP packet. + DataOffset uint8 + + // Flags is the "flags" field of a TCP packet. + Flags uint8 + + // WindowSize is the "window size" field of a TCP packet. + WindowSize uint16 + + // Checksum is the "checksum" field of a TCP packet. + Checksum uint16 + + // UrgentPointer is the "urgent pointer" field of a TCP packet. + UrgentPointer uint16 +} + +// TCPSynOptions is used to return the parsed TCP Options in a syn +// segment. +// syn 报文的选项 +type TCPSynOptions struct { + // MSS is the maximum segment size provided by the peer in the SYN. + MSS uint16 + + // WS is the window scale option provided by the peer in the SYN. + // + // Set to -1 if no window scale option was provided. + WS int + + // TS is true if the timestamp option was provided in the syn/syn-ack. + TS bool + + // TSVal is the value of the TSVal field in the timestamp option. + TSVal uint32 + + // TSEcr is the value of the TSEcr field in the timestamp option. + TSEcr uint32 + + // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK. + SACKPermitted bool +} + +// SACKBlock represents a single contiguous SACK block. +// +// +stateify savable +// 表示 sack 块的结构体 +type SACKBlock struct { + // Start indicates the lowest sequence number in the block. + Start seqnum.Value + + // End indicates the sequence number immediately following the last + // sequence number of this block. + End seqnum.Value +} + +// TCPOptions are used to parse and cache the TCP segment options for a non +// syn/syn-ack segment. +// +// +stateify savable +// tcp选项结构,这个结构不表示 syn/syn-ack 报文 +type TCPOptions struct { + // TS is true if the TimeStamp option is enabled. + TS bool + + // TSVal is the value in the TSVal field of the segment. + TSVal uint32 + + // TSEcr is the value in the TSEcr field of the segment. + TSEcr uint32 + + // SACKBlocks are the SACK blocks specified in the segment. + SACKBlocks []SACKBlock +} + +// TCP represents a TCP header stored in a byte array. +type TCP []byte + +const ( + // TCPMinimumSize is the minimum size of a valid TCP packet. + TCPMinimumSize = 20 + + // TCPProtocolNumber is TCP's transport protocol number. + TCPProtocolNumber tcpip.TransportProtocolNumber = 6 +) + +// SourcePort returns the "source port" field of the tcp header. +func (b TCP) SourcePort() uint16 { + return binary.BigEndian.Uint16(b[srcPort:]) +} + +// DestinationPort returns the "destination port" field of the tcp header. +func (b TCP) DestinationPort() uint16 { + return binary.BigEndian.Uint16(b[dstPort:]) +} + +// SequenceNumber returns the "sequence number" field of the tcp header. +func (b TCP) SequenceNumber() uint32 { + return binary.BigEndian.Uint32(b[seqNum:]) +} + +// AckNumber returns the "ack number" field of the tcp header. +func (b TCP) AckNumber() uint32 { + return binary.BigEndian.Uint32(b[ackNum:]) +} + +// DataOffset returns the "data offset" field of the tcp header. +func (b TCP) DataOffset() uint8 { + return (b[dataOffset] >> 4) * 4 +} + +// Payload returns the data in the tcp packet. +func (b TCP) Payload() []byte { + return b[b.DataOffset():] +} + +// Flags returns the flags field of the tcp header. +func (b TCP) Flags() uint8 { + return b[tcpFlags] +} + +// WindowSize returns the "window size" field of the tcp header. +func (b TCP) WindowSize() uint16 { + return binary.BigEndian.Uint16(b[winSize:]) +} + +// Checksum returns the "checksum" field of the tcp header. +func (b TCP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[tcpChecksum:]) +} + +// SetSourcePort sets the "source port" field of the tcp header. +func (b TCP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[srcPort:], port) +} + +// SetDestinationPort sets the "destination port" field of the tcp header. +func (b TCP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[dstPort:], port) +} + +// SetChecksum sets the checksum field of the tcp header. +func (b TCP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[tcpChecksum:], checksum) +} + +// CalculateChecksum calculates the checksum of the tcp segment given +// the totalLen and partialChecksum(descriptions below) +// totalLen is the total length of the segment +// partialChecksum is the checksum of the network-layer pseudo-header +// (excluding the total length) and the checksum of the segment data. +func (b TCP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { + // Add the length portion of the checksum to the pseudo-checksum. + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + checksum := Checksum(tmp, partialChecksum) + + // Calculate the rest of the checksum. + return Checksum(b[:b.DataOffset()], checksum) +} + +// Options returns a slice that holds the unparsed TCP options in the segment. +func (b TCP) Options() []byte { + return b[TCPMinimumSize:b.DataOffset()] +} + +// ParsedOptions returns a TCPOptions structure which parses and caches the TCP +// option values in the TCP segment. NOTE: Invoking this function repeatedly is +// expensive as it reparses the options on each invocation. +func (b TCP) ParsedOptions() TCPOptions { + return ParseTCPOptions(b.Options()) +} + +func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) { + binary.BigEndian.PutUint32(b[seqNum:], seq) + binary.BigEndian.PutUint32(b[ackNum:], ack) + b[tcpFlags] = flags + binary.BigEndian.PutUint16(b[winSize:], rcvwnd) +} + +// Encode encodes all the fields of the tcp header. +func (b TCP) Encode(t *TCPFields) { + b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) + binary.BigEndian.PutUint16(b[srcPort:], t.SrcPort) + binary.BigEndian.PutUint16(b[dstPort:], t.DstPort) + b[dataOffset] = (t.DataOffset / 4) << 4 + binary.BigEndian.PutUint16(b[tcpChecksum:], t.Checksum) + binary.BigEndian.PutUint16(b[urgentPtr:], t.UrgentPointer) +} + +// EncodePartial updates a subset of the fields of the tcp header. It is useful +// in cases when similar segments are produced. +func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) { + // Add the total length and "flags" field contributions to the checksum. + // We don't use the flags field directly from the header because it's a + // one-byte field with an odd offset, so it would be accounted for + // incorrectly by the Checksum routine. + tmp := make([]byte, 4) + binary.BigEndian.PutUint16(tmp, length) + binary.BigEndian.PutUint16(tmp[2:], uint16(flags)) + checksum := Checksum(tmp, partialChecksum) + + // Encode the passed-in fields. + b.encodeSubset(seqnum, acknum, flags, rcvwnd) + + // Add the contributions of the passed-in fields to the checksum. + checksum = Checksum(b[seqNum:seqNum+8], checksum) + checksum = Checksum(b[winSize:winSize+2], checksum) + + // Encode the checksum. + b.SetChecksum(^checksum) +} + +// ParseSynOptions parses the options received in a SYN segment and returns the +// relevant ones. opts should point to the option part of the TCP Header. +func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { + limit := len(opts) + + synOpts := TCPSynOptions{ + // Per RFC 1122, page 85: "If an MSS option is not received at + // connection setup, TCP MUST assume a default send MSS of 536." + MSS: 536, + // If no window scale option is specified, WS in options is + // returned as -1; this is because the absence of the option + // indicates that the we cannot use window scaling on the + // receive end either. + WS: -1, + } + + for i := 0; i < limit; { + switch opts[i] { + case TCPOptionEOL: + i = limit + case TCPOptionNOP: + i++ + case TCPOptionMSS: + if i+4 > limit || opts[i+1] != 4 { + return synOpts + } + mss := uint16(opts[i+2])<<8 | uint16(opts[i+3]) + if mss == 0 { + return synOpts + } + synOpts.MSS = mss + i += 4 + + case TCPOptionWS: + if i+3 > limit || opts[i+1] != 3 { + return synOpts + } + ws := int(opts[i+2]) + if ws > MaxWndScale { + ws = MaxWndScale + } + synOpts.WS = ws + i += 3 + + case TCPOptionTS: + if i+10 > limit || opts[i+1] != 10 { + return synOpts + } + synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:]) + if isAck { + // If the segment is a SYN-ACK then store the Timestamp Echo Reply + // in the segment. + synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:]) + } + synOpts.TS = true + i += 10 + case TCPOptionSACKPermitted: + if i+2 > limit || opts[i+1] != 2 { + return synOpts + } + synOpts.SACKPermitted = true + i += 2 + + default: + // We don't recognize this option, just skip over it. + if i+2 > limit { + return synOpts + } + l := int(opts[i+1]) + // If the length is incorrect or if l+i overflows the + // total options length then return false. + if l < 2 || i+l > limit { + return synOpts + } + i += l + } + } + + return synOpts +} + +// ParseTCPOptions extracts and stores all known options in the provided byte +// slice in a TCPOptions structure. +func ParseTCPOptions(b []byte) TCPOptions { + opts := TCPOptions{} + limit := len(b) + for i := 0; i < limit; { + switch b[i] { + case TCPOptionEOL: + i = limit + case TCPOptionNOP: + i++ + case TCPOptionTS: + if i+10 > limit || (b[i+1] != 10) { + return opts + } + opts.TS = true + opts.TSVal = binary.BigEndian.Uint32(b[i+2:]) + opts.TSEcr = binary.BigEndian.Uint32(b[i+6:]) + i += 10 + case TCPOptionSACK: + if i+2 > limit { + // Malformed SACK block, just return and stop parsing. + return opts + } + sackOptionLen := int(b[i+1]) + if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 { + // Malformed SACK block, just return and stop parsing. + return opts + } + numBlocks := (sackOptionLen - 2) / 8 + opts.SACKBlocks = []SACKBlock{} + for j := 0; j < numBlocks; j++ { + start := binary.BigEndian.Uint32(b[i+2+j*8:]) + end := binary.BigEndian.Uint32(b[i+2+j*8+4:]) + opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{ + Start: seqnum.Value(start), + End: seqnum.Value(end), + }) + } + i += sackOptionLen + default: + // We don't recognize this option, just skip over it. + if i+2 > limit { + return opts + } + l := int(b[i+1]) + // If the length is incorrect or if l+i overflows the + // total options length then return false. + if l < 2 || i+l > limit { + return opts + } + i += l + } + } + return opts +} + +// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in +// the supplied buffer. If the provided buffer is not large enough then it just +// returns without encoding anything. It returns the number of bytes written to +// the provided buffer. +func EncodeMSSOption(mss uint32, b []byte) int { + // mssOptionSize is the number of bytes in a valid MSS option. + const mssOptionSize = 4 + + if len(b) < mssOptionSize { + return 0 + } + b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) + return mssOptionSize +} + +// EncodeWSOption encodes the WS TCP option with the WS value in the +// provided buffer. If the provided buffer is not large enough then it just +// returns without encoding anything. It returns the number of bytes written to +// the provided buffer. +func EncodeWSOption(ws int, b []byte) int { + if len(b) < 3 { + return 0 + } + b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) + return int(b[1]) +} + +// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp +// option into the provided buffer. If the buffer is smaller than expected it +// just returns without encoding anything. It returns the number of bytes +// written to the provided buffer. +func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { + if len(b) < 10 { + return 0 + } + b[0], b[1] = TCPOptionTS, 10 + binary.BigEndian.PutUint32(b[2:], tsVal) + binary.BigEndian.PutUint32(b[6:], tsEcr) + return int(b[1]) +} + +// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided +// buffer. If the buffer is smaller than required it just returns without +// encoding anything. It returns the number of bytes written to the provided +// buffer. +func EncodeSACKPermittedOption(b []byte) int { + if len(b) < 2 { + return 0 + } + + b[0], b[1] = TCPOptionSACKPermitted, 2 + return int(b[1]) +} + +// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block +// in the provided slice. It tries to fit in as many blocks as possible based on +// number of bytes available in the provided buffer. It returns the number of +// bytes written to the provided buffer. +func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int { + if len(sackBlocks) == 0 { + return 0 + } + l := len(sackBlocks) + if l > TCPMaxSACKBlocks { + l = TCPMaxSACKBlocks + } + if ll := (len(b) - 2) / 8; ll < l { + l = ll + } + if l == 0 { + // There is not enough space in the provided buffer to add + // any SACK blocks. + return 0 + } + b[0] = TCPOptionSACK + b[1] = byte(l*8 + 2) + for i := 0; i < l; i++ { + binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start)) + binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End)) + } + return int(b[1]) +} + +// EncodeNOP adds an explicit NOP to the option list. +func EncodeNOP(b []byte) int { + if len(b) == 0 { + return 0 + } + b[0] = TCPOptionNOP + return 1 +} + +// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align +// the option buffer. It adds padding bytes after the offset specified and +// returns the number of padding bytes added. The passed in options slice +// must have space for the padding bytes. +func AddTCPOptionPadding(options []byte, offset int) int { + paddingToAdd := -offset & 3 + // Now add any padding bytes that might be required to quad align the + // options. + for i := offset; i < offset+paddingToAdd; i++ { + options[i] = TCPOptionNOP + } + return paddingToAdd +} diff --git a/netstack/tcpip/header/tcp_test.go b/netstack/tcpip/header/tcp_test.go new file mode 100644 index 0000000..8d74434 --- /dev/null +++ b/netstack/tcpip/header/tcp_test.go @@ -0,0 +1,148 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header_test + +import ( + "reflect" + "testing" + + "tcpip/netstack/tcpip/header" +) + +func TestEncodeSACKBlocks(t *testing.T) { + testCases := []struct { + sackBlocks []header.SACKBlock + want []header.SACKBlock + bufSize int + }{ + { + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, + 40, + }, + { + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, + 30, + }, + { + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, + []header.SACKBlock{{10, 20}, {22, 30}}, + 20, + }, + { + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, + []header.SACKBlock{{10, 20}}, + 10, + }, + { + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, + nil, + 8, + }, + { + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}, + []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}}, + 60, + }, + } + for _, tc := range testCases { + b := make([]byte, tc.bufSize) + t.Logf("testing: %v", tc) + header.EncodeSACKBlocks(tc.sackBlocks, b) + opts := header.ParseTCPOptions(b) + if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want) + } + } +} + +func TestTCPParseOptions(t *testing.T) { + type tsOption struct { + tsVal uint32 + tsEcr uint32 + } + + generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte { + l := 0 + if tsOpt != nil { + l += 10 + } + if len(sackBlocks) != 0 { + l += len(sackBlocks)*8 + 2 + } + b := make([]byte, l) + offset := 0 + if tsOpt != nil { + offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b) + } + header.EncodeSACKBlocks(sackBlocks, b[offset:]) + return b + } + + testCases := []struct { + b []byte + want header.TCPOptions + }{ + // Trivial cases. + {nil, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, + + // Test timestamp parsing. + {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, + {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, + + // Test malformed timestamp option. + {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, + + // Test SACKBlock parsing. + {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}}, + {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}}, + + // Test malformed SACK option. + {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}}, + + // Test Timestamp + SACK block parsing. + {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}}, + {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}}, + {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}}, + + // Test valid timestamp + malformed SACK block parsing. + {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}}, + {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}}, + {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}}, + {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}}, + {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, + {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}}, + {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}}, + {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}}, + } + for _, tc := range testCases { + if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) { + t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want) + } + } +} diff --git a/netstack/tcpip/header/udp.go b/netstack/tcpip/header/udp.go new file mode 100644 index 0000000..c67f035 --- /dev/null +++ b/netstack/tcpip/header/udp.go @@ -0,0 +1,117 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package header + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" +) + +const ( + udpSrcPort = 0 + udpDstPort = 2 + udpLength = 4 + udpChecksum = 6 +) + +// UDPFields contains the fields of a UDP packet. It is used to describe the +// fields of a packet that needs to be encoded. +// udp 首部字段 +type UDPFields struct { + // SrcPort is the "source port" field of a UDP packet. + SrcPort uint16 + + // DstPort is the "destination port" field of a UDP packet. + DstPort uint16 + + // Length is the "length" field of a UDP packet. + Length uint16 + + // Checksum is the "checksum" field of a UDP packet. + Checksum uint16 +} + +// UDP represents a UDP header stored in a byte array. +type UDP []byte + +const ( + // UDPMinimumSize is the minimum size of a valid UDP packet. + UDPMinimumSize = 8 + + // UDPProtocolNumber is UDP's transport protocol number. + UDPProtocolNumber tcpip.TransportProtocolNumber = 17 +) + +// SourcePort returns the "source port" field of the udp header. +func (b UDP) SourcePort() uint16 { + return binary.BigEndian.Uint16(b[udpSrcPort:]) +} + +// DestinationPort returns the "destination port" field of the udp header. +func (b UDP) DestinationPort() uint16 { + return binary.BigEndian.Uint16(b[udpDstPort:]) +} + +// Length returns the "length" field of the udp header. +func (b UDP) Length() uint16 { + return binary.BigEndian.Uint16(b[udpLength:]) +} + +// Payload returns the data contained in the UDP datagram. +func (b UDP) Payload() []byte { + return b[UDPMinimumSize:] +} + +// Checksum returns the "checksum" field of the udp header. +func (b UDP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[udpChecksum:]) +} + +// SetSourcePort sets the "source port" field of the udp header. +func (b UDP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[udpSrcPort:], port) +} + +// SetDestinationPort sets the "destination port" field of the udp header. +func (b UDP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[udpDstPort:], port) +} + +// SetChecksum sets the "checksum" field of the udp header. +func (b UDP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[udpChecksum:], checksum) +} + +// CalculateChecksum calculates the checksum of the udp packet, given the total +// length of the packet and the checksum of the network-layer pseudo-header +// (excluding the total length) and the checksum of the payload. +func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { + // Add the length portion of the checksum to the pseudo-checksum. + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + checksum := Checksum(tmp, partialChecksum) + + // Calculate the rest of the checksum. + return Checksum(b[:UDPMinimumSize], checksum) +} + +// Encode encodes all the fields of the udp header. +func (b UDP) Encode(u *UDPFields) { + binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) + binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) + binary.BigEndian.PutUint16(b[udpLength:], u.Length) + binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) +} diff --git a/netstack/tcpip/link/channel/channel.go b/netstack/tcpip/link/channel/channel.go new file mode 100644 index 0000000..b5e95b2 --- /dev/null +++ b/netstack/tcpip/link/channel/channel.go @@ -0,0 +1,125 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package channel provides the implemention of channel-based data-link layer +// endpoints. Such endpoints allow injection of inbound packets and store +// outbound packets in a channel. +package channel + +import ( + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/stack" +) + +// PacketInfo holds all the information about an outbound packet. +type PacketInfo struct { + Header buffer.View + Payload buffer.View + Proto tcpip.NetworkProtocolNumber +} + +// Endpoint is link layer endpoint that stores outbound packets in a channel +// and allows injection of inbound packets. +type Endpoint struct { + dispatcher stack.NetworkDispatcher + mtu uint32 + linkAddr tcpip.LinkAddress + + // C is where outbound packets are queued. + C chan PacketInfo +} + +// New creates a new channel endpoint. +func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) { + e := &Endpoint{ + C: make(chan PacketInfo, size), + mtu: mtu, + linkAddr: linkAddr, + } + + return stack.RegisterLinkEndpoint(e), e +} + +// Drain removes all outbound packets from the channel and counts them. +func (e *Endpoint) Drain() int { + c := 0 + for { + select { + case <-e.C: + c++ + default: + return c + } + } +} + +// Inject injects an inbound packet. +func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { + e.InjectLinkAddr(protocol, "", vv) +} + +// InjectLinkAddr injects an inbound packet with a remote link address. +func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) { + e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, "" /* localLinkAddr */, protocol, vv.Clone(nil)) +} + +// Attach saves the stack network-layer dispatcher for use later when packets +// are injected. +func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *Endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +func (e *Endpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength returns the maximum size of the link layer header. Given it +// doesn't have a header, it just returns 0. +func (*Endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (e *Endpoint) LinkAddress() tcpip.LinkAddress { + return e.linkAddr +} + +// WritePacket stores outbound packets into the channel. +func (e *Endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + p := PacketInfo{ + Header: hdr.View(), + Proto: protocol, + Payload: payload.ToView(), + } + + select { + case e.C <- p: + default: + } + + return nil +} diff --git a/netstack/tcpip/link/fdbased/endpoint.go b/netstack/tcpip/link/fdbased/endpoint.go new file mode 100644 index 0000000..d20a165 --- /dev/null +++ b/netstack/tcpip/link/fdbased/endpoint.go @@ -0,0 +1,305 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package fdbased provides the implemention of data-link layer endpoints +// backed by boundary-preserving file descriptors (e.g., TUN devices, +// seqpacket/datagram sockets). +// +// FD based endpoints can be used in the networking stack by calling New() to +// create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +package fdbased + +import ( + "syscall" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/link/rawfile" + "tcpip/netstack/tcpip/stack" +) + +// BufConfig defines the shape of the vectorised view used to read packets from the NIC. +// 从NIC读取数据的多级缓存配置 +var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768} + +// 负责底层网卡的io读写以及数据分发 +type endpoint struct { + // fd is the file descriptor used to send and receive packets. + // 发送和接收数据的文件描述符 + fd int + + // mtu (maximum transmission unit) is the maximum size of a packet. + // 单个帧的最大长度 + mtu uint32 + + // hdrSize specifies the link-layer header size. If set to 0, no header + // is added/removed; otherwise an ethernet header is used. + // 以太网头部长度 + hdrSize int + + // addr is the address of the endpoint. + // 网卡地址 + addr tcpip.LinkAddress + + // caps holds the endpoint capabilities. + // 网卡的能力 + caps stack.LinkEndpointCapabilities + + // closed is a function to be called when the FD's peer (if any) closes + // its end of the communication pipe. + closed func(*tcpip.Error) + + iovecs []syscall.Iovec + views []buffer.View + dispatcher stack.NetworkDispatcher + + // handleLocal indicates whether packets destined to itself should be + // handled by the netstack internally (true) or be forwarded to the FD + // endpoint (false). + // handleLocal指示发往自身的数据包是由内部netstack处理(true)还是转发到FD端点(false)。 + // Resend packets back to netstack if destined to itself + // Add option to redirect packet back to netstack if it's destined to itself. + // This fixes the problem where connecting to the local NIC address would + // not work, e.g.: + // echo bar | nc -l -p 8080 & + // echo foo | nc 192.168.0.2 8080 + handleLocal bool +} + +// Options specify the details about the fd-based endpoint to be created. +// 创建fdbase端的一些选项参数 +type Options struct { + FD int + MTU uint32 + ClosedFunc func(*tcpip.Error) + Address tcpip.LinkAddress + ResolutionRequired bool + SaveRestore bool + ChecksumOffload bool + DisconnectOk bool + HandleLocal bool + TestLossPacket func(data []byte) bool +} + +// New creates a new fd-based endpoint. +// +// Makes fd non-blocking, but does not take ownership of fd, which must remain +// open for the lifetime of the returned endpoint. +// 根据选项参数创建一个链路层的endpoint,并返回该endpoint的id +func New(opts *Options) tcpip.LinkEndpointID { + syscall.SetNonblock(opts.FD, true) + + caps := stack.LinkEndpointCapabilities(0) + if opts.ResolutionRequired { + caps |= stack.CapabilityResolutionRequired + } + if opts.ChecksumOffload { + caps |= stack.CapabilityChecksumOffload + } + if opts.SaveRestore { + caps |= stack.CapabilitySaveRestore + } + if opts.DisconnectOk { + caps |= stack.CapabilityDisconnectOk + } + + e := &endpoint{ + fd: opts.FD, + mtu: opts.MTU, + caps: caps, + closed: opts.ClosedFunc, + addr: opts.Address, + hdrSize: header.EthernetMinimumSize, + views: make([]buffer.View, len(BufConfig)), + iovecs: make([]syscall.Iovec, len(BufConfig)), + handleLocal: opts.HandleLocal, + } + // 全局注册链路层设备 + return stack.RegisterLinkEndpoint(e) +} + +// Attach launches the goroutine that reads packets from the file descriptor and +// dispatches them via the provided dispatcher. +// Attach 启动从文件描述符中读取数据包的goroutine,并通过提供的分发函数来分发数据报。 +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + // Link endpoints are not savable. When transportation endpoints are + // saved, they stop sending outgoing packets and all incoming packets + // are rejected. + // 链接端点不可靠。保存传输端点后,它们将停止发送传出数据包,并拒绝所有传入数据包。 + go e.dispatchLoop() +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +// 判断是否Attach了 +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized +// during construction. +// 返回当前MTU值 +func (e *endpoint) MTU() uint32 { + return e.mtu +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +// 返回当前网卡的能力 +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.caps +} + +// MaxHeaderLength returns the maximum size of the link-layer header. +// 返回当前以太网头部信息长度 +func (e *endpoint) MaxHeaderLength() uint16 { + return uint16(e.hdrSize) +} + +// LinkAddress returns the link address of this endpoint. +// 返回当前MAC地址 +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.addr +} + +// WritePacket writes outbound packets to the file descriptor. If it is not +// currently writable, the packet is dropped. +// 将上层的报文经过链路层封装,写入网卡中,如果写入失败则丢弃该报文 +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, + protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + // 如果目标地址就是设备本身自己,那么将报文重新返回给协议栈 + if e.handleLocal && r.LocalAddress != "" && r.LocalAddress == r.RemoteAddress { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + e.dispatcher.DeliverNetworkPacket(e, r.RemoteLinkAddress, r.LocalLinkAddress, protocol, vv) + return nil + } + + // 封装增加以太网头部 + eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize)) + ethHdr := &header.EthernetFields{ + DstAddr: r.RemoteLinkAddress, + Type: protocol, + } + + // Preserve the src address if it's set in the route. + // 如果路由信息中有配置源MAC地址,那么使用该地址, + // 如果没有,则使用本网卡的地址 + if r.LocalLinkAddress != "" { + ethHdr.SrcAddr = r.LocalLinkAddress + } else { + ethHdr.SrcAddr = e.addr + } + eth.Encode(ethHdr) + + // 写入网卡中 + // log.Printf("write to nic %d bytes", hdr.UsedLength()+payload.Size()) + if payload.Size() == 0 { + return rawfile.NonBlockingWrite(e.fd, hdr.View()) + } + + return rawfile.NonBlockingWrite2(e.fd, hdr.View(), payload.ToView()) +} + +func (e *endpoint) capViews(n int, buffers []int) int { + c := 0 + for i, s := range buffers { + c += s + if c >= n { + e.views[i].CapLength(s - (c - n)) + return i + 1 + } + } + return len(buffers) +} + +// 按bufConfig的长度分配内存大小 +// 注意e.views和e.iovecs共用相同的内存块 +func (e *endpoint) allocateViews(bufConfig []int) { + for i, v := range e.views { + if v != nil { + break + } + b := buffer.NewView(bufConfig[i]) + e.views[i] = b + e.iovecs[i] = syscall.Iovec{ + Base: &b[0], + Len: uint64(len(b)), + } + } +} + +// dispatch reads one packet from the file descriptor and dispatches it. +// 从网卡中读取一个数据报, +func (e *endpoint) dispatch() (bool, *tcpip.Error) { + // 读取数据缓存的分配 + e.allocateViews(BufConfig) + + // 从网卡中读取数据 + n, err := rawfile.BlockingReadv(e.fd, e.iovecs) + if err != nil { + return false, err + } + + // 如果比头部长度还小,直接丢弃 + if n <= e.hdrSize { + return false, nil + } + + var ( + p tcpip.NetworkProtocolNumber + remoteLinkAddr, localLinkAddr tcpip.LinkAddress + ) + // 获取以太网头部信息 + eth := header.Ethernet(e.views[0]) + p = eth.Type() + remoteLinkAddr = eth.SourceAddress() + localLinkAddr = eth.DestinationAddress() + + used := e.capViews(n, BufConfig) + vv := buffer.NewVectorisedView(n, e.views[:used]) + // 将数据内容删除以太网头部信息,也就是将数据指针指向网络层的第一个字节 + vv.TrimFront(e.hdrSize) + + // 调用nic.DeliverNetworkPacket来分发网络层数据 + // log.Printf("read from nic %d bytes", e.hdrSize+vv.Size()) + e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv) + + // Prepare e.views for another packet: release used views. + for i := 0; i < used; i++ { + e.views[i] = nil + } + + return true, nil +} + +// dispatchLoop reads packets from the file descriptor in a loop and dispatches +// them to the network stack. +// 循环的从fd读取数据,然后将数据包分发给协议栈 +func (e *endpoint) dispatchLoop() *tcpip.Error { + for { + cont, err := e.dispatch() + if err != nil || !cont { + if e.closed != nil { + e.closed(err) + } + return err + } + } +} diff --git a/netstack/tcpip/link/fdbased/endpoint_test.go b/netstack/tcpip/link/fdbased/endpoint_test.go new file mode 100644 index 0000000..fa12168 --- /dev/null +++ b/netstack/tcpip/link/fdbased/endpoint_test.go @@ -0,0 +1,341 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package fdbased + +import ( + "bytes" + "fmt" + "math/rand" + "reflect" + "syscall" + "testing" + "time" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" +) + +const ( + mtu = 1500 + laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66") + raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc") + proto = 10 +) + +type packetInfo struct { + raddr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber + contents buffer.View +} + +type context struct { + t *testing.T + fds [2]int + ep stack.LinkEndpoint + ch chan packetInfo + done chan struct{} +} + +func newContext(t *testing.T, opt *Options) *context { + fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0) + if err != nil { + t.Fatalf("Socketpair failed: %v", err) + } + + done := make(chan struct{}, 1) + opt.ClosedFunc = func(*tcpip.Error) { + done <- struct{}{} + } + + opt.FD = fds[1] + ep := stack.FindLinkEndpoint(New(opt)).(*endpoint) + + c := &context{ + t: t, + fds: fds, + ep: ep, + ch: make(chan packetInfo, 100), + done: done, + } + + ep.Attach(c) + + return c +} + +func (c *context) cleanup() { + syscall.Close(c.fds[0]) + <-c.done + syscall.Close(c.fds[1]) +} + +func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { + c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()} +} + +func TestEthernetProperties(t *testing.T) { + c := newContext(t, &Options{MTU: mtu}) + defer c.cleanup() + + if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v { + t.Fatalf("MaxHeaderLength() = %v, want %v", v, want) + } + + if want, v := uint32(mtu), c.ep.MTU(); want != v { + t.Fatalf("MTU() = %v, want %v", v, want) + } +} + +func TestAddress(t *testing.T) { + addrs := []tcpip.LinkAddress{"", "abc", "def"} + for _, a := range addrs { + t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) { + c := newContext(t, &Options{Address: a, MTU: mtu}) + defer c.cleanup() + + if want, v := a, c.ep.LinkAddress(); want != v { + t.Fatalf("LinkAddress() = %v, want %v", v, want) + } + }) + } +} + +func TestWritePacket(t *testing.T) { + lengths := []int{0, 100, 1000} + for _, plen := range lengths { + t.Run(fmt.Sprintf("PayloadLen=%v", plen), func(t *testing.T) { + c := newContext(t, &Options{Address: laddr, MTU: mtu}) + defer c.cleanup() + + r := &stack.Route{ + RemoteLinkAddress: raddr, + } + + // Build header. + hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) + b := hdr.Prepend(100) + for i := range b { + b[i] = uint8(rand.Intn(256)) + } + + // Build payload and write. + payload := make(buffer.View, plen) + for i := range payload { + payload[i] = uint8(rand.Intn(256)) + } + want := append(hdr.View(), payload...) + if err := c.ep.WritePacket(r, hdr, payload.ToVectorisedView(), proto); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } + + // Read from fd, then compare with what we wrote. + b = make([]byte, mtu) + n, err := syscall.Read(c.fds[0], b) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + b = b[:n] + h := header.Ethernet(b) + b = b[header.EthernetMinimumSize:] + + if a := h.SourceAddress(); a != laddr { + t.Fatalf("SourceAddress() = %v, want %v", a, laddr) + } + + if a := h.DestinationAddress(); a != raddr { + t.Fatalf("DestinationAddress() = %v, want %v", a, raddr) + } + + if et := h.Type(); et != proto { + t.Fatalf("Type() = %v, want %v", et, proto) + } + if len(b) != len(want) { + t.Fatalf("Read returned %v bytes, want %v", len(b), len(want)) + } + if !bytes.Equal(b, want) { + t.Fatalf("Read returned %x, want %x", b, want) + } + }) + } +} + +func TestPreserveSrcAddress(t *testing.T) { + baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99") + + c := newContext(t, &Options{Address: laddr, MTU: mtu}) + defer c.cleanup() + + // Set LocalLinkAddress in route to the value of the bridged address. + r := &stack.Route{ + RemoteLinkAddress: raddr, + LocalLinkAddress: baddr, + } + + // WritePacket panics given a prependable with anything less than + // the minimum size of the ethernet header. + hdr := buffer.NewPrependable(header.EthernetMinimumSize) + if err := c.ep.WritePacket(r, hdr, buffer.VectorisedView{}, proto); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } + + // Read from the FD, then compare with what we wrote. + b := make([]byte, mtu) + n, err := syscall.Read(c.fds[0], b) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + b = b[:n] + h := header.Ethernet(b) + + if a := h.SourceAddress(); a != baddr { + t.Fatalf("SourceAddress() = %v, want %v", a, baddr) + } +} + +func TestDeliverPacket(t *testing.T) { + lengths := []int{100, 1000} + for _, plen := range lengths { + t.Run(fmt.Sprintf("PayloadLen=%v", plen), func(t *testing.T) { + c := newContext(t, &Options{Address: laddr, MTU: mtu}) + defer c.cleanup() + + // Build packet. + b := make([]byte, plen) + all := b + for i := range b { + b[i] = uint8(rand.Intn(256)) + } + + hdr := make(header.Ethernet, header.EthernetMinimumSize) + hdr.Encode(&header.EthernetFields{ + SrcAddr: raddr, + DstAddr: laddr, + Type: proto, + }) + all = append(hdr, b...) + + // Write packet via the file descriptor. + if _, err := syscall.Write(c.fds[0], all); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Receive packet through the endpoint. + select { + case pi := <-c.ch: + want := packetInfo{ + raddr: raddr, + proto: proto, + contents: b, + } + + if !reflect.DeepEqual(want, pi) { + t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want) + } + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for packet") + } + }) + } +} + +func TestBufConfigMaxLength(t *testing.T) { + got := 0 + for _, i := range BufConfig { + got += i + } + want := header.MaxIPPacketSize // maximum TCP packet size + if got < want { + t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want) + } +} + +func TestBufConfigFirst(t *testing.T) { + // The stack assumes that the TCP/IP header is enterily contained in the first view. + // Therefore, the first view needs to be large enough to contain the maximum TCP/IP + // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP). + want := 120 + got := BufConfig[0] + if got < want { + t.Errorf("first view has an invalid size: got %d, want >= %d", got, want) + } +} + +func build(bufConfig []int) *endpoint { + e := &endpoint{ + views: make([]buffer.View, len(bufConfig)), + iovecs: make([]syscall.Iovec, len(bufConfig)), + } + e.allocateViews(bufConfig) + return e +} + +var capLengthTestCases = []struct { + comment string + config []int + n int + wantUsed int + wantLengths []int +}{ + { + comment: "Single slice", + config: []int{2}, + n: 1, + wantUsed: 1, + wantLengths: []int{1}, + }, + { + comment: "Multiple slices", + config: []int{1, 2}, + n: 2, + wantUsed: 2, + wantLengths: []int{1, 1}, + }, + { + comment: "Entire buffer", + config: []int{1, 2}, + n: 3, + wantUsed: 2, + wantLengths: []int{1, 2}, + }, + { + comment: "Entire buffer but not on the last slice", + config: []int{1, 2, 3}, + n: 3, + wantUsed: 2, + wantLengths: []int{1, 2, 3}, + }, +} + +func TestCapLength(t *testing.T) { + for _, c := range capLengthTestCases { + e := build(c.config) + used := e.capViews(c.n, c.config) + if used != c.wantUsed { + t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed) + } + lengths := make([]int, len(e.views)) + for i, v := range e.views { + lengths[i] = len(v) + } + if !reflect.DeepEqual(lengths, c.wantLengths) { + t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths) + } + + } +} diff --git a/netstack/tcpip/link/loopback/loopback.go b/netstack/tcpip/link/loopback/loopback.go new file mode 100644 index 0000000..9240d9f --- /dev/null +++ b/netstack/tcpip/link/loopback/loopback.go @@ -0,0 +1,87 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package loopback provides the implemention of loopback data-link layer +// endpoints. Such endpoints just turn outbound packets into inbound ones. +// +// Loopback endpoints can be used in the networking stack by calling New() to +// create a new endpoint, and then passing it as an argument to +// Stack.CreateNIC(). +package loopback + +import ( + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/stack" +) + +type endpoint struct { + dispatcher stack.NetworkDispatcher +} + +// New creates a new loopback endpoint. This link-layer endpoint just turns +// outbound packets into inbound packets. +func New() tcpip.LinkEndpointID { + return stack.RegisterLinkEndpoint(&endpoint{}) +} + +// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network- +// layer dispatcher for later use when packets need to be dispatched. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the +// linux loopback interface. +func (*endpoint) MTU() uint32 { + return 65536 +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises +// itself as supporting checksum offload, but in reality it's just omitted. +func (*endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback +} + +// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the +// loopback interface doesn't have a header, it just returns 0. +func (*endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*endpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound +// packets to the network-layer dispatcher. +func (e *endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + + // Because we're immediately turning around and writing the packet back to the + // rx path, we intentionally don't preserve the remote and local link + // addresses from the stack.Route we're passed. + e.dispatcher.DeliverNetworkPacket(e, "" /* remoteLinkAddr */, "" /* localLinkAddr */, protocol, vv) + + return nil +} diff --git a/netstack/tcpip/link/rawfile/blockingpoll_unsafe.go b/netstack/tcpip/link/rawfile/blockingpoll_unsafe.go new file mode 100644 index 0000000..a2dfc20 --- /dev/null +++ b/netstack/tcpip/link/rawfile/blockingpoll_unsafe.go @@ -0,0 +1,27 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "syscall" + "unsafe" +) + +func blockingPoll(fds *pollEvent, nfds int, timeout int64) (int, syscall.Errno) { + n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout)) + return int(n), e +} diff --git a/netstack/tcpip/link/rawfile/errors.go b/netstack/tcpip/link/rawfile/errors.go new file mode 100644 index 0000000..0a0b56d --- /dev/null +++ b/netstack/tcpip/link/rawfile/errors.go @@ -0,0 +1,70 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +package rawfile + +import ( + "fmt" + "syscall" + + "tcpip/netstack/tcpip" +) + +const maxErrno = 134 + +var translations [maxErrno]*tcpip.Error + +// TranslateErrno translate an errno from the syscall package into a +// *tcpip.Error. +// +// Valid, but unreconigized errnos will be translated to +// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos. +func TranslateErrno(e syscall.Errno) *tcpip.Error { + if err := translations[e]; err != nil { + return err + } + return tcpip.ErrInvalidEndpointState +} + +func addTranslation(host syscall.Errno, trans *tcpip.Error) { + if translations[host] != nil { + panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host)) + } + translations[host] = trans +} + +func init() { + addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress) + addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute) + addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState) + addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting) + addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected) + addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse) + addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress) + addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend) + addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock) + addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused) + addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout) + addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted) + addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired) + addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported) + addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported) + addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected) + addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset) + addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted) + addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong) + addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace) +} diff --git a/netstack/tcpip/link/rawfile/rawfile_unsafe.go b/netstack/tcpip/link/rawfile/rawfile_unsafe.go new file mode 100644 index 0000000..1dc974a --- /dev/null +++ b/netstack/tcpip/link/rawfile/rawfile_unsafe.go @@ -0,0 +1,145 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build linux + +// Package rawfile contains utilities for using the netstack with raw host +// files on Linux hosts. +package rawfile + +import ( + "syscall" + "unsafe" + + "tcpip/netstack/tcpip" +) + +// GetMTU determines the MTU of a network interface device. +func GetMTU(name string) (uint32, error) { + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + return 0, err + } + + defer syscall.Close(fd) + + var ifreq struct { + name [16]byte + mtu int32 + _ [20]byte + } + + copy(ifreq.name[:], name) + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq))) + if errno != 0 { + return 0, errno + } + + return uint32(ifreq.mtu), nil +} + +// NonBlockingWrite writes the given buffer to a file descriptor. It fails if +// partial data is written. +func NonBlockingWrite(fd int, buf []byte) *tcpip.Error { + var ptr unsafe.Pointer + if len(buf) > 0 { + ptr = unsafe.Pointer(&buf[0]) + } + + _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf))) + if e != 0 { + return TranslateErrno(e) + } + + return nil +} + +// NonBlockingWrite2 writes up to two byte slices to a file descriptor in a +// single syscall. It fails if partial data is written. +func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error { + // If the is no second buffer, issue a regular write. + if len(b2) == 0 { + return NonBlockingWrite(fd, b1) + } + + // We have two buffers. Build the iovec that represents them and issue + // a writev syscall. + iovec := [...]syscall.Iovec{ + { + Base: &b1[0], + Len: uint64(len(b1)), + }, + { + Base: &b2[0], + Len: uint64(len(b2)), + }, + } + + _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec))) + if e != 0 { + return TranslateErrno(e) + } + + return nil +} + +type pollEvent struct { + fd int32 + events int16 + revents int16 +} + +// BlockingRead reads from a file descriptor that is set up as non-blocking. If +// no data is available, it will block in a poll() syscall until the file +// descirptor becomes readable. +func BlockingRead(fd int, b []byte) (int, *tcpip.Error) { + for { + n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b))) + if e == 0 { + return int(n), nil + } + + event := pollEvent{ + fd: int32(fd), + events: 1, // POLLIN + } + + _, e = blockingPoll(&event, 1, -1) + if e != 0 && e != syscall.EINTR { + return 0, TranslateErrno(e) + } + } +} + +// BlockingReadv reads from a file descriptor that is set up as non-blocking and +// stores the data in a list of iovecs buffers. If no data is available, it will +// block in a poll() syscall until the file descirptor becomes readable. +func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) { + for { + n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) + if e == 0 { + return int(n), nil + } + + event := pollEvent{ + fd: int32(fd), + events: 1, // POLLIN + } + + _, e = blockingPoll(&event, 1, -1) + if e != 0 && e != syscall.EINTR { + return 0, TranslateErrno(e) + } + } +} diff --git a/netstack/tcpip/link/sniffer/pcap.go b/netstack/tcpip/link/sniffer/pcap.go new file mode 100644 index 0000000..7a827a7 --- /dev/null +++ b/netstack/tcpip/link/sniffer/pcap.go @@ -0,0 +1,66 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sniffer + +import "time" + +type pcapHeader struct { + // MagicNumber is the file magic number. + MagicNumber uint32 + + // VersionMajor is the major version number. + VersionMajor uint16 + + // VersionMinor is the minor version number. + VersionMinor uint16 + + // Thiszone is the GMT to local correction. + Thiszone int32 + + // Sigfigs is the accuracy of timestamps. + Sigfigs uint32 + + // Snaplen is the max length of captured packets, in octets. + Snaplen uint32 + + // Network is the data link type. + Network uint32 +} + +const pcapPacketHeaderLen = 16 + +type pcapPacketHeader struct { + // Seconds is the timestamp seconds. + Seconds uint32 + + // Microseconds is the timestamp microseconds. + Microseconds uint32 + + // IncludedLength is the number of octets of packet saved in file. + IncludedLength uint32 + + // OriginalLength is the actual length of packet. + OriginalLength uint32 +} + +func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader { + now := time.Now() + return pcapPacketHeader{ + Seconds: uint32(now.Unix()), + Microseconds: uint32(now.Nanosecond() / 1000), + IncludedLength: incLen, + OriginalLength: orgLen, + } +} diff --git a/netstack/tcpip/link/sniffer/sniffer.go b/netstack/tcpip/link/sniffer/sniffer.go new file mode 100644 index 0000000..9dba4fa --- /dev/null +++ b/netstack/tcpip/link/sniffer/sniffer.go @@ -0,0 +1,390 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sniffer provides the implementation of data-link layer endpoints that +// wrap another endpoint and logs inbound and outbound packets. +// +// Sniffer endpoints can be used in the networking stack by calling New(eID) to +// create a new endpoint, where eID is the ID of the endpoint being wrapped, +// and then passing it as an argument to Stack.CreateNIC(). +package sniffer + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "os" + "sync/atomic" + "time" + + "log" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" +) + +// LogPackets is a flag used to enable or disable packet logging via the log +// package. Valid values are 0 or 1. +// +// LogPackets must be accessed atomically. +var LogPackets uint32 = 1 + +// LogPacketsToFile is a flag used to enable or disable logging packets to a +// pcap file. Valid values are 0 or 1. A file must have been specified when the +// sniffer was created for this flag to have effect. +// +// LogPacketsToFile must be accessed atomically. +var LogPacketsToFile uint32 = 1 + +type endpoint struct { + dispatcher stack.NetworkDispatcher + lower stack.LinkEndpoint + file *os.File + maxPCAPLen uint32 +} + +// New creates a new sniffer link-layer endpoint. It wraps around another +// endpoint and logs packets and they traverse the endpoint. +func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID { + return stack.RegisterLinkEndpoint(&endpoint{ + lower: stack.FindLinkEndpoint(lower), + }) +} + +func zoneOffset() (int32, error) { + loc, err := time.LoadLocation("Local") + if err != nil { + return 0, err + } + date := time.Date(0, 0, 0, 0, 0, 0, 0, loc) + _, offset := date.Zone() + return int32(offset), nil +} + +func writePCAPHeader(w io.Writer, maxLen uint32) error { + offset, err := zoneOffset() + if err != nil { + return err + } + return binary.Write(w, binary.BigEndian, pcapHeader{ + // From https://wiki.wireshark.org/Development/LibpcapFileFormat + MagicNumber: 0xa1b2c3d4, + + VersionMajor: 2, + VersionMinor: 4, + Thiszone: offset, + Sigfigs: 0, + Snaplen: maxLen, + Network: 101, // LINKTYPE_RAW + }) +} + +// NewWithFile creates a new sniffer link-layer endpoint. It wraps around +// another endpoint and logs packets and they traverse the endpoint. +// +// Packets can be logged to file in the pcap format. A sniffer created +// with this function will not emit packets using the standard log +// package. +// +// snapLen is the maximum amount of a packet to be saved. Packets with a length +// less than or equal too snapLen will be saved in their entirety. Longer +// packets will be truncated to snapLen. +func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) { + if err := writePCAPHeader(file, snapLen); err != nil { + return 0, err + } + return stack.RegisterLinkEndpoint(&endpoint{ + lower: stack.FindLinkEndpoint(lower), + file: file, + maxPCAPLen: snapLen, + }), nil +} + +// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is +// called by the link-layer endpoint being wrapped when a packet arrives, and +// logs the packet before forwarding to the actual dispatcher. +func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { + if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { + logPacket("recv", protocol, vv.First()) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + vs := vv.Views() + length := vv.Size() + if length > int(e.maxPCAPLen) { + length = int(e.maxPCAPLen) + } + + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil { + panic(err) + } + for _, v := range vs { + if length == 0 { + break + } + if len(v) > length { + v = v[:length] + } + if _, err := buf.Write([]byte(v)); err != nil { + panic(err) + } + length -= len(v) + } + if _, err := e.file.Write(buf.Bytes()); err != nil { + panic(err) + } + } + e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, protocol, vv) +} + +// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher +// and registers with the lower endpoint as its dispatcher so that "e" is called +// for inbound packets. +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher + e.lower.Attach(e) +} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the +// lower endpoint. +func (e *endpoint) MTU() uint32 { + return e.lower.MTU() +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the +// request to the lower endpoint. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.lower.Capabilities() +} + +// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards +// the request to the lower endpoint. +func (e *endpoint) MaxHeaderLength() uint16 { + return e.lower.MaxHeaderLength() +} + +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return e.lower.LinkAddress() +} + +// WritePacket implements the stack.LinkEndpoint interface. It is called by +// higher-level protocols to write packets; it just logs the packet and forwards +// the request to the lower endpoint. +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil { + logPacket("send", protocol, hdr.View()) + } + if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 { + hdrBuf := hdr.View() + length := len(hdrBuf) + payload.Size() + if length > int(e.maxPCAPLen) { + length = int(e.maxPCAPLen) + } + + buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length)) + if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil { + panic(err) + } + if len(hdrBuf) > length { + hdrBuf = hdrBuf[:length] + } + if _, err := buf.Write(hdrBuf); err != nil { + panic(err) + } + length -= len(hdrBuf) + if length > 0 { + for _, v := range payload.Views() { + if len(v) > length { + v = v[:length] + } + n, err := buf.Write(v) + if err != nil { + panic(err) + } + length -= n + if length == 0 { + break + } + } + } + if _, err := e.file.Write(buf.Bytes()); err != nil { + panic(err) + } + } + return e.lower.WritePacket(r, hdr, payload, protocol) +} + +func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) { + // Figure out the network layer info. + var transProto uint8 + src := tcpip.Address("unknown") + dst := tcpip.Address("unknown") + id := 0 + size := uint16(0) + switch protocol { + case header.IPv4ProtocolNumber: + ipv4 := header.IPv4(b) + src = ipv4.SourceAddress() + dst = ipv4.DestinationAddress() + transProto = ipv4.Protocol() + size = ipv4.TotalLength() - uint16(ipv4.HeaderLength()) + b = b[ipv4.HeaderLength():] + id = int(ipv4.ID()) + + case header.IPv6ProtocolNumber: + ipv6 := header.IPv6(b) + src = ipv6.SourceAddress() + dst = ipv6.DestinationAddress() + transProto = ipv6.NextHeader() + size = ipv6.PayloadLength() + b = b[header.IPv6MinimumSize:] + + case header.ARPProtocolNumber: + arp := header.ARP(b) + log.Printf( + "%s arp %v (%v) -> %v (%v) valid:%v", + prefix, + tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()), + tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()), + arp.IsValid(), + ) + return + default: + log.Printf("%s unknown network protocol", prefix) + return + } + + // Figure out the transport layer info. + transName := "unknown" + srcPort := uint16(0) + dstPort := uint16(0) + details := "" + switch tcpip.TransportProtocolNumber(transProto) { + case header.ICMPv4ProtocolNumber: + transName = "icmp" + icmp := header.ICMPv4(b) + icmpType := "unknown" + switch icmp.Type() { + case header.ICMPv4EchoReply: + icmpType = "echo reply" + case header.ICMPv4DstUnreachable: + icmpType = "destination unreachable" + case header.ICMPv4SrcQuench: + icmpType = "source quench" + case header.ICMPv4Redirect: + icmpType = "redirect" + case header.ICMPv4Echo: + icmpType = "echo" + case header.ICMPv4TimeExceeded: + icmpType = "time exceeded" + case header.ICMPv4ParamProblem: + icmpType = "param problem" + case header.ICMPv4Timestamp: + icmpType = "timestamp" + case header.ICMPv4TimestampReply: + icmpType = "timestamp reply" + case header.ICMPv4InfoRequest: + icmpType = "info request" + case header.ICMPv4InfoReply: + icmpType = "info reply" + } + log.Printf("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + return + + case header.ICMPv6ProtocolNumber: + transName = "icmp" + icmp := header.ICMPv6(b) + icmpType := "unknown" + switch icmp.Type() { + case header.ICMPv6DstUnreachable: + icmpType = "destination unreachable" + case header.ICMPv6PacketTooBig: + icmpType = "packet too big" + case header.ICMPv6TimeExceeded: + icmpType = "time exceeded" + case header.ICMPv6ParamProblem: + icmpType = "param problem" + case header.ICMPv6EchoRequest: + icmpType = "echo request" + case header.ICMPv6EchoReply: + icmpType = "echo reply" + case header.ICMPv6RouterSolicit: + icmpType = "router solicit" + case header.ICMPv6RouterAdvert: + icmpType = "router advert" + case header.ICMPv6NeighborSolicit: + icmpType = "neighbor solicit" + case header.ICMPv6NeighborAdvert: + icmpType = "neighbor advert" + case header.ICMPv6RedirectMsg: + icmpType = "redirect message" + } + log.Printf("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code()) + return + + case header.UDPProtocolNumber: + transName = "udp" + udp := header.UDP(b) + srcPort = udp.SourcePort() + dstPort = udp.DestinationPort() + size -= header.UDPMinimumSize + + details = fmt.Sprintf("xsum: 0x%x", udp.Checksum()) + + case header.TCPProtocolNumber: + transName = "tcp" + tcp := header.TCP(b) + offset := int(tcp.DataOffset()) + if offset < header.TCPMinimumSize { + details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset) + break + } + if offset > len(tcp) { + details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp)) + break + } + + srcPort = tcp.SourcePort() + dstPort = tcp.DestinationPort() + size -= uint16(offset) + + // Initialize the TCP flags. + flags := tcp.Flags() + flagsStr := []byte("FSRPAU") + for i := range flagsStr { + if flags&(1< %v unknown transport protocol: %d", prefix, src, dst, transProto) + return + } + + log.Printf("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details) +} diff --git a/netstack/tcpip/link/tuntap/tuntap.go b/netstack/tcpip/link/tuntap/tuntap.go new file mode 100644 index 0000000..7d8c39d --- /dev/null +++ b/netstack/tcpip/link/tuntap/tuntap.go @@ -0,0 +1,141 @@ +// +build linux + +package tuntap + +import ( + "errors" + "fmt" + "os/exec" + "syscall" + "unsafe" +) + +const ( + TUN = 1 + TAP = 2 +) + +var ( + ErrDeviceMode = errors.New("unsupport device mode") +) + +type rawSockaddr struct { + Family uint16 + Data [14]byte +} + +// 虚拟网卡设置的配置 +type Config struct { + Name string // 网卡名 + Mode int // 网卡模式,TUN or TAP +} + +// NewNetDev根据配置返回虚拟网卡的文件描述符 +func NewNetDev(c *Config) (fd int, err error) { + switch c.Mode { + case TUN: + fd, err = newTun(c.Name) + case TAP: + fd, err = newTAP(c.Name) + default: + err = ErrDeviceMode + return + } + if err != nil { + return + } + return +} + +// SetLinkUp 让系统启动该网卡 +func SetLinkUp(name string) (err error) { + // ip link set up + out, cmdErr := exec.Command("ip", "link", "set", name, "up").CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + return + } + return +} + +// SetRoute 通过ip命令添加路由 +func SetRoute(name, cidr string) (err error) { + // ip route add 192.168.1.0/24 dev tap0 + out, cmdErr := exec.Command("ip", "route", "add", cidr, "dev", name).CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + return + } + return +} + +// AddIP 通过ip命令添加IP地址 +func AddIP(name, ip string) (err error) { + // ip addr add 192.168.1.1 dev tap0 + out, cmdErr := exec.Command("ip", "addr", "add", ip, "dev", name).CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + return + } + return +} + +func GetHardwareAddr(name string) (string, error) { + fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0) + if err != nil { + return "", err + } + + defer syscall.Close(fd) + + var ifreq struct { + name [16]byte + addr rawSockaddr + _ [8]byte + } + + copy(ifreq.name[:], name) + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFHWADDR, uintptr(unsafe.Pointer(&ifreq))) + if errno != 0 { + return "", errno + } + + mac := ifreq.addr.Data[:6] + return string(mac[:]), nil +} + +// newTun新建一个tun模式的虚拟网卡,然后返回该网卡的文件描述符 +// IFF_NO_PI表示不需要包信息 +func newTun(name string) (int, error) { + return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI) +} + +// newTAP新建一个tap模式的虚拟网卡,然后返回该网卡的文件描述符 +func newTAP(name string) (int, error) { + return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI) +} + +// 先打开一个字符串设备,通过系统调用将虚拟网卡和字符串设备fd绑定在一起 +func open(name string, flags uint16) (int, error) { + // 打开tuntap的字符设备,得到字符设备的文件描述符 + fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0) + if err != nil { + return -1, err + } + + var ifr struct { + name [16]byte + flags uint16 + _ [22]byte + } + + copy(ifr.name[:], name) + ifr.flags = flags + // 通过ioctl系统调用,将fd和虚拟网卡驱动绑定在一起 + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) + if errno != 0 { + syscall.Close(fd) + return -1, errno + } + return fd, nil +} diff --git a/netstack/tcpip/network/arp/arp.go b/netstack/tcpip/network/arp/arp.go new file mode 100644 index 0000000..5e92bc4 --- /dev/null +++ b/netstack/tcpip/network/arp/arp.go @@ -0,0 +1,196 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package arp implements the ARP network protocol. It is used to resolve +// IPv4 addresses into link-local MAC addresses, and advertises IPv4 +// addresses of its stack with the local network. +// +// To use it in the networking stack, pass arp.ProtocolName as one of the +// network protocols when calling stack.New. Then add an "arp" address to +// every NIC on the stack that should respond to ARP requests. That is: +// +// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil { +// // handle err +// } +package arp + +import ( + "log" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ARP protocol name. + ProtocolName = "arp" + + // ProtocolNumber is the ARP protocol number. + ProtocolNumber = header.ARPProtocolNumber + + // ProtocolAddress is the address expected by the ARP endpoint. + ProtocolAddress = tcpip.Address("arp") +) + +// endpoint implements stack.NetworkEndpoint. +// 表示arp端 +type endpoint struct { + nicid tcpip.NICID + addr tcpip.Address + linkEP stack.LinkEndpoint + linkAddrCache stack.LinkAddressCache +} + +// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint. +func (e *endpoint) DefaultTTL() uint8 { + return 0 +} + +func (e *endpoint) MTU() uint32 { + lmtu := e.linkEP.MTU() + return lmtu - uint32(e.MaxHeaderLength()) +} + +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &stack.NetworkEndpointID{ProtocolAddress} +} + +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.ARPSize +} + +func (e *endpoint) Close() {} + +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// arp数据包的处理,包括arp请求和响应 +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { + v := vv.First() + h := header.ARP(v) + if !h.IsValid() { + return + } + + // 判断操作码类型 + switch h.Op() { + case header.ARPRequest: + // 如果是arp请求 + log.Printf("recv arp request") + localAddr := tcpip.Address(h.ProtocolAddressTarget()) + if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + return // we have no useful answer, ignore the request + } + hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) + pkt := header.ARP(hdr.Prepend(header.ARPSize)) + pkt.SetIPv4OverEthernet() + pkt.SetOp(header.ARPReply) + copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) + copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) + log.Printf("send arp reply") + e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) + // 注意这里的 fallthrough 表示需要继续执行下面分支的代码 + // 当收到 arp 请求需要添加到链路地址缓存中 + fallthrough // also fill the cache from requests + case header.ARPReply: + // 这里记录ip和mac对应关系,也就是arp表 + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) + } +} + +// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. +// 实现了 stack.NetworkProtocol 和 stack.LinkAddressResolver 两个接口 +type protocol struct { +} + +func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } +func (p *protocol) MinimumPacketSize() int { return header.ARPSize } + +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.ARP(v) + return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +} + +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + if addr != ProtocolAddress { + return nil, tcpip.ErrBadLocalAddress + } + return &endpoint{ + nicid: nicid, + addr: addr, + linkEP: sender, + linkAddrCache: linkAddrCache, + }, nil +} + +// LinkAddressProtocol implements stack.LinkAddressResolver. +func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv4ProtocolNumber +} + +// LinkAddressRequest implements stack.LinkAddressResolver. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { + r := &stack.Route{ + RemoteLinkAddress: broadcastMAC, + } + + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) + h := header.ARP(hdr.Prepend(header.ARPSize)) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), linkEP.LinkAddress()) + copy(h.ProtocolAddressSender(), localAddr) + copy(h.ProtocolAddressTarget(), addr) + + return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) +} + +// ResolveStaticAddress implements stack.LinkAddressResolver. +func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\xff\xff\xff\xff" { + return broadcastMAC, true + } + return "", false +} + +// SetOption implements NetworkProtocol. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + +func init() { + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/netstack/tcpip/network/arp/arp_test.go b/netstack/tcpip/network/arp/arp_test.go new file mode 100644 index 0000000..7871db8 --- /dev/null +++ b/netstack/tcpip/network/arp/arp_test.go @@ -0,0 +1,149 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arp_test + +import ( + "testing" + "time" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/link/channel" + "tcpip/netstack/tcpip/link/sniffer" + "tcpip/netstack/tcpip/network/arp" + "tcpip/netstack/tcpip/network/ipv4" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/tcpip/transport/ping" +) + +const ( + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") +) + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack +} + +func newTestContext(t *testing.T) *testContext { + s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ping.ProtocolName4}, stack.Options{}) + + const defaultMTU = 65536 + id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("AddAddress for arp failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }}) + + return &testContext{ + t: t, + s: s, + linkEP: linkEP, + } +} + +func (c *testContext) cleanup() { + close(c.linkEP.C) +} + +func TestDirectRequest(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + + const senderMAC = "\x01\x02\x03\x04\x05\x06" + const senderIPv4 = "\x0a\x00\x00\x02" + + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), senderMAC) + copy(h.ProtocolAddressSender(), senderIPv4) + + inject := func(addr tcpip.Address) { + copy(h.ProtocolAddressTarget(), addr) + c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) + } + + inject(stackAddr1) + { + pkt := <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { + t.Errorf("stackAddr1: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) + } + } + + inject(stackAddr2) + { + pkt := <-c.linkEP.C + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { + t.Errorf("stackAddr2: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) + } + } + + inject(stackAddrBad) + select { + case pkt := <-c.linkEP.C: + t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) + case <-time.After(100 * time.Millisecond): + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + } +} diff --git a/netstack/tcpip/network/fragmentation/frag_heap.go b/netstack/tcpip/network/fragmentation/frag_heap.go new file mode 100644 index 0000000..43ce2d4 --- /dev/null +++ b/netstack/tcpip/network/fragmentation/frag_heap.go @@ -0,0 +1,77 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "container/heap" + "fmt" + + "tcpip/netstack/tcpip/buffer" +) + +type fragment struct { + offset uint16 + vv buffer.VectorisedView +} + +type fragHeap []fragment + +func (h *fragHeap) Len() int { + return len(*h) +} + +func (h *fragHeap) Less(i, j int) bool { + return (*h)[i].offset < (*h)[j].offset +} + +func (h *fragHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +func (h *fragHeap) Push(x interface{}) { + *h = append(*h, x.(fragment)) +} + +func (h *fragHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[:n-1] + return x +} + +// reassamble empties the heap and returns a VectorisedView +// containing a reassambled version of the fragments inside the heap. +func (h *fragHeap) reassemble() (buffer.VectorisedView, error) { + curr := heap.Pop(h).(fragment) + views := curr.vv.Views() + size := curr.vv.Size() + + if curr.offset != 0 { + return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset) + } + + for h.Len() > 0 { + curr := heap.Pop(h).(fragment) + if int(curr.offset) < size { + curr.vv.TrimFront(size - int(curr.offset)) + } else if int(curr.offset) > size { + return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset) + } + size += curr.vv.Size() + views = append(views, curr.vv.Views()...) + } + return buffer.NewVectorisedView(size, views), nil +} diff --git a/netstack/tcpip/network/fragmentation/frag_heap_test.go b/netstack/tcpip/network/fragmentation/frag_heap_test.go new file mode 100644 index 0000000..1d0ce51 --- /dev/null +++ b/netstack/tcpip/network/fragmentation/frag_heap_test.go @@ -0,0 +1,126 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "container/heap" + "reflect" + "testing" + + "tcpip/netstack/tcpip/buffer" +) + +var reassambleTestCases = []struct { + comment string + in []fragment + want buffer.VectorisedView +}{ + { + comment: "Non-overlapping in-order", + in: []fragment{ + {offset: 0, vv: vv(1, "0")}, + {offset: 1, vv: vv(1, "1")}, + }, + want: vv(2, "0", "1"), + }, + { + comment: "Non-overlapping out-of-order", + in: []fragment{ + {offset: 1, vv: vv(1, "1")}, + {offset: 0, vv: vv(1, "0")}, + }, + want: vv(2, "0", "1"), + }, + { + comment: "Duplicated packets", + in: []fragment{ + {offset: 0, vv: vv(1, "0")}, + {offset: 0, vv: vv(1, "0")}, + }, + want: vv(1, "0"), + }, + { + comment: "Overlapping in-order", + in: []fragment{ + {offset: 0, vv: vv(2, "01")}, + {offset: 1, vv: vv(2, "12")}, + }, + want: vv(3, "01", "2"), + }, + { + comment: "Overlapping out-of-order", + in: []fragment{ + {offset: 1, vv: vv(2, "12")}, + {offset: 0, vv: vv(2, "01")}, + }, + want: vv(3, "01", "2"), + }, + { + comment: "Overlapping subset in-order", + in: []fragment{ + {offset: 0, vv: vv(3, "012")}, + {offset: 1, vv: vv(1, "1")}, + }, + want: vv(3, "012"), + }, + { + comment: "Overlapping subset out-of-order", + in: []fragment{ + {offset: 1, vv: vv(1, "1")}, + {offset: 0, vv: vv(3, "012")}, + }, + want: vv(3, "012"), + }, +} + +func TestReassamble(t *testing.T) { + for _, c := range reassambleTestCases { + t.Run(c.comment, func(t *testing.T) { + h := make(fragHeap, 0, 8) + heap.Init(&h) + for _, f := range c.in { + heap.Push(&h, f) + } + got, err := h.reassemble() + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(got, c.want) { + t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want) + } + }) + } +} + +func TestReassambleFailsForNonZeroOffset(t *testing.T) { + h := make(fragHeap, 0, 8) + heap.Init(&h) + heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")}) + _, err := h.reassemble() + if err == nil { + t.Errorf("reassemble() did not fail when the first packet had offset != 0") + } +} + +func TestReassambleFailsForHoles(t *testing.T) { + h := make(fragHeap, 0, 8) + heap.Init(&h) + heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")}) + heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")}) + _, err := h.reassemble() + if err == nil { + t.Errorf("reassemble() did not fail when there was a hole in the packet") + } +} diff --git a/netstack/tcpip/network/fragmentation/fragmentation.go b/netstack/tcpip/network/fragmentation/fragmentation.go new file mode 100644 index 0000000..ea17564 --- /dev/null +++ b/netstack/tcpip/network/fragmentation/fragmentation.go @@ -0,0 +1,134 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package fragmentation contains the implementation of IP fragmentation. +// It is based on RFC 791 and RFC 815. +package fragmentation + +import ( + "log" + "sync" + "time" + + "tcpip/netstack/tcpip/buffer" +) + +// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time. +const DefaultReassembleTimeout = 30 * time.Second + +// HighFragThreshold is the threshold at which we start trimming old +// fragmented packets. Linux uses a default value of 4 MB. See +// net.ipv4.ipfrag_high_thresh for more information. +const HighFragThreshold = 4 << 20 // 4MB + +// LowFragThreshold is the threshold we reach to when we start dropping +// older fragmented packets. It's important that we keep enough room for newer +// packets to be re-assembled. Hence, this needs to be lower than +// HighFragThreshold enough. Linux uses a default value of 3 MB. See +// net.ipv4.ipfrag_low_thresh for more information. +const LowFragThreshold = 3 << 20 // 3MB + +// Fragmentation is the main structure that other modules +// of the stack should use to implement IP Fragmentation. +type Fragmentation struct { + mu sync.Mutex + highLimit int + lowLimit int + reassemblers map[uint32]*reassembler + rList reassemblerList + size int + timeout time.Duration +} + +// NewFragmentation creates a new Fragmentation. +// +// highMemoryLimit specifies the limit on the memory consumed +// by the fragments stored by Fragmentation (overhead of internal data-structures +// is not accounted). Fragments are dropped when the limit is reached. +// +// lowMemoryLimit specifies the limit on which we will reach by dropping +// fragments after reaching highMemoryLimit. +// +// reassemblingTimeout specifes the maximum time allowed to reassemble a packet. +// Fragments are lazily evicted only when a new a packet with an +// already existing fragmentation-id arrives after the timeout. +func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation { + if lowMemoryLimit >= highMemoryLimit { + lowMemoryLimit = highMemoryLimit + } + + if lowMemoryLimit < 0 { + lowMemoryLimit = 0 + } + + return &Fragmentation{ + reassemblers: make(map[uint32]*reassembler), + highLimit: highMemoryLimit, + lowLimit: lowMemoryLimit, + timeout: reassemblingTimeout, + } +} + +// Process processes an incoming fragment beloning to an ID +// and returns a complete packet when all the packets belonging to that ID have been received. +func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) { + f.mu.Lock() + r, ok := f.reassemblers[id] + if ok && r.tooOld(f.timeout) { + // This is very likely to be an id-collision or someone performing a slow-rate attack. + f.release(r) + ok = false + } + if !ok { + r = newReassembler(id) + f.reassemblers[id] = r + f.rList.PushFront(r) + } + f.mu.Unlock() + + res, done, consumed := r.process(first, last, more, vv) + + f.mu.Lock() + f.size += consumed + if done { + f.release(r) + } + // Evict reassemblers if we are consuming more memory than highLimit until + // we reach lowLimit. + if f.size > f.highLimit { + tail := f.rList.Back() + for f.size > f.lowLimit && tail != nil { + f.release(tail) + tail = tail.Prev() + } + } + f.mu.Unlock() + return res, done +} + +func (f *Fragmentation) release(r *reassembler) { + // Before releasing a fragment we need to check if r is already marked as done. + // Otherwise, we would delete it twice. + if r.checkDoneOrMark() { + return + } + + delete(f.reassemblers, r.id) + f.rList.Remove(r) + f.size -= r.size + if f.size < 0 { + log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) + f.size = 0 + } +} diff --git a/netstack/tcpip/network/fragmentation/fragmentation_test.go b/netstack/tcpip/network/fragmentation/fragmentation_test.go new file mode 100644 index 0000000..7dcdeb0 --- /dev/null +++ b/netstack/tcpip/network/fragmentation/fragmentation_test.go @@ -0,0 +1,159 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "reflect" + "testing" + "time" + + "tcpip/netstack/tcpip/buffer" +) + +// vv is a helper to build VectorisedView from different strings. +func vv(size int, pieces ...string) buffer.VectorisedView { + views := make([]buffer.View, len(pieces)) + for i, p := range pieces { + views[i] = []byte(p) + } + + return buffer.NewVectorisedView(size, views) +} + +type processInput struct { + id uint32 + first uint16 + last uint16 + more bool + vv buffer.VectorisedView +} + +type processOutput struct { + vv buffer.VectorisedView + done bool +} + +var processTestCases = []struct { + comment string + in []processInput + out []processOutput +}{ + { + comment: "One ID", + in: []processInput{ + {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, + { + comment: "Two IDs", + in: []processInput{ + {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")}, + {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")}, + {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")}, + {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")}, + }, + out: []processOutput{ + {vv: buffer.VectorisedView{}, done: false}, + {vv: buffer.VectorisedView{}, done: false}, + {vv: vv(4, "ab", "cd"), done: true}, + {vv: vv(4, "01", "23"), done: true}, + }, + }, +} + +func TestFragmentationProcess(t *testing.T) { + for _, c := range processTestCases { + t.Run(c.comment, func(t *testing.T) { + f := NewFragmentation(1024, 512, DefaultReassembleTimeout) + for i, in := range c.in { + vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv) + if !reflect.DeepEqual(vv, c.out[i].vv) { + t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv) + } + if done != c.out[i].done { + t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done) + } + if c.out[i].done { + if _, ok := f.reassemblers[in.id]; ok { + t.Errorf("Process(%d) did not remove buffer from reassemblers", i) + } + for n := f.rList.Front(); n != nil; n = n.Next() { + if n.id == in.id { + t.Errorf("Process(%d) did not remove buffer from rList", i) + } + } + } + } + }) + } +} + +func TestReassemblingTimeout(t *testing.T) { + timeout := time.Millisecond + f := NewFragmentation(1024, 512, timeout) + // Send first fragment with id = 0, first = 0, last = 0, and more = true. + f.Process(0, 0, 0, true, vv(1, "0")) + // Sleep more than the timeout. + time.Sleep(2 * timeout) + // Send another fragment that completes a packet. + // However, no packet should be reassembled because the fragment arrived after the timeout. + _, done := f.Process(0, 1, 1, false, vv(1, "1")) + if done { + t.Errorf("Fragmentation does not respect the reassembling timeout.") + } +} + +func TestMemoryLimits(t *testing.T) { + f := NewFragmentation(3, 1, DefaultReassembleTimeout) + // Send first fragment with id = 0. + f.Process(0, 0, 0, true, vv(1, "0")) + // Send first fragment with id = 1. + f.Process(1, 0, 0, true, vv(1, "1")) + // Send first fragment with id = 2. + f.Process(2, 0, 0, true, vv(1, "2")) + + // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be + // evicted. + f.Process(3, 0, 0, true, vv(1, "3")) + + if _, ok := f.reassemblers[0]; ok { + t.Errorf("Memory limits are not respected: id=0 has not been evicted.") + } + if _, ok := f.reassemblers[1]; ok { + t.Errorf("Memory limits are not respected: id=1 has not been evicted.") + } + if _, ok := f.reassemblers[3]; !ok { + t.Errorf("Implementation of memory limits is wrong: id=3 is not present.") + } +} + +func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { + f := NewFragmentation(1, 0, DefaultReassembleTimeout) + // Send first fragment with id = 0. + f.Process(0, 0, 0, true, vv(1, "0")) + // Send the same packet again. + f.Process(0, 0, 0, true, vv(1, "0")) + + got := f.size + want := 1 + if got != want { + t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want) + } +} diff --git a/netstack/tcpip/network/fragmentation/reassembler.go b/netstack/tcpip/network/fragmentation/reassembler.go new file mode 100644 index 0000000..2457d16 --- /dev/null +++ b/netstack/tcpip/network/fragmentation/reassembler.go @@ -0,0 +1,118 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "container/heap" + "fmt" + "math" + "sync" + "time" + + "tcpip/netstack/tcpip/buffer" +) + +type hole struct { + first uint16 + last uint16 + deleted bool +} + +type reassembler struct { + reassemblerEntry + id uint32 + size int + mu sync.Mutex + holes []hole + deleted int + heap fragHeap + done bool + creationTime time.Time +} + +func newReassembler(id uint32) *reassembler { + r := &reassembler{ + id: id, + holes: make([]hole, 0, 16), + deleted: 0, + heap: make(fragHeap, 0, 8), + creationTime: time.Now(), + } + r.holes = append(r.holes, hole{ + first: 0, + last: math.MaxUint16, + deleted: false}) + return r +} + +// updateHoles updates the list of holes for an incoming fragment and +// returns true iff the fragment filled at least part of an existing hole. +func (r *reassembler) updateHoles(first, last uint16, more bool) bool { + used := false + for i := range r.holes { + if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first { + continue + } + used = true + r.deleted++ + r.holes[i].deleted = true + if first > r.holes[i].first { + r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false}) + } + if last < r.holes[i].last && more { + r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false}) + } + } + return used +} + +func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) { + r.mu.Lock() + defer r.mu.Unlock() + consumed := 0 + if r.done { + // A concurrent goroutine might have already reassembled + // the packet and emptied the heap while this goroutine + // was waiting on the mutex. We don't have to do anything in this case. + return buffer.VectorisedView{}, false, consumed + } + if r.updateHoles(first, last, more) { + // We store the incoming packet only if it filled some holes. + heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)}) + consumed = vv.Size() + r.size += consumed + } + // Check if all the holes have been deleted and we are ready to reassamble. + if r.deleted < len(r.holes) { + return buffer.VectorisedView{}, false, consumed + } + res, err := r.heap.reassemble() + if err != nil { + panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err)) + } + return res, true, consumed +} + +func (r *reassembler) tooOld(timeout time.Duration) bool { + return time.Now().Sub(r.creationTime) > timeout +} + +func (r *reassembler) checkDoneOrMark() bool { + r.mu.Lock() + prev := r.done + r.done = true + r.mu.Unlock() + return prev +} diff --git a/netstack/tcpip/network/fragmentation/reassembler_list.go b/netstack/tcpip/network/fragmentation/reassembler_list.go new file mode 100644 index 0000000..a659ed1 --- /dev/null +++ b/netstack/tcpip/network/fragmentation/reassembler_list.go @@ -0,0 +1,173 @@ +package fragmentation + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type reassemblerElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type reassemblerList struct { + head *reassembler + tail *reassembler +} + +// Reset resets list l to the empty state. +func (l *reassemblerList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *reassemblerList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *reassemblerList) Front() *reassembler { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *reassemblerList) Back() *reassembler { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *reassemblerList) PushFront(e *reassembler) { + reassemblerElementMapper{}.linkerFor(e).SetNext(l.head) + reassemblerElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *reassemblerList) PushBack(e *reassembler) { + reassemblerElementMapper{}.linkerFor(e).SetNext(nil) + reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *reassemblerList) PushBackList(m *reassemblerList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head) + reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *reassemblerList) InsertAfter(b, e *reassembler) { + a := reassemblerElementMapper{}.linkerFor(b).Next() + reassemblerElementMapper{}.linkerFor(e).SetNext(a) + reassemblerElementMapper{}.linkerFor(e).SetPrev(b) + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *reassemblerList) InsertBefore(a, e *reassembler) { + b := reassemblerElementMapper{}.linkerFor(a).Prev() + reassemblerElementMapper{}.linkerFor(e).SetNext(a) + reassemblerElementMapper{}.linkerFor(e).SetPrev(b) + reassemblerElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + reassemblerElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *reassemblerList) Remove(e *reassembler) { + prev := reassemblerElementMapper{}.linkerFor(e).Prev() + next := reassemblerElementMapper{}.linkerFor(e).Next() + + if prev != nil { + reassemblerElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + reassemblerElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type reassemblerEntry struct { + next *reassembler + prev *reassembler +} + +// Next returns the entry that follows e in the list. +func (e *reassemblerEntry) Next() *reassembler { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *reassemblerEntry) Prev() *reassembler { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *reassemblerEntry) SetNext(elem *reassembler) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *reassemblerEntry) SetPrev(elem *reassembler) { + e.prev = elem +} diff --git a/netstack/tcpip/network/fragmentation/reassembler_test.go b/netstack/tcpip/network/fragmentation/reassembler_test.go new file mode 100644 index 0000000..e7449de --- /dev/null +++ b/netstack/tcpip/network/fragmentation/reassembler_test.go @@ -0,0 +1,105 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fragmentation + +import ( + "math" + "reflect" + "testing" +) + +type updateHolesInput struct { + first uint16 + last uint16 + more bool +} + +var holesTestCases = []struct { + comment string + in []updateHolesInput + want []hole +}{ + { + comment: "No fragments. Expected holes: {[0 -> inf]}.", + in: []updateHolesInput{}, + want: []hole{{first: 0, last: math.MaxUint16, deleted: false}}, + }, + { + comment: "One fragment at beginning. Expected holes: {[2, inf]}.", + in: []updateHolesInput{{first: 0, last: 1, more: true}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 2, last: math.MaxUint16, deleted: false}, + }, + }, + { + comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.", + in: []updateHolesInput{{first: 1, last: 2, more: true}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 0, last: 0, deleted: false}, + {first: 3, last: math.MaxUint16, deleted: false}, + }, + }, + { + comment: "One fragment at the end. Expected holes: {[0, 0]}.", + in: []updateHolesInput{{first: 1, last: 2, more: false}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 0, last: 0, deleted: false}, + }, + }, + { + comment: "One fragment completing a packet. Expected holes: {}.", + in: []updateHolesInput{{first: 0, last: 1, more: false}}, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + }, + }, + { + comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.", + in: []updateHolesInput{ + {first: 0, last: 1, more: true}, + {first: 2, last: 3, more: false}, + }, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 2, last: math.MaxUint16, deleted: true}, + }, + }, + { + comment: "Two overlapping fragments completing a packet. Expected holes: {}.", + in: []updateHolesInput{ + {first: 0, last: 2, more: true}, + {first: 2, last: 3, more: false}, + }, + want: []hole{ + {first: 0, last: math.MaxUint16, deleted: true}, + {first: 3, last: math.MaxUint16, deleted: true}, + }, + }, +} + +func TestUpdateHoles(t *testing.T) { + for _, c := range holesTestCases { + r := newReassembler(0) + for _, i := range c.in { + r.updateHoles(i.first, i.last, i.more) + } + if !reflect.DeepEqual(r.holes, c.want) { + t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want) + } + } +} diff --git a/netstack/tcpip/network/hash/hash.go b/netstack/tcpip/network/hash/hash.go new file mode 100644 index 0000000..4e33646 --- /dev/null +++ b/netstack/tcpip/network/hash/hash.go @@ -0,0 +1,94 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package hash contains utility functions for hashing. +package hash + +import ( + "encoding/binary" + + "tcpip/netstack/rand" + "tcpip/netstack/tcpip/header" +) + +var hashIV = RandN32(1)[0] + +// RandN32 generates a slice of n cryptographic random 32-bit numbers. +func RandN32(n int) []uint32 { + b := make([]byte, 4*n) + if _, err := rand.Read(b); err != nil { + panic("unable to get random numbers: " + err.Error()) + } + r := make([]uint32, n) + for i := range r { + r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)]) + } + return r +} + +// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted +// from linux. +func Hash3Words(a, b, c, initval uint32) uint32 { + const iv = 0xdeadbeef + (3 << 2) + initval += iv + + a += initval + b += initval + c += initval + + c ^= b + c -= rol32(b, 14) + a ^= c + a -= rol32(c, 11) + b ^= a + b -= rol32(a, 25) + c ^= b + c -= rol32(b, 16) + a ^= c + a -= rol32(c, 4) + b ^= a + b -= rol32(a, 14) + c ^= b + c -= rol32(b, 24) + + return c +} + +// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791. +// 根据id,源ip,目的ip和协议类型得到hash值 +func IPv4FragmentHash(h header.IPv4) uint32 { + x := uint32(h.ID())<<16 | uint32(h.Protocol()) + t := h.SourceAddress() + y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = h.DestinationAddress() + z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return Hash3Words(x, y, z, hashIV) +} + +// IPv6FragmentHash computes the hash of the ipv6 fragment. +// Unlike IPv4, the protocol is not used to compute the hash. +// RFC 2640 (sec 4.5) is not very sharp on this aspect. +// As a reference, also Linux ignores the protocol to compute +// the hash (inet6_hash_frag). +func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 { + t := h.SourceAddress() + y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = h.DestinationAddress() + z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return Hash3Words(f.ID(), y, z, hashIV) +} + +func rol32(v, shift uint32) uint32 { + return (v << shift) | (v >> ((-shift) & 31)) +} diff --git a/netstack/tcpip/network/ip_test.go b/netstack/tcpip/network/ip_test.go new file mode 100644 index 0000000..8320bc3 --- /dev/null +++ b/netstack/tcpip/network/ip_test.go @@ -0,0 +1,604 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ip_test + +import ( + "testing" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/link/loopback" + "tcpip/netstack/tcpip/network/ipv4" + "tcpip/netstack/tcpip/network/ipv6" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/tcpip/transport/tcp" + "tcpip/netstack/tcpip/transport/udp" +) + +const ( + localIpv4Addr = "\x0a\x00\x00\x01" + remoteIpv4Addr = "\x0a\x00\x00\x02" + ipv4SubnetAddr = "\x0a\x00\x00\x00" + ipv4SubnetMask = "\xff\xff\xff\x00" + ipv4Gateway = "\x0a\x00\x00\x03" + localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00" + ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03" +) + +// testObject implements two interfaces: LinkEndpoint and TransportDispatcher. +// The former is used to pretend that it's a link endpoint so that we can +// inspect packets written by the network endpoints. The latter is used to +// pretend that it's the network stack so that it can inspect incoming packets +// that have been handled by the network endpoints. +// +// Packets are checked by comparing their fields/values against the expected +// values stored in the test object itself. +type testObject struct { + t *testing.T + protocol tcpip.TransportProtocolNumber + contents []byte + srcAddr tcpip.Address + dstAddr tcpip.Address + v4 bool + typ stack.ControlType + extra uint32 + + dataCalls int + controlCalls int +} + +// checkValues verifies that the transport protocol, data contents, src & dst +// addresses of a packet match what's expected. If any field doesn't match, the +// test fails. +func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) { + v := vv.ToView() + if protocol != t.protocol { + t.t.Errorf("protocol = %v, want %v", protocol, t.protocol) + } + + if srcAddr != t.srcAddr { + t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr) + } + + if dstAddr != t.dstAddr { + t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr) + } + + if len(v) != len(t.contents) { + t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents)) + } + + for i := range t.contents { + if t.contents[i] != v[i] { + t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i]) + } + } +} + +// DeliverTransportPacket is called by network endpoints after parsing incoming +// packets. This is used by the test object to verify that the results of the +// parsing are expected. +func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { + t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress) + t.dataCalls++ +} + +// DeliverTransportControlPacket is called by network endpoints after parsing +// incoming control (ICMP) packets. This is used by the test object to verify +// that the results of the parsing are expected. +func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + t.checkValues(trans, vv, remote, local) + if typ != t.typ { + t.t.Errorf("typ = %v, want %v", typ, t.typ) + } + if extra != t.extra { + t.t.Errorf("extra = %v, want %v", extra, t.extra) + } + t.controlCalls++ +} + +// Attach is only implemented to satisfy the LinkEndpoint interface. +func (*testObject) Attach(stack.NetworkDispatcher) {} + +// IsAttached implements stack.LinkEndpoint.IsAttached. +func (*testObject) IsAttached() bool { + return true +} + +// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that +// matches the linux loopback MTU. +func (*testObject) MTU() uint32 { + return 65536 +} + +// Capabilities implements stack.LinkEndpoint.Capabilities. +func (*testObject) Capabilities() stack.LinkEndpointCapabilities { + return 0 +} + +// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface. +func (*testObject) MaxHeaderLength() uint16 { + return 0 +} + +// LinkAddress returns the link address of this endpoint. +func (*testObject) LinkAddress() tcpip.LinkAddress { + return "" +} + +// WritePacket is called by network endpoints after producing a packet and +// writing it to the link endpoint. This is used by the test object to verify +// that the produced packet is as expected. +func (t *testObject) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + var prot tcpip.TransportProtocolNumber + var srcAddr tcpip.Address + var dstAddr tcpip.Address + + if t.v4 { + h := header.IPv4(hdr.View()) + prot = tcpip.TransportProtocolNumber(h.Protocol()) + srcAddr = h.SourceAddress() + dstAddr = h.DestinationAddress() + + } else { + h := header.IPv6(hdr.View()) + prot = tcpip.TransportProtocolNumber(h.NextHeader()) + srcAddr = h.SourceAddress() + dstAddr = h.DestinationAddress() + } + t.checkValues(prot, payload, srcAddr, dstAddr) + return nil +} + +func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { + s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s.CreateNIC(1, loopback.New()) + s.AddAddress(1, ipv4.ProtocolNumber, local) + s.SetRouteTable([]tcpip.Route{{ + Destination: ipv4SubnetAddr, + Mask: ipv4SubnetMask, + Gateway: ipv4Gateway, + NIC: 1, + }}) + + return s.FindRoute(1, local, remote, ipv4.ProtocolNumber) +} + +func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) { + s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{}) + s.CreateNIC(1, loopback.New()) + s.AddAddress(1, ipv6.ProtocolNumber, local) + s.SetRouteTable([]tcpip.Route{{ + Destination: ipv6SubnetAddr, + Mask: ipv6SubnetMask, + Gateway: ipv6Gateway, + NIC: 1, + }}) + + return s.FindRoute(1, local, remote, ipv6.ProtocolNumber) +} + +func TestIPv4Send(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, nil, &o) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Allocate and initialize the payload view. + payload := buffer.NewView(100) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + + // Allocate the header buffer. + hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + + // Issue the write. + o.protocol = 123 + o.srcAddr = localIpv4Addr + o.dstAddr = remoteIpv4Addr + o.contents = payload + + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } +} + +func TestIPv4Receive(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv4MinimumSize + 30 + view := buffer.NewView(totalLen) + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, + }) + + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + view[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr + o.contents = view[header.IPv4MinimumSize:totalLen] + + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + ep.HandlePacket(&r, view.ToVectorisedView()) + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv4ReceiveControl(t *testing.T) { + const mtu = 0xbeef - header.IPv4MinimumSize + cases := []struct { + name string + expectedCount int + fragmentOffset uint16 + code uint8 + expectedTyp stack.ControlType + expectedExtra uint32 + trunc int + }{ + {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0}, + {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, + {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, + {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8}, + {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8}, + } + r := stack.Route{ + LocalAddress: localIpv4Addr, + RemoteAddress: "\x0a\x00\x00\xbb", + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + o := testObject{t: t} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4 + view := buffer.NewView(dataOffset + 8) + + // Create the outer IPv4 header. + ip := header.IPv4(view) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(view) - c.trunc), + TTL: 20, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: "\x0a\x00\x00\xbb", + DstAddr: localIpv4Addr, + }) + + // Create the ICMP header. + icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) + icmp.SetType(header.ICMPv4DstUnreachable) + icmp.SetCode(c.code) + copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + + // Create the inner IPv4 header. + ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:]) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: 100, + TTL: 20, + Protocol: 10, + FragmentOffset: c.fragmentOffset, + SrcAddr: localIpv4Addr, + DstAddr: remoteIpv4Addr, + }) + + // Make payload be non-zero. + for i := dataOffset; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to IPv4 endpoint, dispatcher will validate that + // it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr + o.contents = view[dataOffset:] + o.typ = c.expectedTyp + o.extra = c.expectedExtra + + vv := view[:len(view)-c.trunc].ToVectorisedView() + ep.HandlePacket(&r, vv) + if want := c.expectedCount; o.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + } + }) + } +} + +func TestIPv4FragmentationReceive(t *testing.T) { + o := testObject{t: t, v4: true} + proto := ipv4.NewProtocol() + ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv4MinimumSize + 24 + + frag1 := buffer.NewView(totalLen) + ip1 := header.IPv4(frag1) + ip1.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + FragmentOffset: 0, + Flags: header.IPv4FlagMoreFragments, + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, + }) + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + frag1[i] = uint8(i) + } + + frag2 := buffer.NewView(totalLen) + ip2 := header.IPv4(frag2) + ip2.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(totalLen), + TTL: 20, + Protocol: 10, + FragmentOffset: 24, + SrcAddr: remoteIpv4Addr, + DstAddr: localIpv4Addr, + }) + // Make payload be non-zero. + for i := header.IPv4MinimumSize; i < totalLen; i++ { + frag2[i] = uint8(i) + } + + // Give packet to ipv4 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv4Addr + o.dstAddr = localIpv4Addr + o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...) + + r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + + // Send first segment. + ep.HandlePacket(&r, frag1.ToVectorisedView()) + if o.dataCalls != 0 { + t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls) + } + + // Send second segment. + ep.HandlePacket(&r, frag2.ToVectorisedView()) + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv6Send(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, nil, &o) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Allocate and initialize the payload view. + payload := buffer.NewView(100) + for i := 0; i < len(payload); i++ { + payload[i] = uint8(i) + } + + // Allocate the header buffer. + hdr := buffer.NewPrependable(int(ep.MaxHeaderLength())) + + // Issue the write. + o.protocol = 123 + o.srcAddr = localIpv6Addr + o.dstAddr = remoteIpv6Addr + o.contents = payload + + r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil { + t.Fatalf("WritePacket failed: %v", err) + } +} + +func TestIPv6Receive(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + totalLen := header.IPv6MinimumSize + 30 + view := buffer.NewView(totalLen) + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(totalLen - header.IPv6MinimumSize), + NextHeader: 10, + HopLimit: 20, + SrcAddr: remoteIpv6Addr, + DstAddr: localIpv6Addr, + }) + + // Make payload be non-zero. + for i := header.IPv6MinimumSize; i < totalLen; i++ { + view[i] = uint8(i) + } + + // Give packet to ipv6 endpoint, dispatcher will validate that it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv6Addr + o.dstAddr = localIpv6Addr + o.contents = view[header.IPv6MinimumSize:totalLen] + + r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr) + if err != nil { + t.Fatalf("could not find route: %v", err) + } + + ep.HandlePacket(&r, view.ToVectorisedView()) + if o.dataCalls != 1 { + t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls) + } +} + +func TestIPv6ReceiveControl(t *testing.T) { + newUint16 := func(v uint16) *uint16 { return &v } + + const mtu = 0xffff + cases := []struct { + name string + expectedCount int + fragmentOffset *uint16 + typ header.ICMPv6Type + code uint8 + expectedTyp stack.ControlType + expectedExtra uint32 + trunc int + }{ + {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0}, + {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10}, + {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8}, + {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8}, + {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8}, + {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0}, + {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8}, + } + r := stack.Route{ + LocalAddress: localIpv6Addr, + RemoteAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + o := testObject{t: t} + proto := ipv6.NewProtocol() + ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + defer ep.Close() + + dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4 + if c.fragmentOffset != nil { + dataOffset += header.IPv6FragmentHeaderSize + } + view := buffer.NewView(dataOffset + 8) + + // Create the outer IPv6 header. + ip := header.IPv6(view) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc), + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: 20, + SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa", + DstAddr: localIpv6Addr, + }) + + // Create the ICMP header. + icmp := header.ICMPv6(view[header.IPv6MinimumSize:]) + icmp.SetType(c.typ) + icmp.SetCode(c.code) + copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + + // Create the inner IPv6 header. + ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + ip.Encode(&header.IPv6Fields{ + PayloadLength: 100, + NextHeader: 10, + HopLimit: 20, + SrcAddr: localIpv6Addr, + DstAddr: remoteIpv6Addr, + }) + + // Build the fragmentation header if needed. + if c.fragmentOffset != nil { + ip.SetNextHeader(header.IPv6FragmentHeader) + frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:]) + frag.Encode(&header.IPv6FragmentFields{ + NextHeader: 10, + FragmentOffset: *c.fragmentOffset, + M: true, + Identification: 0x12345678, + }) + } + + // Make payload be non-zero. + for i := dataOffset; i < len(view); i++ { + view[i] = uint8(i) + } + + // Give packet to IPv6 endpoint, dispatcher will validate that + // it's ok. + o.protocol = 10 + o.srcAddr = remoteIpv6Addr + o.dstAddr = localIpv6Addr + o.contents = view[dataOffset:] + o.typ = c.expectedTyp + o.extra = c.expectedExtra + + vv := view[:len(view)-c.trunc].ToVectorisedView() + ep.HandlePacket(&r, vv) + if want := c.expectedCount; o.controlCalls != want { + t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want) + } + }) + } +} diff --git a/netstack/tcpip/network/ipv4/icmp.go b/netstack/tcpip/network/ipv4/icmp.go new file mode 100644 index 0000000..37ff525 --- /dev/null +++ b/netstack/tcpip/network/ipv4/icmp.go @@ -0,0 +1,134 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4 + +import ( + "encoding/binary" + "log" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" +) + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +// handleControl处理ICMP数据包包含导致ICMP发送的原始数据包的标头的情况。 +// 此信息用于确定必须通知哪个传输端点有关ICMP数据包。 +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + h := header.IPv4(vv.First()) + + // We don't use IsValid() here because ICMP only requires that the IP + // header plus 8 bytes of the transport header be included. So it's + // likely that it is truncated, which would cause IsValid to return + // false. + // + // Drop packet if it doesn't have the basic IPv4 header or if the + // original source address doesn't match the endpoint's address. + if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress { + return + } + + hlen := int(h.HeaderLength()) + if vv.Size() < hlen || h.FragmentOffset() != 0 { + // We won't be able to handle this if it doesn't contain the + // full IPv4 header, or if it's a fragment not at offset 0 + // (because it won't have the transport header). + return + } + + // Skip the ip header, then deliver control message. + vv.TrimFront(hlen) + p := h.TransportProtocol() + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) +} + +// 处理ICMP报文 +func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { + v := vv.First() + if len(v) < header.ICMPv4MinimumSize { + return + } + h := header.ICMPv4(v) + + // 更具icmp的类型来进行相应的处理 + switch h.Type() { + case header.ICMPv4Echo: // icmp echo请求 + if len(v) < header.ICMPv4EchoMinimumSize { + return + } + log.Printf("icmp echo") + vv.TrimFront(header.ICMPv4MinimumSize) + req := echoRequest{r: r.Clone(), v: vv.ToView()} + select { + case e.echoRequests <- req: // 发送给echoReplier处理 + default: + req.r.Release() + } + + case header.ICMPv4EchoReply: // icmp echo响应 + if len(v) < header.ICMPv4EchoMinimumSize { + return + } + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv) + + case header.ICMPv4DstUnreachable: // 目标不可达 + if len(v) < header.ICMPv4DstUnreachableMinimumSize { + return + } + vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize) + switch h.Code() { + case header.ICMPv4PortUnreachable: // 端口不可达 + e.handleControl(stack.ControlPortUnreachable, 0, vv) + + case header.ICMPv4FragmentationNeeded: // 需要进行分片但设置不分片标志 + mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:])) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + } + } + // TODO: Handle other ICMP types. +} + +type echoRequest struct { + r stack.Route + v buffer.View +} + +// 处理icmp echo请求的goroutine +func (e *endpoint) echoReplier() { + for req := range e.echoRequests { + sendPing4(&req.r, 0, req.v) + req.r.Release() + } +} + +// 根据icmp echo请求,封装icmp echo响应报文,并传给ip层处理 +func sendPing4(r *stack.Route, code byte, data buffer.View) *tcpip.Error { + hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) + + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + icmpv4.SetType(header.ICMPv4EchoReply) + icmpv4.SetCode(code) + copy(icmpv4[header.ICMPv4MinimumSize:], data) + data = data[header.ICMPv4EchoMinimumSize-header.ICMPv4MinimumSize:] + icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + + log.Printf("icmp reply") + // 传给ip层处理 + return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) +} diff --git a/netstack/tcpip/network/ipv4/ipv4.go b/netstack/tcpip/network/ipv4/ipv4.go new file mode 100644 index 0000000..a4dc2a6 --- /dev/null +++ b/netstack/tcpip/network/ipv4/ipv4.go @@ -0,0 +1,280 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ipv4 contains the implementation of the ipv4 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the +// network protocols when calling stack.New(). Then endpoints can be created +// by passing ipv4.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv4 + +import ( + "log" + "sync/atomic" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/network/fragmentation" + "tcpip/netstack/tcpip/network/hash" + "tcpip/netstack/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ipv4 protocol name. + ProtocolName = "ipv4" + + // ProtocolNumber is the ipv4 protocol number. + ProtocolNumber = header.IPv4ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // TotalLength field of the ipv4 header. + maxTotalSize = 0xffff + + // buckets is the number of identifier buckets. + buckets = 2048 +) + +// ip端表示 +type endpoint struct { + // 网卡id + nicid tcpip.NICID + // 表示该endpoint的id,也是ip地址 + id stack.NetworkEndpointID + // 链路端的表示 + linkEP stack.LinkEndpoint + // 报文分发器 + dispatcher stack.TransportDispatcher + // ping请求报文接收队列 + echoRequests chan echoRequest + // ip报文分片处理器 + fragmentation *fragmentation.Fragmentation +} + +// NewEndpoint creates a new ipv4 endpoint. +// 根据参数,新建一个ipv4端 +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, + dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + e := &endpoint{ + nicid: nicid, + id: stack.NetworkEndpointID{LocalAddress: addr}, + linkEP: linkEP, + dispatcher: dispatcher, + echoRequests: make(chan echoRequest, 10), + fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, + fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout), + } + + // 启动一个goroutine给icmp请求服务 + go e.echoReplier() + + return e, nil +} + +// DefaultTTL is the default time-to-live value for this endpoint. +// 默认的TTL值,TTL每经过路由转发一次就会减1 +func (e *endpoint) DefaultTTL() uint8 { + return 255 +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +// 获取去除ipv4头部后的最大报文长度 +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +// ID returns the ipv4 endpoint ID. +// 获取该网络层端的id,也就是ip地址 +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// MaxHeaderLength returns the maximum length needed by ipv4 headers (and +// underlying protocols). +// 链路层和网络层的头部长度 +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize +} + +// WritePacket writes a packet to the given destination address and protocol. +// 将传输层的数据封装加上IP头,并调用网卡的写入接口,写入IP报文 +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, + protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + // 预留ip报文的空间 + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + length := uint16(hdr.UsedLength() + payload.Size()) + id := uint32(0) + // 如果报文长度大于68 + if length > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1) + } + // ip首部编码 + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: length, + ID: uint16(id), + TTL: ttl, + Protocol: uint8(protocol), + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + // 计算校验和和设置校验和 + ip.SetChecksum(^ip.CalculateChecksum()) + r.Stats().IP.PacketsSent.Increment() + + // 写入网卡接口 + log.Printf("send ipv4 packet %d bytes, proto: 0x%x", hdr.UsedLength()+payload.Size(), protocol) + return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) +} + +// HandlePacket is called by the link layer when new ipv4 packets arrive for +// this endpoint. +// 收到ip包的处理 +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { + // 得到ip报文 + h := header.IPv4(vv.First()) + // 检查报文是否有效 + if !h.IsValid(vv.Size()) { + return + } + + hlen := int(h.HeaderLength()) + tlen := int(h.TotalLength()) + vv.TrimFront(hlen) + vv.CapLength(tlen - hlen) + + // 检查ip报文是否有更多的分片 + more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 + // 是否需要ip重组 + if more || h.FragmentOffset() != 0 { + // The packet is a fragment, let's try to reassemble it. + last := h.FragmentOffset() + uint16(vv.Size()) - 1 + var ready bool + // ip分片重组 + vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv) + if !ready { + return + } + } + // 得到传输层的协议 + p := h.TransportProtocol() + // 如果时ICMP协议,则进入ICMP处理函数 + if p == header.ICMPv4ProtocolNumber { + e.handleICMP(r, vv) + return + } + r.Stats().IP.PacketsDelivered.Increment() + // 根据协议分发到不通处理函数,比如协议时TCP,会进入tcp.HandlePacket + log.Printf("recv ipv4 packet %d bytes, proto: 0x%x", tlen, p) + e.dispatcher.DeliverTransportPacket(r, p, vv) +} + +// Close cleans up resources associated with the endpoint. +func (e *endpoint) Close() { + close(e.echoRequests) +} + +// 实现NetworkProtocol接口 +type protocol struct{} + +// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported +// only for tests that short-circuit the stack. Regular use of the protocol is +// done via the stack, which gets a protocol descriptor from the init() function +// below. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} + +// Number returns the ipv4 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv4 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv4MinimumSize +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv4(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + if mtu > maxTotalSize { + mtu = maxTotalSize + } + return mtu - header.IPv4MinimumSize +} + +// hashRoute calculates a hash value for the given route. It uses the source & +// destination address, the transport protocol number, and a random initial +// value (generated once on initialization) to generate the hash. +func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 { + t := r.LocalAddress + a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + t = r.RemoteAddress + b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24 + return hash.Hash3Words(a, b, uint32(protocol), hashIV) +} + +var ( + ids []uint32 + hashIV uint32 +) + +// 初始化的时候初始化buckets,用于hash计算 +// 并在网络层协议中注册ipv4协议 +func init() { + ids = make([]uint32, buckets) + + // Randomly initialize hashIV and the ids. + r := hash.RandN32(1 + buckets) + for i := range ids { + ids[i] = r[i] + } + hashIV = r[buckets] + + // 注册网络层协议 + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/netstack/tcpip/network/ipv4/ipv4_test.go b/netstack/tcpip/network/ipv4/ipv4_test.go new file mode 100644 index 0000000..18d7b05 --- /dev/null +++ b/netstack/tcpip/network/ipv4/ipv4_test.go @@ -0,0 +1,92 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv4_test + +import ( + "testing" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/link/channel" + "tcpip/netstack/tcpip/link/sniffer" + "tcpip/netstack/tcpip/network/ipv4" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/tcpip/transport/udp" + "tcpip/netstack/waiter" +) + +func TestExcludeBroadcast(t *testing.T) { + s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + + const defaultMTU = 65536 + id, _ := channel.New(256, defaultMTU, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }}) + + randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53} + + var wq waiter.Queue + t.Run("WithoutPrimaryAddress", func(t *testing.T) { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + // Cannot connect using a broadcast address as the source. + if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute { + t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + } + + // However, we can bind to a broadcast address to listen. + if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}, nil); err != nil { + t.Errorf("Bind failed: %v", err) + } + }) + + t.Run("WithPrimaryAddress", func(t *testing.T) { + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + defer ep.Close() + + // Add a valid primary endpoint address, now we can connect. + if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + if err := ep.Connect(randomAddr); err != nil { + t.Errorf("Connect failed: %v", err) + } + }) +} diff --git a/netstack/tcpip/network/ipv6/icmp.go b/netstack/tcpip/network/ipv6/icmp.go new file mode 100644 index 0000000..35672ef --- /dev/null +++ b/netstack/tcpip/network/ipv6/icmp.go @@ -0,0 +1,231 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "encoding/binary" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" +) + +// handleControl handles the case when an ICMP packet contains the headers of +// the original packet that caused the ICMP one to be sent. This information is +// used to find out which transport endpoint must be notified about the ICMP +// packet. +func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + h := header.IPv6(vv.First()) + + // We don't use IsValid() here because ICMP only requires that up to + // 1280 bytes of the original packet be included. So it's likely that it + // is truncated, which would cause IsValid to return false. + // + // Drop packet if it doesn't have the basic IPv6 header or if the + // original source address doesn't match the endpoint's address. + if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress { + return + } + + // Skip the IP header, then handle the fragmentation header if there + // is one. + vv.TrimFront(header.IPv6MinimumSize) + p := h.TransportProtocol() + if p == header.IPv6FragmentHeader { + f := header.IPv6Fragment(vv.First()) + if !f.IsValid() || f.FragmentOffset() != 0 { + // We can't handle fragments that aren't at offset 0 + // because they don't have the transport headers. + return + } + + // Skip fragmentation header and find out the actual protocol + // number. + vv.TrimFront(header.IPv6FragmentHeaderSize) + p = f.TransportProtocol() + } + + // Deliver the control packet to the transport endpoint. + e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv) +} + +func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) { + v := vv.First() + if len(v) < header.ICMPv6MinimumSize { + return + } + h := header.ICMPv6(v) + + switch h.Type() { + case header.ICMPv6PacketTooBig: + if len(v) < header.ICMPv6PacketTooBigMinimumSize { + return + } + vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize) + mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:]) + e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) + + case header.ICMPv6DstUnreachable: + if len(v) < header.ICMPv6DstUnreachableMinimumSize { + return + } + vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize) + switch h.Code() { + case header.ICMPv6PortUnreachable: + e.handleControl(stack.ControlPortUnreachable, 0, vv) + } + + case header.ICMPv6NeighborSolicit: + if len(v) < header.ICMPv6NeighborSolicitMinimumSize { + return + } + targetAddr := tcpip.Address(v[8 : 8+16]) + if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 { + // We don't have a useful answer; the best we can do is ignore the request. + return + } + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborAdvert) + pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag + copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr) + pkt[icmpV6OptOffset] = ndpOptDstLinkAddr + pkt[icmpV6LengthOffset] = 1 + copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:]) + + // ICMPv6 Neighbor Solicit messages are always sent to + // specially crafted IPv6 multicast addresses. As a result, the + // route we end up with here has as its LocalAddress such a + // multicast address. It would be nonsense to claim that our + // source address is a multicast address, so we manually set + // the source address to the target address requested in the + // solicit message. Since that requires mutating the route, we + // must first clone it. + r := r.Clone() + defer r.Release() + r.LocalAddress = targetAddr + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + r.WritePacket(hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + + e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + + case header.ICMPv6NeighborAdvert: + if len(v) < header.ICMPv6NeighborAdvertSize { + return + } + targetAddr := tcpip.Address(v[8 : 8+16]) + e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress) + if targetAddr != r.RemoteAddress { + e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress) + } + + case header.ICMPv6EchoRequest: + if len(v) < header.ICMPv6EchoMinimumSize { + return + } + vv.TrimFront(header.ICMPv6EchoMinimumSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(pkt, h) + pkt.SetType(header.ICMPv6EchoReply) + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) + r.WritePacket(hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL()) + + case header.ICMPv6EchoReply: + if len(v) < header.ICMPv6EchoMinimumSize { + return + } + e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv) + + } +} + +const ( + ndpSolicitedFlag = 1 << 6 + ndpOverrideFlag = 1 << 5 + + ndpOptSrcLinkAddr = 1 + ndpOptDstLinkAddr = 2 + + icmpV6FlagOffset = 4 + icmpV6OptOffset = 24 + icmpV6LengthOffset = 25 +) + +var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + +var _ stack.LinkAddressResolver = (*protocol)(nil) + +// LinkAddressProtocol implements stack.LinkAddressResolver. +func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv6ProtocolNumber +} + +// LinkAddressRequest implements stack.LinkAddressResolver. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { + snaddr := header.SolicitedNodeAddr(addr) + r := &stack.Route{ + LocalAddress: localAddr, + RemoteAddress: snaddr, + RemoteLinkAddress: broadcastMAC, + } + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize)) + pkt.SetType(header.ICMPv6NeighborSolicit) + copy(pkt[icmpV6OptOffset-len(addr):], addr) + pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr + pkt[icmpV6LengthOffset] = 1 + copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress()) + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + + length := uint16(hdr.UsedLength()) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: length, + NextHeader: uint8(header.ICMPv6ProtocolNumber), + HopLimit: defaultIPv6HopLimit, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + + return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) +} + +// ResolveStaticAddress implements stack.LinkAddressResolver. +func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + return "", false +} + +func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 { + // Calculate the IPv6 pseudo-header upper-layer checksum. + xsum := header.Checksum([]byte(src), 0) + xsum = header.Checksum([]byte(dst), xsum) + var upperLayerLength [4]byte + binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size())) + xsum = header.Checksum(upperLayerLength[:], xsum) + xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum) + for _, v := range vv.Views() { + xsum = header.Checksum(v, xsum) + } + + // h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum. + h2, h3 := h[2], h[3] + h[2], h[3] = 0, 0 + xsum = ^header.Checksum(h, xsum) + h[2], h[3] = h2, h3 + + return xsum +} diff --git a/netstack/tcpip/network/ipv6/icmp_test.go b/netstack/tcpip/network/ipv6/icmp_test.go new file mode 100644 index 0000000..7b195b3 --- /dev/null +++ b/netstack/tcpip/network/ipv6/icmp_test.go @@ -0,0 +1,226 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ipv6 + +import ( + "context" + "runtime" + "strings" + "testing" + "time" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/link/channel" + "tcpip/netstack/tcpip/link/sniffer" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/tcpip/transport/ping" + "tcpip/netstack/waiter" +) + +const ( + linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06") + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f") +) + +var ( + lladdr0 = header.LinkLocalAddr(linkAddr0) + lladdr1 = header.LinkLocalAddr(linkAddr1) +) + +type icmpInfo struct { + typ header.ICMPv6Type + src tcpip.Address +} + +type testContext struct { + t *testing.T + s0 *stack.Stack + s1 *stack.Stack + + linkEP0 *channel.Endpoint + linkEP1 *channel.Endpoint + + icmpCh chan icmpInfo +} + +type endpointWithResolutionCapability struct { + stack.LinkEndpoint +} + +func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities { + return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired +} + +func newTestContext(t *testing.T) *testContext { + c := &testContext{ + t: t, + s0: stack.New([]string{ProtocolName}, []string{ping.ProtocolName6}, stack.Options{}), + s1: stack.New([]string{ProtocolName}, []string{ping.ProtocolName6}, stack.Options{}), + icmpCh: make(chan icmpInfo, 10), + } + + const defaultMTU = 65536 + _, linkEP0 := channel.New(256, defaultMTU, linkAddr0) + c.linkEP0 = linkEP0 + wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0} + id0 := stack.RegisterLinkEndpoint(wrappedEP0) + if testing.Verbose() { + id0 = sniffer.New(id0) + } + if err := c.s0.CreateNIC(1, id0); err != nil { + t.Fatalf("CreateNIC s0: %v", err) + } + if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil { + t.Fatalf("AddAddress lladdr0: %v", err) + } + if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil { + t.Fatalf("AddAddress sn lladdr0: %v", err) + } + + _, linkEP1 := channel.New(256, defaultMTU, linkAddr1) + c.linkEP1 = linkEP1 + wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1} + id1 := stack.RegisterLinkEndpoint(wrappedEP1) + if err := c.s1.CreateNIC(1, id1); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil { + t.Fatalf("AddAddress lladdr1: %v", err) + } + if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil { + t.Fatalf("AddAddress sn lladdr1: %v", err) + } + + c.s0.SetRouteTable( + []tcpip.Route{{ + Destination: lladdr1, + Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + NIC: 1, + }}, + ) + c.s1.SetRouteTable( + []tcpip.Route{{ + Destination: lladdr0, + Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)), + NIC: 1, + }}, + ) + + go c.routePackets(linkEP0.C, linkEP1) + go c.routePackets(linkEP1.C, linkEP0) + + return c +} + +func (c *testContext) countPacket(pkt channel.PacketInfo) { + if pkt.Proto != ProtocolNumber { + return + } + ipv6 := header.IPv6(pkt.Header) + transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader()) + if transProto != header.ICMPv6ProtocolNumber { + return + } + b := pkt.Header[header.IPv6MinimumSize:] + icmp := header.ICMPv6(b) + c.icmpCh <- icmpInfo{ + typ: icmp.Type(), + src: ipv6.SourceAddress(), + } +} + +func (c *testContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) { + for pkt := range ch { + c.countPacket(pkt) + views := []buffer.View{pkt.Header, pkt.Payload} + size := len(pkt.Header) + len(pkt.Payload) + vv := buffer.NewVectorisedView(size, views) + ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), vv) + } +} + +func (c *testContext) cleanup() { + close(c.linkEP0.C) + close(c.linkEP1.C) +} + +func TestLinkResolution(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber) + if err != nil { + t.Fatal(err) + } + defer r.Release() + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + pkt.SetType(header.ICMPv6EchoRequest) + pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) + payload := tcpip.SlicePayload(hdr.View()) + + // We can't send our payload directly over the route because that + // doesn't provoke NDP discovery. + var wq waiter.Queue + ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq) + if err != nil { + t.Fatal(err) + } + + // This actually takes about 10 milliseconds, so no need to wait for + // a multi-minute go test timeout if something is broken. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + for { + if ctx.Err() != nil { + break + } + if _, _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}); err == tcpip.ErrNoLinkAddress { + // There's something asynchronous going on; yield to let it do its thing. + runtime.Gosched() + } else if err == nil { + break + } else { + t.Fatal(err) + } + } + + stats := make(map[header.ICMPv6Type]int) + for { + select { + case <-ctx.Done(): + t.Errorf("timeout waiting for ICMP, got: %#+v", stats) + return + case icmpInfo := <-c.icmpCh: + switch icmpInfo.typ { + case header.ICMPv6NeighborAdvert: + if got, want := icmpInfo.src, lladdr1; got != want { + t.Errorf("got ICMPv6NeighborAdvert.sourceAddress = %v, want = %v", got, want) + } + } + stats[icmpInfo.typ]++ + + if stats[header.ICMPv6NeighborSolicit] > 0 && + stats[header.ICMPv6NeighborAdvert] > 0 && + stats[header.ICMPv6EchoRequest] > 0 && + stats[header.ICMPv6EchoReply] > 0 { + return + } + } + } +} diff --git a/netstack/tcpip/network/ipv6/ipv6.go b/netstack/tcpip/network/ipv6/ipv6.go new file mode 100644 index 0000000..01eb7fb --- /dev/null +++ b/netstack/tcpip/network/ipv6/ipv6.go @@ -0,0 +1,187 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ipv6 contains the implementation of the ipv6 network protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the +// network protocols when calling stack.New(). Then endpoints can be created +// by passing ipv6.ProtocolNumber as the network protocol number when calling +// Stack.NewEndpoint(). +package ipv6 + +import ( + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ipv6 protocol name. + ProtocolName = "ipv6" + + // ProtocolNumber is the ipv6 protocol number. + ProtocolNumber = header.IPv6ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // PayloadLength field of the ipv6 header. + maxPayloadSize = 0xffff + + // defaultIPv6HopLimit is the default hop limit for IPv6 Packets + // egressed by Netstack. + defaultIPv6HopLimit = 255 +) + +type endpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + linkEP stack.LinkEndpoint + linkAddrCache stack.LinkAddressCache + dispatcher stack.TransportDispatcher +} + +// DefaultTTL is the default hop limit for this endpoint. +func (e *endpoint) DefaultTTL() uint8 { + return 255 +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +// ID returns the ipv6 endpoint ID. +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// MaxHeaderLength returns the maximum length needed by ipv6 headers (and +// underlying protocols). +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize +} + +// WritePacket writes a packet to the given destination address and protocol. +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + length := uint16(hdr.UsedLength() + payload.Size()) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: length, + NextHeader: uint8(protocol), + HopLimit: ttl, + SrcAddr: r.LocalAddress, + DstAddr: r.RemoteAddress, + }) + r.Stats().IP.PacketsSent.Increment() + + return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber) +} + +// HandlePacket is called by the link layer when new ipv6 packets arrive for +// this endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { + h := header.IPv6(vv.First()) + if !h.IsValid(vv.Size()) { + return + } + + vv.TrimFront(header.IPv6MinimumSize) + vv.CapLength(int(h.PayloadLength())) + + p := h.TransportProtocol() + if p == header.ICMPv6ProtocolNumber { + e.handleICMP(r, vv) + return + } + + r.Stats().IP.PacketsDelivered.Increment() + e.dispatcher.DeliverTransportPacket(r, p, vv) +} + +// Close cleans up resources associated with the endpoint. +func (*endpoint) Close() {} + +type protocol struct{} + +// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported +// only for tests that short-circuit the stack. Regular use of the protocol is +// done via the stack, which gets a protocol descriptor from the init() function +// below. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} + +// Number returns the ipv6 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv6 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv6MinimumSize +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.IPv6(v) + return h.SourceAddress(), h.DestinationAddress() +} + +// NewEndpoint creates a new ipv6 endpoint. +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return &endpoint{ + nicid: nicid, + id: stack.NetworkEndpointID{LocalAddress: addr}, + linkEP: linkEP, + linkAddrCache: linkAddrCache, + dispatcher: dispatcher, + }, nil +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + mtu -= header.IPv6MinimumSize + if mtu <= maxPayloadSize { + return mtu + } + return maxPayloadSize +} + +func init() { + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/netstack/tcpip/ports/ports.go b/netstack/tcpip/ports/ports.go new file mode 100644 index 0000000..121ba21 --- /dev/null +++ b/netstack/tcpip/ports/ports.go @@ -0,0 +1,185 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ports provides PortManager that manages allocating, reserving and releasing ports. +package ports + +import ( + "log" + "math" + "math/rand" + "sync" + + "tcpip/netstack/tcpip" +) + +const ( + // 临时端口的最小值 + FirstEphemeral = 16000 + + anyIPAddress tcpip.Address = "" +) + +// 端口的唯一标识: 网络层协议-传输层协议-端口号 +type portDescriptor struct { + network tcpip.NetworkProtocolNumber + transport tcpip.TransportProtocolNumber + port uint16 +} + +// 管理端口的对象,由它来保留和释放端口 +type PortManager struct { + mu sync.RWMutex + // 用一个map接口来保存端口被占用 + allocatedPorts map[portDescriptor]bindAddresses +} + +// bindAddresses is a set of IP addresses. +type bindAddresses map[tcpip.Address]struct{} + +// isAvailable checks whether an IP address is available to bind to. +func (b bindAddresses) isAvailable(addr tcpip.Address) bool { + if addr == anyIPAddress { + return len(b) == 0 + } + + // If all addresses for this portDescriptor are already bound, no + // address is available. + if _, ok := b[anyIPAddress]; ok { + return false + } + + if _, ok := b[addr]; ok { + return false + } + return true +} + +// NewPortManager 新建一个端口管理器 +func NewPortManager() *PortManager { + return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)} +} + +// PickEphemeralPort 从端口管理器中随机分配一个端口,并调用testPort来检测是否可用。 +func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) { + count := uint16(math.MaxUint16 - FirstEphemeral + 1) + offset := uint16(rand.Int31n(int32(count))) + + for i := uint16(0); i < count; i++ { + port = FirstEphemeral + (offset+i)%count + ok, err := testPort(port) + if err != nil { + return 0, err + } + + if ok { + return port, nil + } + } + + return 0, tcpip.ErrNoPortAvailable +} + +// IsPortAvailable tests if the given port is available on all given protocols. +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, + addr tcpip.Address, port uint16) bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.isPortAvailableLocked(networks, transport, addr, port) +} + +// isPortAvailableLocked 根据参数判断该端口号是否已经被占用了 +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, + addr tcpip.Address, port uint16) bool { + for _, network := range networks { + desc := portDescriptor{network, transport, port} + if addrs, ok := s.allocatedPorts[desc]; ok { + if !addrs.isAvailable(addr) { + return false + } + } + } + return true +} + +// ReservePort 将端口和IP地址绑定在一起,这样别的程序就无法使用已经被绑定的端口。 +// 如果传入的端口不为0,那么会尝试绑定该端口,若该端口没有被占用,那么绑定成功。 +// 如果传人的端口等于0,那么就是告诉协议栈自己分配端口,端口管理器就会随机返回一个端口。 +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, + addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) { + s.mu.Lock() + defer s.mu.Unlock() + + // If a port is specified, just try to reserve it for all network + // protocols. + if port != 0 { + if !s.reserveSpecificPort(networks, transport, addr, port) { + return 0, tcpip.ErrPortInUse + } + reservedPort = port + log.Printf("new transport: %d, port: %d", transport, reservedPort) + return + } + + // 随机分配一个未占用的端口 + reservedPort, err = s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + return s.reserveSpecificPort(networks, transport, addr, p), nil + }) + log.Printf("new transport: %d, port: %d", transport, reservedPort) + return +} + +// reserveSpecificPort 尝试根据协议号和IP地址绑定一个端口 +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, + addr tcpip.Address, port uint16) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port) { + return false + } + + // Reserve port on all network protocols. + // 根据给定的网络层协议号(IPV4或IPV6),绑定端口 + for _, network := range networks { + desc := portDescriptor{network, transport, port} + m, ok := s.allocatedPorts[desc] + if !ok { + m = make(bindAddresses) + s.allocatedPorts[desc] = m + } + // 注册该地址被绑定了 + m[addr] = struct{}{} + } + + return true +} + +// ReleasePort releases the reservation on a port/IP combination so that it can +// be reserved by other endpoints. +// 释放绑定的端口,以便别的程序复用。 +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, + addr tcpip.Address, port uint16) { + s.mu.Lock() + defer s.mu.Unlock() + + // 删除绑定关系 + for _, network := range networks { + desc := portDescriptor{network, transport, port} + if m, ok := s.allocatedPorts[desc]; ok { + log.Printf("delete transport: %d, port: %d", transport, port) + delete(m, addr) + if len(m) == 0 { + delete(s.allocatedPorts, desc) + } + } + } +} diff --git a/netstack/tcpip/ports/ports_test.go b/netstack/tcpip/ports/ports_test.go new file mode 100644 index 0000000..a736883 --- /dev/null +++ b/netstack/tcpip/ports/ports_test.go @@ -0,0 +1,145 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ports + +import ( + "testing" + + "tcpip/netstack/tcpip" +) + +const ( + fakeTransNumber tcpip.TransportProtocolNumber = 1 + fakeNetworkNumber tcpip.NetworkProtocolNumber = 2 + + fakeIPAddress = tcpip.Address("\x08\x08\x08\x08") + fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09") +) + +func TestPortReservation(t *testing.T) { + pm := NewPortManager() + net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber} + + for _, test := range []struct { + port uint16 + ip tcpip.Address + want *tcpip.Error + }{ + { + port: 80, + ip: fakeIPAddress, + want: nil, + }, + { + port: 80, + ip: fakeIPAddress1, + want: nil, + }, + { + /* N.B. Order of tests matters! */ + port: 80, + ip: anyIPAddress, + want: tcpip.ErrPortInUse, + }, + { + port: 22, + ip: anyIPAddress, + want: nil, + }, + { + port: 22, + ip: fakeIPAddress, + want: tcpip.ErrPortInUse, + }, + { + port: 0, + ip: fakeIPAddress, + want: nil, + }, + { + port: 0, + ip: fakeIPAddress, + want: nil, + }, + } { + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port) + if err != test.want { + t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want) + } + if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { + t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) + } + } + + // Release port 22 from any IP address, then try to reserve fake IP + // address on 22. + pm.ReleasePort(net, fakeTransNumber, anyIPAddress, 22) + + if port, err := pm.ReservePort(net, fakeTransNumber, fakeIPAddress, 22); port != 22 || err != nil { + t.Fatalf("ReservePort(.., .., .., %d) = (port %d, err %v), want (22, nil); failed to reserve port after it should have been released", 22, port, err) + } +} + +func TestPickEphemeralPort(t *testing.T) { + pm := NewPortManager() + customErr := &tcpip.Error{} + for _, test := range []struct { + name string + f func(port uint16) (bool, *tcpip.Error) + wantErr *tcpip.Error + wantPort uint16 + }{ + { + name: "no-port-available", + f: func(port uint16) (bool, *tcpip.Error) { + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + { + name: "port-tester-error", + f: func(port uint16) (bool, *tcpip.Error) { + return false, customErr + }, + wantErr: customErr, + }, + { + name: "only-port-16042-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port == FirstEphemeral+42 { + return true, nil + } + return false, nil + }, + wantPort: FirstEphemeral + 42, + }, + { + name: "only-port-under-16000-available", + f: func(port uint16) (bool, *tcpip.Error) { + if port < FirstEphemeral { + return true, nil + } + return false, nil + }, + wantErr: tcpip.ErrNoPortAvailable, + }, + } { + t.Run(test.name, func(t *testing.T) { + if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr { + t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr) + } + }) + } +} diff --git a/netstack/tcpip/seqnum/seqnum.go b/netstack/tcpip/seqnum/seqnum.go new file mode 100644 index 0000000..636ec40 --- /dev/null +++ b/netstack/tcpip/seqnum/seqnum.go @@ -0,0 +1,67 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package seqnum defines the types and methods for TCP sequence numbers such +// that they fit in 32-bit words and work properly when overflows occur. +package seqnum + +// Value represents the value of a sequence number. +type Value uint32 + +// Size represents the size (length) of a sequence number window. +type Size uint32 + +// LessThan checks if v is before w, i.e., v < w. +func (v Value) LessThan(w Value) bool { + return int32(v-w) < 0 +} + +// LessThanEq returns true if v==w or v is before i.e., v < w. +func (v Value) LessThanEq(w Value) bool { + if v == w { + return true + } + return v.LessThan(w) +} + +// InRange checks if v is in the range [a,b), i.e., a <= v < b. +func (v Value) InRange(a, b Value) bool { + return v-a < b-a +} + +// InWindow checks if v is in the window that starts at 'first' and spans 'size' +// sequence numbers. +func (v Value) InWindow(first Value, size Size) bool { + return v.InRange(first, first.Add(size)) +} + +// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). +func Overlap(a Value, b Size, x Value, y Size) bool { + return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) +} + +// Add calculates the sequence number following the [v, v+s) window. +func (v Value) Add(s Size) Value { + return v + Value(s) +} + +// Size calculates the size of the window defined by [v, w). +func (v Value) Size(w Value) Size { + return Size(w - v) +} + +// UpdateForward updates v such that it becomes v + s. +func (v *Value) UpdateForward(s Size) { + *v += Value(s) +} diff --git a/netstack/tcpip/stack/linkaddrcache.go b/netstack/tcpip/stack/linkaddrcache.go new file mode 100644 index 0000000..60ec1fd --- /dev/null +++ b/netstack/tcpip/stack/linkaddrcache.go @@ -0,0 +1,308 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "log" + "sync" + "time" + + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" +) + +const linkAddrCacheSize = 512 // max cache entries + +// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. +// +// The entries are stored in a ring buffer, oldest entry replaced first. +// +// This struct is safe for concurrent use. +type linkAddrCache struct { + // ageLimit is how long a cache entry is valid for. + ageLimit time.Duration + + // resolutionTimeout is the amount of time to wait for a link request to + // resolve an address. + resolutionTimeout time.Duration + + // resolutionAttempts is the number of times an address is attempted to be + // resolved before failing. + resolutionAttempts int + + mu sync.Mutex + cache map[tcpip.FullAddress]*linkAddrEntry + next int // array index of next available entry + entries [linkAddrCacheSize]linkAddrEntry +} + +// entryState controls the state of a single entry in the cache. +type entryState int + +const ( + // incomplete means that there is an outstanding request to resolve the + // address. This is the initial state. + incomplete entryState = iota + // ready means that the address has been resolved and can be used. + ready + // failed means that address resolution timed out and the address + // could not be resolved. + failed + // expired means that the cache entry has expired and the address must be + // resolved again. + expired +) + +// String implements Stringer. +func (s entryState) String() string { + switch s { + case incomplete: + return "incomplete" + case ready: + return "ready" + case failed: + return "failed" + case expired: + return "expired" + default: + return fmt.Sprintf("unknown(%d)", s) + } +} + +// A linkAddrEntry is an entry in the linkAddrCache. +// This struct is thread-compatible. +type linkAddrEntry struct { + addr tcpip.FullAddress + linkAddr tcpip.LinkAddress + expiration time.Time + s entryState + + // wakers is a set of waiters for address resolution result. Anytime + // state transitions out of 'incomplete' these waiters are notified. + wakers map[*sleep.Waker]struct{} + + done chan struct{} +} + +func (e *linkAddrEntry) state() entryState { + if e.s != expired && time.Now().After(e.expiration) { + // Force the transition to ensure waiters are notified. + e.changeState(expired) + } + return e.s +} + +func (e *linkAddrEntry) changeState(ns entryState) { + if e.s == ns { + return + } + + // Validate state transition. + switch e.s { + case incomplete: + // All transitions are valid. + case ready, failed: + if ns != expired { + panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) + } + case expired: + // Terminal state. + panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) + default: + panic(fmt.Sprintf("invalid state: %s", e.s)) + } + + // Notify whoever is waiting on address resolution when transitioning + // out of 'incomplete'. + if e.s == incomplete { + for w := range e.wakers { + w.Assert() + } + e.wakers = nil + if e.done != nil { + close(e.done) + } + } + e.s = ns +} + +func (e *linkAddrEntry) addWaker(w *sleep.Waker) { + e.wakers[w] = struct{}{} +} + +func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { + delete(e.wakers, w) +} + +// add adds a k -> v mapping to the cache. +func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { + log.Printf("add link cache: %v-%v", k, v) + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.cache[k] + if ok { + s := entry.state() + if s != expired && entry.linkAddr == v { + // Disregard repeated calls. + return + } + // Check if entry is waiting for address resolution. + if s == incomplete { + entry.linkAddr = v + } else { + // Otherwise create a new entry to replace it. + entry = c.makeAndAddEntry(k, v) + } + } else { + entry = c.makeAndAddEntry(k, v) + } + + entry.changeState(ready) +} + +// makeAndAddEntry is a helper function to create and add a new +// entry to the cache map and evict older entry as needed. +func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { + // Take over the next entry. + entry := &c.entries[c.next] + if c.cache[entry.addr] == entry { + delete(c.cache, entry.addr) + } + + // Mark the soon-to-be-replaced entry as expired, just in case there is + // someone waiting for address resolution on it. + entry.changeState(expired) + + *entry = linkAddrEntry{ + addr: k, + linkAddr: v, + expiration: time.Now().Add(c.ageLimit), + wakers: make(map[*sleep.Waker]struct{}), + done: make(chan struct{}), + } + + c.cache[k] = entry + c.next = (c.next + 1) % len(c.entries) + return entry +} + +// get reports any known link address for k. +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { + log.Printf("link addr get linkRes: %#v, addr: %+v", linkRes, k) + if linkRes != nil { + if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { + return addr, nil, nil + } + } + + c.mu.Lock() + defer c.mu.Unlock() + // 尝试从缓存中得到MAC地址 + if entry, ok := c.cache[k]; ok { + switch s := entry.state(); s { + case expired: + case ready: + return entry.linkAddr, nil, nil + case failed: + return "", nil, tcpip.ErrNoLinkAddress + case incomplete: + // Address resolution is still in progress. + entry.addWaker(waker) + return "", entry.done, tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } + } + + if linkRes == nil { + return "", nil, tcpip.ErrNoLinkAddress + } + + // Add 'incomplete' entry in the cache to mark that resolution is in progress. + e := c.makeAndAddEntry(k, "") + e.addWaker(waker) + + go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) + + return "", e.done, tcpip.ErrWouldBlock +} + +// removeWaker removes a waker previously added through get(). +func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { + c.mu.Lock() + defer c.mu.Unlock() + + if entry, ok := c.cache[k]; ok { + entry.removeWaker(waker) + } +} + +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) { + for i := 0; ; i++ { + // Send link request, then wait for the timeout limit and check + // whether the request succeeded. + linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + + select { + case <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(k, i); stop { + return + } + case <-done: + return + } + } +} + +// checkLinkRequest checks whether previous attempt to resolve address has succeeded +// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request +// can stop, false if another request should be sent. +func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.cache[k] + if !ok { + // Entry was evicted from the cache. + return true + } + + switch s := entry.state(); s { + case ready, failed, expired: + // Entry was made ready by resolver or failed. Either way we're done. + return true + case incomplete: + if attempt+1 >= c.resolutionAttempts { + // Max number of retries reached, mark entry as failed. + entry.changeState(failed) + return true + } + // No response yet, need to send another ARP request. + return false + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } +} + +func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { + return &linkAddrCache{ + ageLimit: ageLimit, + resolutionTimeout: resolutionTimeout, + resolutionAttempts: resolutionAttempts, + cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), + } +} diff --git a/netstack/tcpip/stack/linkaddrcache_test.go b/netstack/tcpip/stack/linkaddrcache_test.go new file mode 100644 index 0000000..e6882c3 --- /dev/null +++ b/netstack/tcpip/stack/linkaddrcache_test.go @@ -0,0 +1,266 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "fmt" + "sync" + "testing" + "time" + + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" +) + +type testaddr struct { + addr tcpip.FullAddress + linkAddr tcpip.LinkAddress +} + +var testaddrs []testaddr + +type testLinkAddressResolver struct { + cache *linkAddrCache + delay time.Duration +} + +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { + go func() { + if r.delay > 0 { + time.Sleep(r.delay) + } + r.fakeRequest(addr) + }() + return nil +} + +func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { + for _, ta := range testaddrs { + if ta.addr.Addr == addr { + r.cache.add(ta.addr, ta.linkAddr) + break + } + } +} + +func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "broadcast" { + return "mac_broadcast", true + } + return "", false +} + +func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return 1 +} + +func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { + w := sleep.Waker{} + s := sleep.Sleeper{} + s.AddWaker(&w, 123) + defer s.Done() + + for { + if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + return got, err + } + s.Fetch(true) + } +} + +func init() { + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + testaddrs = append(testaddrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } +} + +func TestCacheOverflow(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + for i := len(testaddrs) - 1; i >= 0; i-- { + e := testaddrs[i] + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + } + } + // Expect to find at least half of the most recent entries. + for i := 0; i < linkAddrCacheSize/2; i++ { + e := testaddrs[i] + got, _, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + } + } + // The earliest entries should no longer be in the cache. + for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { + e := testaddrs[i] + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) + } + } +} + +func TestCacheConcurrent(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + + var wg sync.WaitGroup + for r := 0; r < 16; r++ { + wg.Add(1) + go func() { + for _, e := range testaddrs { + c.add(e.addr, e.linkAddr) + c.get(e.addr, nil, "", nil, nil) // make work for gotsan + } + wg.Done() + }() + } + wg.Wait() + + // All goroutines add in the same order and add more values than + // can fit in the cache, so our eviction strategy requires that + // the last entry be present and the first be missing. + e := testaddrs[len(testaddrs)-1] + got, _, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + + e = testaddrs[0] + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +func TestCacheAgeLimit(t *testing.T) { + c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + e := testaddrs[0] + c.add(e.addr, e.linkAddr) + time.Sleep(50 * time.Millisecond) + if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +func TestCacheReplace(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + e := testaddrs[0] + l2 := e.linkAddr + "2" + c.add(e.addr, e.linkAddr) + got, _, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + + c.add(e.addr, l2) + got, _, err = c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != l2 { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) + } +} + +func TestCacheResolution(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) + linkRes := &testLinkAddressResolver{cache: c} + for i, ta := range testaddrs { + got, err := getBlocking(c, ta.addr, linkRes) + if err != nil { + t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) + } + if got != ta.linkAddr { + t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) + } + } + + // Check that after resolved, address stays in the cache and never returns WouldBlock. + for i := 0; i < 10; i++ { + e := testaddrs[len(testaddrs)-1] + got, _, err := c.get(e.addr, linkRes, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + } +} + +func TestCacheResolutionFailed(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) + linkRes := &testLinkAddressResolver{cache: c} + + // First, sanity check that resolution is working... + e := testaddrs[0] + got, err := getBlocking(c, e.addr, linkRes) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + + e.addr.Addr += "2" + if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +func TestCacheResolutionTimeout(t *testing.T) { + resolverDelay := 50 * time.Millisecond + expiration := resolverDelay / 2 + c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) + linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} + + e := testaddrs[0] + if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +// TestStaticResolution checks that static link addresses are resolved immediately and don't +// send resolution requests. +func TestStaticResolution(t *testing.T) { + c := newLinkAddrCache(1<<63-1, time.Millisecond, 1) + linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute} + + addr := tcpip.Address("broadcast") + want := tcpip.LinkAddress("mac_broadcast") + got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err) + } + if got != want { + t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) + } +} diff --git a/netstack/tcpip/stack/nic.go b/netstack/tcpip/stack/nic.go new file mode 100644 index 0000000..e33a7cd --- /dev/null +++ b/netstack/tcpip/stack/nic.go @@ -0,0 +1,668 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "strings" + "sync" + "sync/atomic" + + "tcpip/netstack/ilist" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" +) + +// NIC represents a "network interface card" to which the networking stack is +// attached. +// 代表一个网卡对象 +type NIC struct { + stack *Stack + // 每个网卡的唯一标识号 + id tcpip.NICID + // 网卡名,可有可无 + name string + // 链路层端 + linkEP LinkEndpoint + + // 传输层的解复用 + demux *transportDemuxer + + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + // 网络层端的记录 + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + // 子网的记录 + subnets []tcpip.Subnet +} + +// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior. +type PrimaryEndpointBehavior int + +const ( + // CanBePrimaryEndpoint indicates the endpoint can be used as a primary + // endpoint for new connections with no local address. This is the + // default when calling NIC.AddAddress. + CanBePrimaryEndpoint PrimaryEndpointBehavior = iota + + // FirstPrimaryEndpoint indicates the endpoint should be the first + // primary endpoint considered. If there are multiple endpoints with + // this behavior, the most recently-added one will be first. + FirstPrimaryEndpoint + + // NeverPrimaryEndpoint indicates the endpoint should never be a + // primary endpoint. + NeverPrimaryEndpoint +) + +// 根据参数新建一个NIC +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { + return &NIC{ + stack: stack, + id: id, + name: name, + linkEP: ep, + demux: newTransportDemuxer(stack), + primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + } +} + +// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it +// to start delivering packets. +// attachLinkEndpoint添加当前的NIC到链路层设备,激活该NIC +func (n *NIC) attachLinkEndpoint() { + n.linkEP.Attach(n) +} + +// setPromiscuousMode enables or disables promiscuous mode. +// 设备网卡为混杂模式 +func (n *NIC) setPromiscuousMode(enable bool) { + n.mu.Lock() + n.promiscuous = enable + n.mu.Unlock() +} + +// 判断网卡是否开启混杂模式 +func (n *NIC) isPromiscuousMode() bool { + n.mu.RLock() + rv := n.promiscuous + n.mu.RUnlock() + return rv +} + +// setSpoofing enables or disables address spoofing. +func (n *NIC) setSpoofing(enable bool) { + n.mu.Lock() + n.spoofing = enable + n.mu.Unlock() +} + +func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) { + n.mu.RLock() + defer n.mu.RUnlock() + + var r *referencedNetworkEndpoint + + // Check for a primary endpoint. + if list, ok := n.primary[protocol]; ok { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + if ref.holdsInsertRef && ref.tryIncRef() { + r = ref + break + } + } + + } + + // If no primary endpoints then check for other endpoints. + if r == nil { + for _, ref := range n.endpoints { + if ref.holdsInsertRef && ref.tryIncRef() { + r = ref + break + } + } + } + + if r == nil { + return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress + } + + address := r.ep.ID().LocalAddress + r.decRef() + + // Find the least-constrained matching subnet for the address, if one + // exists, and return it. + var subnet tcpip.Subnet + for _, s := range n.subnets { + if s.Contains(address) && !subnet.Contains(s.ID()) { + subnet = s + } + } + return address, subnet, nil +} + +// primaryEndpoint returns the primary endpoint of n for the given network +// protocol. +// 根据网络层协议号找到对应的网络层端 +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { + n.mu.RLock() + defer n.mu.RUnlock() + + list := n.primary[protocol] + if list == nil { + return nil + } + + for e := list.Front(); e != nil; e = e.Next() { + r := e.(*referencedNetworkEndpoint) + // TODO: allow broadcast address when SO_BROADCAST is set. + switch r.ep.ID().LocalAddress { + case header.IPv4Broadcast, header.IPv4Any: + continue + } + if r.tryIncRef() { + return r + } + } + + return nil +} + +// findEndpoint finds the endpoint, if any, with the given address. +// 根据address参数查找对应的网络层端 +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, + peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + id := NetworkEndpointID{address} + + n.mu.RLock() + ref := n.endpoints[id] + if ref != nil && !ref.tryIncRef() { + ref = nil + } + spoofing := n.spoofing + n.mu.RUnlock() + + if ref != nil || !spoofing { + return ref + } + + // Try again with the lock in exclusive mode. If we still can't get the + // endpoint, create a new "temporary" endpoint. It will only exist while + // there's a route through it. + n.mu.Lock() + ref = n.endpoints[id] + if ref == nil || !ref.tryIncRef() { + ref, _ = n.addAddressLocked(protocol, address, peb, true) + if ref != nil { + ref.holdsInsertRef = false + } + } + n.mu.Unlock() + return ref +} + +// 在NIC上添加addr地址,注册和初始化网络层协议 +// 相当于给网卡添加ip地址 +func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, + peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { + // 查看是否支持该协议,若不支持则返回错误 + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + // Create the new network endpoint. + // 比如netProto为ipv4,会调用ipv4.NewEndpoint,新建一个网络层端 + ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP) + if err != nil { + return nil, err + } + + // 获取网络层端的id,其实就是ip地址 + id := *ep.ID() + if ref, ok := n.endpoints[id]; ok { + // 不是替换,且该id否存在,返回错误 + if !replace { + return nil, tcpip.ErrDuplicateAddress + } + + n.removeEndpointLocked(ref) + } + + ref := &referencedNetworkEndpoint{ + refs: 1, + ep: ep, + nic: n, + protocol: protocol, + holdsInsertRef: true, + } + + // Set up cache if link address resolution exists for this protocol. + if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if _, ok := n.stack.linkAddrResolvers[protocol]; ok { + ref.linkCache = n.stack + } + } + + // 注册该网络端 + n.endpoints[id] = ref + + l, ok := n.primary[protocol] + if !ok { + l = &ilist.List{} + n.primary[protocol] = l + } + + switch peb { + case CanBePrimaryEndpoint: + l.PushBack(ref) + case FirstPrimaryEndpoint: + l.PushFront(ref) + } + + return ref, nil +} + +// AddAddress adds a new address to n, so that it starts accepting packets +// targeted at the given address (and network protocol). +// AddAddress向n添加一个新地址,以便它开始接受针对给定地址(和网络协议)的数据包 +// 最终会生成一个网络端注册到 endpoints +func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint) +} + +// AddAddressWithOptions is the same as AddAddress, but allows you to specify +// whether the new endpoint can be primary or not. +// AddAddressWithOptions与AddAddress相同,但允许您指定新端点是否可以是主端点。 +func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, + peb PrimaryEndpointBehavior) *tcpip.Error { + // Add the endpoint. + n.mu.Lock() + _, err := n.addAddressLocked(protocol, addr, peb, false) + n.mu.Unlock() + + return err +} + +// Addresses returns the addresses associated with this NIC. +// 获取网卡已分配的协议和对应的地址 +func (n *NIC) Addresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) + for nid, ep := range n.endpoints { + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: ep.protocol, + Address: nid.LocalAddress, + }) + } + return addrs +} + +// AddSubnet adds a new subnet to n, so that it starts accepting packets +// targeted at the given address and network protocol. +// AddSubnet向n添加一个新子网,以便它开始接受针对给定地址和网络协议的数据包。 +func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { + n.mu.Lock() + n.subnets = append(n.subnets, subnet) + n.mu.Unlock() +} + +// RemoveSubnet removes the given subnet from n. +// 从n中删除一个子网 +func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { + n.mu.Lock() + + // Use the same underlying array. + tmp := n.subnets[:0] + for _, sub := range n.subnets { + if sub != subnet { + tmp = append(tmp, sub) + } + } + n.subnets = tmp + + n.mu.Unlock() +} + +// ContainsSubnet reports whether this NIC contains the given subnet. +// 判断 subnet 这个子网是否在该网卡下 +func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { + for _, s := range n.Subnets() { + if s == subnet { + return true + } + } + return false +} + +// Subnets returns the Subnets associated with this NIC. +// 获取该网卡的所有子网 +func (n *NIC) Subnets() []tcpip.Subnet { + n.mu.RLock() + defer n.mu.RUnlock() + sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + for nid := range n.endpoints { + sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) + if err != nil { + // This should never happen as the mask has been carefully crafted to + // match the address. + panic("Invalid endpoint subnet: " + err.Error()) + } + sns = append(sns, sn) + } + return append(sns, n.subnets...) +} + +// 删除一个网络端 +func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { + id := *r.ep.ID() + + // Nothing to do if the reference has already been replaced with a + // different one. + if n.endpoints[id] != r { + return + } + + if r.holdsInsertRef { + panic("Reference count dropped to zero before being removed") + } + + delete(n.endpoints, id) + wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front() + if wasInList { + n.primary[r.protocol].Remove(r) + } + + r.ep.Close() +} + +func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { + n.mu.Lock() + n.removeEndpointLocked(r) + n.mu.Unlock() +} + +// RemoveAddress removes an address from n. +func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + r := n.endpoints[NetworkEndpointID{addr}] + if r == nil || !r.holdsInsertRef { + n.mu.Unlock() + return tcpip.ErrBadLocalAddress + } + + r.holdsInsertRef = false + n.mu.Unlock() + + r.decRef() + + return nil +} + +// DeliverNetworkPacket finds the appropriate network protocol endpoint and +// hands the packet over for further processing. This function is called when +// the NIC receives a packet from the physical interface. +// Note that the ownership of the slice backing vv is retained by the caller. +// This rule applies only to the slice itself, not to the items of the slice; +// the ownership of the items is not retained by the caller. +// DeliverNetworkPacket 找到适当的网络协议端点并将数据包交给它进一步处理。 +// 当NIC从物理接口接收数据包时,将调用此函数。比如protocol是arp协议号, 那么会找到arp.HandlePacket来处理数据报。 +// protocol是ipv4协议号, 那么会找到ipv4.HandlePacket来处理数据报。 +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, + protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.stack.stats.UnknownProtocolRcvdPackets.Increment() + return + } + + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { + n.stack.stats.IP.PacketsReceived.Increment() + } + + if len(vv.First()) < netProto.MinimumPacketSize() { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + + src, dst := netProto.ParseAddresses(vv.First()) + + // 根据网络协议和数据包的目的地址,找到网络端 + // 然后将数据包分发给网络层 + if ref := n.getRef(protocol, dst); ref != nil { + r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) + r.RemoteLinkAddress = remoteLinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() + return + } + + // This NIC doesn't care about the packet. Find a NIC that cares about the + // packet and forward it to the NIC. + // + // TODO: Should we be forwarding the packet even if promiscuous? + if n.stack.Forwarding() { + r, err := n.stack.FindRoute(0, "", dst, protocol) + if err != nil { + n.stack.stats.IP.InvalidAddressesReceived.Increment() + return + } + defer r.Release() + + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remoteLinkAddr + + // Found a NIC. + n := r.ref.nic + n.mu.RLock() + ref, ok := n.endpoints[NetworkEndpointID{dst}] + n.mu.RUnlock() + if ok && ref.tryIncRef() { + ref.ep.HandlePacket(&r, vv) + ref.decRef() + } else { + // n doesn't have a destination endpoint. + // Send the packet out of n. + hdr := buffer.NewPrependableFromView(vv.First()) + vv.RemoveFirst() + n.linkEP.WritePacket(&r, hdr, vv, protocol) + } + return + } + + n.stack.stats.IP.InvalidAddressesReceived.Increment() +} + +// 根据协议类型和目标地址,找出关联的Endpoint +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + id := NetworkEndpointID{dst} + + n.mu.RLock() + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + + promiscuous := n.promiscuous + // Check if the packet is for a subnet this NIC cares about. + if !promiscuous { + for _, sn := range n.subnets { + if sn.Contains(dst) { + promiscuous = true + break + } + } + } + n.mu.RUnlock() + if promiscuous { + // Try again with the lock in exclusive mode. If we still can't + // get the endpoint, create a new "temporary" one. It will only + // exist while there's a route through it. + n.mu.Lock() + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + n.mu.Unlock() + return ref + } + ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true) + n.mu.Unlock() + if err == nil { + ref.holdsInsertRef = false + return ref + } + } + + return nil +} + +// DeliverTransportPacket delivers the packets to the appropriate transport +// protocol endpoint. +// DeliverTransportPacket 按照传输层协议号和传输层ID,将数据包分发给相应的传输层端 +// 比如 protocol=6表示为tcp协议,将会给相应的tcp端分发消息。 +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { + // 先查找协议栈是否注册了该传输层协议 + state, ok := n.stack.transportProtocols[protocol] + if !ok { + n.stack.stats.UnknownProtocolRcvdPackets.Increment() + return + } + + transProto := state.proto + // 如果报文长度比该协议最小报文长度还小,那么丢弃它 + if len(vv.First()) < transProto.MinimumPacketSize() { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + + // 解析报文得到源端口和目的端口 + srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + if err != nil { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + + id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} + // 调用分流器,根据传输层协议和传输层id分发数据报文 + if n.demux.deliverPacket(r, protocol, vv, id) { + return + } + if n.stack.demux.deliverPacket(r, protocol, vv, id) { + return + } + + // Try to deliver to per-stack default handler. + if state.defaultHandler != nil { + if state.defaultHandler(r, id, vv) { + return + } + } + + // We could not find an appropriate destination for this packet, so + // deliver it to the global handler. + if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + n.stack.stats.MalformedRcvdPackets.Increment() + } +} + +// DeliverTransportControlPacket delivers control packets to the appropriate +// transport protocol endpoint. +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, + trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) { + state, ok := n.stack.transportProtocols[trans] + if !ok { + return + } + + transProto := state.proto + + // ICMPv4 only guarantees that 8 bytes of the transport protocol will + // be present in the payload. We know that the ports are within the + // first 8 bytes for all known transport protocols. + if len(vv.First()) < 8 { + return + } + + srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + if err != nil { + return + } + + id := TransportEndpointID{srcPort, local, dstPort, remote} + if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + return + } + if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + return + } +} + +// ID returns the identifier of n. +func (n *NIC) ID() tcpip.NICID { + return n.id +} + +type referencedNetworkEndpoint struct { + ilist.Entry + refs int32 + ep NetworkEndpoint + nic *NIC + protocol tcpip.NetworkProtocolNumber + + // linkCache is set if link address resolution is enabled for this + // protocol. Set to nil otherwise. + linkCache LinkAddressCache + + // holdsInsertRef is protected by the NIC's mutex. It indicates whether + // the reference count is biased by 1 due to the insertion of the + // endpoint. It is reset to false when RemoveAddress is called on the + // NIC. + holdsInsertRef bool +} + +// decRef decrements the ref count and cleans up the endpoint once it reaches +// zero. +func (r *referencedNetworkEndpoint) decRef() { + if atomic.AddInt32(&r.refs, -1) == 0 { + r.nic.removeEndpoint(r) + } +} + +// incRef increments the ref count. It must only be called when the caller is +// known to be holding a reference to the endpoint, otherwise tryIncRef should +// be used. +func (r *referencedNetworkEndpoint) incRef() { + atomic.AddInt32(&r.refs, 1) +} + +// tryIncRef attempts to increment the ref count from n to n+1, but only if n is +// not zero. That is, it will increment the count if the endpoint is still +// alive, and do nothing if it has already been clean up. +func (r *referencedNetworkEndpoint) tryIncRef() bool { + for { + v := atomic.LoadInt32(&r.refs) + if v == 0 { + return false + } + + if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { + return true + } + } +} diff --git a/netstack/tcpip/stack/registration.go b/netstack/tcpip/stack/registration.go new file mode 100644 index 0000000..6a7275a --- /dev/null +++ b/netstack/tcpip/stack/registration.go @@ -0,0 +1,373 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "sync" + + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/waiter" +) + +// NetworkEndpointID is the identifier of a network layer protocol endpoint. +// Currently the local address is sufficient because all supported protocols +// (i.e., IPv4 and IPv6) have different sizes for their addresses. +type NetworkEndpointID struct { + LocalAddress tcpip.Address +} + +// TransportEndpointID is the identifier of a transport layer protocol endpoint. +// +// +stateify savable +type TransportEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// ControlType is the type of network control message. +type ControlType int + +// The following are the allowed values for ControlType values. +const ( + ControlPacketTooBig ControlType = iota + ControlPortUnreachable + ControlUnknown +) + +// TransportEndpoint is the interface that needs to be implemented by transport +// protocol (e.g., tcp, udp) endpoints that can handle packets. +type TransportEndpoint interface { + // HandlePacket is called by the stack when new packets arrive to + // this transport endpoint. + HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) + + // HandleControlPacket is called by the stack when new control (e.g., + // ICMP) packets arrive to this transport endpoint. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) +} + +// TransportProtocol is the interface that needs to be implemented by transport +// protocols (e.g., tcp, udp) that want to be part of the networking stack. +type TransportProtocol interface { + // Number returns the transport protocol number. + Number() tcpip.TransportProtocolNumber + + // NewEndpoint creates a new endpoint of the transport protocol. + NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + + // MinimumPacketSize returns the minimum valid packet size of this + // transport protocol. The stack automatically drops any packets smaller + // than this targeted at this protocol. + MinimumPacketSize() int + + // ParsePorts returns the source and destination ports stored in a + // packet of this protocol. + ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) + + // HandleUnknownDestinationPacket handles packets targeted at this + // protocol but that don't match any existing endpoint. For example, + // it is targeted at a port that have no listeners. + // + // The return value indicates whether the packet was well-formed (for + // stats purposes only). + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool + + // SetOption allows enabling/disabling protocol specific features. + // SetOption returns an error if the option is not supported or the + // provided option value is invalid. + SetOption(option interface{}) *tcpip.Error + + // Option allows retrieving protocol specific option values. + // Option returns an error if the option is not supported or the + // provided option value is invalid. + Option(option interface{}) *tcpip.Error +} + +// TransportDispatcher contains the methods used by the network stack to deliver +// packets to the appropriate transport endpoint after it has been handled by +// the network layer. +type TransportDispatcher interface { + // DeliverTransportPacket delivers packets to the appropriate + // transport protocol endpoint. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) + + // DeliverTransportControlPacket delivers control packets to the + // appropriate transport protocol endpoint. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) +} + +// NetworkEndpoint is the interface that needs to be implemented by endpoints +// of network layer protocols (e.g., ipv4, ipv6). +// NetworkEndpoint是需要由网络层协议(例如,ipv4,ipv6)的端点实现的接口。 +type NetworkEndpoint interface { + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) + // for this endpoint. + DefaultTTL() uint8 + + // MTU is the maximum transmission unit for this endpoint. This is + // generally calculated as the MTU of the underlying data link endpoint + // minus the network endpoint max header length. + MTU() uint32 + + // Capabilities returns the set of capabilities supported by the + // underlying link-layer endpoint. + Capabilities() LinkEndpointCapabilities + + // MaxHeaderLength returns the maximum size the network (and lower + // level layers combined) headers can have. Higher levels use this + // information to reserve space in the front of the packets they're + // building. + MaxHeaderLength() uint16 + + // WritePacket writes a packet to the given destination address and + // protocol. + WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error + + // ID returns the network protocol endpoint ID. + ID() *NetworkEndpointID + + // NICID returns the id of the NIC this endpoint belongs to. + NICID() tcpip.NICID + + // HandlePacket is called by the link layer when new packets arrive to + // this network endpoint. + HandlePacket(r *Route, vv buffer.VectorisedView) + + // Close is called when the endpoint is reomved from a stack. + Close() +} + +// NetworkProtocol is the interface that needs to be implemented by network +// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack. +type NetworkProtocol interface { + // Number returns the network protocol number. + Number() tcpip.NetworkProtocolNumber + + // MinimumPacketSize returns the minimum valid packet size of this + // network protocol. The stack automatically drops any packets smaller + // than this targeted at this protocol. + MinimumPacketSize() int + + // ParsePorts returns the source and destination addresses stored in a + // packet of this protocol. + ParseAddresses(v buffer.View) (src, dst tcpip.Address) + + // NewEndpoint creates a new endpoint of this protocol. + NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + + // SetOption allows enabling/disabling protocol specific features. + // SetOption returns an error if the option is not supported or the + // provided option value is invalid. + SetOption(option interface{}) *tcpip.Error + + // Option allows retrieving protocol specific option values. + // Option returns an error if the option is not supported or the + // provided option value is invalid. + Option(option interface{}) *tcpip.Error +} + +// NetworkDispatcher contains the methods used by the network stack to deliver +// packets to the appropriate network endpoint after it has been handled by +// the data link layer. +type NetworkDispatcher interface { + // DeliverNetworkPacket finds the appropriate network protocol + // endpoint and hands the packet over for further processing. + DeliverNetworkPacket(linkEP LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) +} + +// LinkEndpointCapabilities is the type associated with the capabilities +// supported by a link-layer endpoint. It is a set of bitfields. +type LinkEndpointCapabilities uint + +// The following are the supported link endpoint capabilities. +const ( + CapabilityChecksumOffload LinkEndpointCapabilities = 1 << iota + CapabilityResolutionRequired + CapabilitySaveRestore + CapabilityDisconnectOk + CapabilityLoopback +) + +// LinkEndpoint is the interface implemented by data link layer protocols (e.g., +// ethernet, loopback, raw) and used by network layer protocols to send packets +// out through the implementer's data link endpoint. +// LinkEndpoint是由数据链路层协议(例如,以太网,环回,原始)实现的接口,并由网络层协议用于通过实施者的数据链路端点发送数据包。 +type LinkEndpoint interface { + // MTU is the maximum transmission unit for this endpoint. This is + // usually dictated by the backing physical network; when such a + // physical network doesn't exist, the limit is generally 64k, which + // includes the maximum size of an IP packet. + // MTU是此端点的最大传输单位。这通常由支持物理网络决定; + // 当这种物理网络不存在时,限制通常为64k,其中包括IP数据包的最大大小。 + MTU() uint32 + + // Capabilities returns the set of capabilities supported by the + // endpoint. + // Capabilities返回链路层端点支持的功能集。 + Capabilities() LinkEndpointCapabilities + + // MaxHeaderLength returns the maximum size the data link (and + // lower level layers combined) headers can have. Higher levels use this + // information to reserve space in the front of the packets they're + // building. + // MaxHeaderLength 返回数据链接(和较低级别的图层组合)标头可以具有的最大大小。 + // 较高级别使用此信息来保留它们正在构建的数据包前面预留空间。 + MaxHeaderLength() uint16 + + // LinkAddress returns the link address (typically a MAC) of the + // link endpoint. + // 本地链路层地址 + LinkAddress() tcpip.LinkAddress + + // WritePacket writes a packet with the given protocol through the given + // route. + // + // To participate in transparent bridging, a LinkEndpoint implementation + // should call eth.Encode with header.EthernetFields.SrcAddr set to + // r.LocalLinkAddress if it is provided. + // WritePacket通过给定的路由写入具有给定协议的数据包。 + // 要参与透明桥接,LinkEndpoint实现应调用eth.Encode,并将header.EthernetFields.SrcAddr设置为r.LocalLinkAddress(如果已提供)。 + WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + + // Attach attaches the data link layer endpoint to the network-layer + // dispatcher of the stack. + // Attach 将数据链路层端点附加到协议栈的网络层调度程序。 + Attach(dispatcher NetworkDispatcher) + + // IsAttached returns whether a NetworkDispatcher is attached to the + // endpoint. + // 是否已经添加了网络层调度器 + IsAttached() bool +} + +// A LinkAddressResolver is an extension to a NetworkProtocol that +// can resolve link addresses. +type LinkAddressResolver interface { + // LinkAddressRequest sends a request for the LinkAddress of addr. + // The request is sent on linkEP with localAddr as the source. + // + // A valid response will cause the discovery protocol's network + // endpoint to call AddLinkAddress. + LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + + // ResolveStaticAddress attempts to resolve address without sending + // requests. It either resolves the name immediately or returns the + // empty LinkAddress. + // + // It can be used to resolve broadcast addresses for example. + ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) + + // LinkAddressProtocol returns the network protocol of the + // addresses this this resolver can resolve. + LinkAddressProtocol() tcpip.NetworkProtocolNumber +} + +// A LinkAddressCache caches link addresses. +type LinkAddressCache interface { + // CheckLocalAddress determines if the given local address exists, and if it + // does not exist. + CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + + // AddLinkAddress adds a link address to the cache. + AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + + // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). + // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver + // registered with the network protocol, the cache attempts to resolve the address + // and returns ErrWouldBlock. Waker is notified when address resolution is + // complete (success or not). + // + // If address resolution is required, ErrNoLinkAddress and a notification channel is + // returned for the top level caller to block. Channel is closed once address resolution + // is complete (success or not). + GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) + + // RemoveWaker removes a waker that has been added in GetLinkAddress(). + RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) +} + +// TransportProtocolFactory functions are used by the stack to instantiate +// transport protocols. +type TransportProtocolFactory func() TransportProtocol + +// NetworkProtocolFactory provides methods to be used by the stack to +// instantiate network protocols. +type NetworkProtocolFactory func() NetworkProtocol + +var ( + // 传输层协议的注册储存结构 + transportProtocols = make(map[string]TransportProtocolFactory) + // 网络层协议的注册存储结构 + networkProtocols = make(map[string]NetworkProtocolFactory) + + linkEPMu sync.RWMutex + nextLinkEndpointID tcpip.LinkEndpointID = 1 + linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) +) + +// RegisterTransportProtocolFactory registers a new transport protocol factory +// with the stack so that it becomes available to users of the stack. This +// function is intended to be called by init() functions of the protocols. +// RegisterTransportProtocolFactory 在协议栈中注册一个新的传输协议工厂, +// 以便协议栈可以使用它。此函数在由协议的init函数中使用。 +func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { + transportProtocols[name] = p +} + +// RegisterNetworkProtocolFactory registers a new network protocol factory with +// the stack so that it becomes available to users of the stack. This function +// is intended to be called by init() functions of the protocols. +// RegisterNetworkProtocolFactory 在协议栈中注册一个新的网络协议工厂, +// 以便协议栈可以使用它。此函数在由协议的init函数中使用。 +func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { + networkProtocols[name] = p +} + +// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an +// ID that can be used to refer to it. +// 注册一个链路层设备 +func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { + linkEPMu.Lock() + defer linkEPMu.Unlock() + + v := nextLinkEndpointID + nextLinkEndpointID++ + + linkEndpoints[v] = linkEP + + return v +} + +// FindLinkEndpoint finds the link endpoint associated with the given ID. +// 根据ID找到网卡设备 +func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { + linkEPMu.RLock() + defer linkEPMu.RUnlock() + + return linkEndpoints[id] +} diff --git a/netstack/tcpip/stack/route.go b/netstack/tcpip/stack/route.go new file mode 100644 index 0000000..496537c --- /dev/null +++ b/netstack/tcpip/stack/route.go @@ -0,0 +1,186 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" +) + +// Route represents a route through the networking stack to a given destination. +// 贯穿整个协议栈的路由,也就是在链路层和网络层都可以路由 +// 如果目标地址是链路层地址,那么在链路层路由, +// 如果目标地址是网络层地址,那么在网络层路由。 +type Route struct { + // 远端网络层地址,ipv4或者ipv6地址 + RemoteAddress tcpip.Address + + // RemoteLinkAddress is the link-layer (MAC) address of the + // final destination of the route. + // 远端网卡MAC地址 + RemoteLinkAddress tcpip.LinkAddress + + // LocalAddress is the local address where the route starts. + // 本地网络层地址,ipv4或者ipv6地址 + LocalAddress tcpip.Address + + // LocalLinkAddress is the link-layer (MAC) address of the + // where the route starts. + // 本地网卡MAC地址 + LocalLinkAddress tcpip.LinkAddress + + // NextHop is the next node in the path to the destination. + // 下一跳网络层地址 + NextHop tcpip.Address + + // NetProto is the network-layer protocol. + // 网络层协议号 + NetProto tcpip.NetworkProtocolNumber + + // ref a reference to the network endpoint through which the route + // starts. + // 相关的网络终端 + ref *referencedNetworkEndpoint +} + +// makeRoute initializes a new route. It takes ownership of the provided +// reference to a network endpoint. +// 根据参数新建一个路由,并关联一个网络层端 +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, + localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route { + return Route{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: localLinkAddr, + RemoteAddress: remoteAddr, + ref: ref, + } +} + +// NICID returns the id of the NIC from which this route originates. +func (r *Route) NICID() tcpip.NICID { + return r.ref.ep.NICID() +} + +// MaxHeaderLength forwards the call to the network endpoint's implementation. +func (r *Route) MaxHeaderLength() uint16 { + return r.ref.ep.MaxHeaderLength() +} + +// Stats returns a mutable copy of current stats. +func (r *Route) Stats() tcpip.Stats { + return r.ref.nic.stack.Stats() +} + +// PseudoHeaderChecksum forwards the call to the network endpoint's +// implementation. +// udp或tcp伪首部校验和的计算 +func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 { + return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress) +} + +// Capabilities returns the link-layer capabilities of the route. +func (r *Route) Capabilities() LinkEndpointCapabilities { + return r.ref.ep.Capabilities() +} + +// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in +// case address resolution requires blocking, e.g. wait for ARP reply. Waker is +// notified when address resolution is complete (success or not). +// +// If address resolution is required, ErrNoLinkAddress and a notification channel is +// returned for the top level caller to block. Channel is closed once address resolution +// is complete (success or not). +// +// Resolve 如有必要,解决尝试解析链接地址的问题。如果地址解析需要阻塞,则返回ErrWouldBlock, +// 例如等待ARP回复。地址解析完成(成功与否)时通知Waker。 +// 如果需要地址解析,则返回ErrNoLinkAddress和通知通道,以阻止顶级调用者。 +// 地址解析完成后,通道关闭(不管成功与否)。 +func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { + if !r.IsResolutionRequired() { + // Nothing to do if there is no cache (which does the resolution on cache miss) or + // link address is already known. + return nil, nil + } + + nextAddr := r.NextHop + if nextAddr == "" { + // Local link address is already known. + if r.RemoteAddress == r.LocalAddress { + r.RemoteLinkAddress = r.LocalLinkAddress + return nil, nil + } + nextAddr = r.RemoteAddress + } + // 调用地址解析协议来解析IP地址 + linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + if err != nil { + return ch, err + } + r.RemoteLinkAddress = linkAddr + return nil, nil +} + +// RemoveWaker removes a waker that has been added in Resolve(). +func (r *Route) RemoveWaker(waker *sleep.Waker) { + nextAddr := r.NextHop + if nextAddr == "" { + nextAddr = r.RemoteAddress + } + r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) +} + +// IsResolutionRequired returns true if Resolve() must be called to resolve +// the link address before the this route can be written to. +func (r *Route) IsResolutionRequired() bool { + return r.ref.linkCache != nil && r.RemoteLinkAddress == "" +} + +// WritePacket writes the packet through the given route. +func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, + protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl) + if err == tcpip.ErrNoRoute { + r.Stats().IP.OutgoingPacketErrors.Increment() + } + return err +} + +// DefaultTTL returns the default TTL of the underlying network endpoint. +func (r *Route) DefaultTTL() uint8 { + return r.ref.ep.DefaultTTL() +} + +// MTU returns the MTU of the underlying network endpoint. +func (r *Route) MTU() uint32 { + return r.ref.ep.MTU() +} + +// Release frees all resources associated with the route. +func (r *Route) Release() { + if r.ref != nil { + r.ref.decRef() + r.ref = nil + } +} + +// Clone Clone a route such that the original one can be released and the new +// one will remain valid. +func (r *Route) Clone() Route { + r.ref.incRef() + return *r +} diff --git a/netstack/tcpip/stack/stack.go b/netstack/tcpip/stack/stack.go new file mode 100644 index 0000000..8f92ea6 --- /dev/null +++ b/netstack/tcpip/stack/stack.go @@ -0,0 +1,980 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package stack provides the glue between networking protocols and the +// consumers of the networking stack. +// +// For consumers, the only function of interest is New(), everything else is +// provided by the tcpip/public package. +// +// For protocol implementers, RegisterTransportProtocolFactory() and +// RegisterNetworkProtocolFactory() are used to register protocol factories with +// the stack, which will then be used to instantiate protocol objects when +// consumers interact with the stack. +package stack + +import ( + "sync" + "time" + + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/ports" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/waiter" +) + +const ( + // ageLimit is set to the same cache stale time used in Linux. + ageLimit = 1 * time.Minute + // resolutionTimeout is set to the same ARP timeout used in Linux. + resolutionTimeout = 1 * time.Second + // resolutionAttempts is set to the same ARP retries used in Linux. + resolutionAttempts = 3 +) + +type transportProtocolState struct { + proto TransportProtocol + defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool +} + +// TCPProbeFunc is the expected function type for a TCP probe function to be +// passed to stack.AddTCPProbe. +type TCPProbeFunc func(s TCPEndpointState) + +// TCPCubicState is used to hold a copy of the internal cubic state when the +// TCPProbeFunc is invoked. +type TCPCubicState struct { + WLastMax float64 + WMax float64 + T time.Time + TimeSinceLastCongestion time.Duration + C float64 + K float64 + Beta float64 + WC float64 + WEst float64 +} + +// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. +type TCPEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// TCPFastRecoveryState holds a copy of the internal fast recovery state of a +// TCP endpoint. +type TCPFastRecoveryState struct { + // Active if true indicates the endpoint is in fast recovery. + Active bool + + // First is the first unacknowledged sequence number being recovered. + First seqnum.Value + + // Last is the 'recover' sequence number that indicates the point at + // which we should exit recovery barring any timeouts etc. + Last seqnum.Value + + // MaxCwnd is the maximum value we are permitted to grow the congestion + // window during recovery. This is set at the time we enter recovery. + MaxCwnd int +} + +// TCPReceiverState holds a copy of the internal state of the receiver for +// a given TCP endpoint. +type TCPReceiverState struct { + // RcvNxt is the TCP variable RCV.NXT. + RcvNxt seqnum.Value + + // RcvAcc is the TCP variable RCV.ACC. + RcvAcc seqnum.Value + + // RcvWndScale is the window scaling to use for inbound segments. + RcvWndScale uint8 + + // PendingBufUsed is the number of bytes pending in the receive + // queue. + PendingBufUsed seqnum.Size + + // PendingBufSize is the size of the socket receive buffer. + PendingBufSize seqnum.Size +} + +// TCPSenderState holds a copy of the internal state of the sender for +// a given TCP Endpoint. +type TCPSenderState struct { + // LastSendTime is the time at which we sent the last segment. + LastSendTime time.Time + + // DupAckCount is the number of Duplicate ACK's received. + DupAckCount int + + // SndCwnd is the size of the sending congestion window in packets. + SndCwnd int + + // Ssthresh is the slow start threshold in packets. + Ssthresh int + + // SndCAAckCount is the number of packets consumed in congestion + // avoidance mode. + SndCAAckCount int + + // Outstanding is the number of packets in flight. + Outstanding int + + // SndWnd is the send window size in bytes. + SndWnd seqnum.Size + + // SndUna is the next unacknowledged sequence number. + SndUna seqnum.Value + + // SndNxt is the sequence number of the next segment to be sent. + SndNxt seqnum.Value + + // RTTMeasureSeqNum is the sequence number being used for the latest RTT + // measurement. + RTTMeasureSeqNum seqnum.Value + + // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. + RTTMeasureTime time.Time + + // Closed indicates that the caller has closed the endpoint for sending. + Closed bool + + // SRTT is the smoothed round-trip time as defined in section 2 of + // RFC 6298. + SRTT time.Duration + + // RTO is the retransmit timeout as defined in section of 2 of RFC 6298. + RTO time.Duration + + // RTTVar is the round-trip time variation as defined in section 2 of + // RFC 6298. + RTTVar time.Duration + + // SRTTInited if true indicates take a valid RTT measurement has been + // completed. + SRTTInited bool + + // MaxPayloadSize is the maximum size of the payload of a given segment. + // It is initialized on demand. + MaxPayloadSize int + + // SndWndScale is the number of bits to shift left when reading the send + // window size from a segment. + SndWndScale uint8 + + // MaxSentAck is the highest acknowledgement number sent till now. + MaxSentAck seqnum.Value + + // FastRecovery holds the fast recovery state for the endpoint. + FastRecovery TCPFastRecoveryState + + // Cubic holds the state related to CUBIC congestion control. + Cubic TCPCubicState +} + +// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. +type TCPSACKInfo struct { + // Blocks is the list of SACK block currently received by the + // TCP endpoint. + Blocks []header.SACKBlock +} + +// TCPEndpointState is a copy of the internal state of a TCP endpoint. +type TCPEndpointState struct { + // ID is a copy of the TransportEndpointID for the endpoint. + ID TCPEndpointID + + // SegTime denotes the absolute time when this segment was received. + SegTime time.Time + + // RcvBufSize is the size of the receive socket buffer for the endpoint. + RcvBufSize int + + // RcvBufUsed is the amount of bytes actually held in the receive socket + // buffer for the endpoint. + RcvBufUsed int + + // RcvClosed if true, indicates the endpoint has been closed for reading. + RcvClosed bool + + // SendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1. + SendTSOk bool + + // RecentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field is + // updated if required when a new segment is received by this endpoint. + RecentTS uint32 + + // TSOffset is a randomized offset added to the value of the TSVal field + // in the timestamp option. + TSOffset uint32 + + // SACKPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + SACKPermitted bool + + // SACK holds TCP SACK related information for this endpoint. + SACK TCPSACKInfo + + // SndBufSize is the size of the socket send buffer. + SndBufSize int + + // SndBufUsed is the number of bytes held in the socket send buffer. + SndBufUsed int + + // SndClosed indicates that the endpoint has been closed for sends. + SndClosed bool + + // SndBufInQueue is the number of bytes in the send queue. + SndBufInQueue seqnum.Size + + // PacketTooBigCount is used to notify the main protocol routine how + // many times a "packet too big" control packet is received. + PacketTooBigCount int + + // SndMTU is the smallest MTU seen in the control packets received. + SndMTU int + + // Receiver holds variables related to the TCP receiver for the endpoint. + Receiver TCPReceiverState + + // Sender holds state related to the TCP Sender for the endpoint. + Sender TCPSenderState +} + +// Stack is a networking stack, with all supported protocols, NICs, and route +// table. +type Stack struct { + transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState + networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol + linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + + demux *transportDemuxer + + stats tcpip.Stats + + linkAddrCache *linkAddrCache + + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + forwarding bool + + // route is the route table passed in by the user via SetRouteTable(), + // it is used by FindRoute() to build a route for a specific + // destination. + routeTable []tcpip.Route + + *ports.PortManager + + // If not nil, then any new endpoints will have this probe function + // invoked everytime they receive a TCP segment. + tcpProbeFunc TCPProbeFunc + + // clock is used to generate user-visible times. + clock tcpip.Clock +} + +// Options contains optional Stack configuration. +type Options struct { + // Clock is an optional clock source used for timestampping packets. + // + // If no Clock is specified, the clock source will be time.Now. + Clock tcpip.Clock + + // Stats are optional statistic counters. + Stats tcpip.Stats +} + +// New allocates a new networking stack with only the requested networking and +// transport protocols configured with default options. +// +// Protocol options can be changed by calling the +// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the +// stack. Please refer to individual protocol implementations as to what options +// are supported. +// 新建一个协议栈对象 +func New(network []string, transport []string, opts Options) *Stack { + clock := opts.Clock + if clock == nil { + clock = &tcpip.StdClock{} + } + + s := &Stack{ + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), + } + + // Add specified network protocols. + // 添加指定的网络层协议端,如IPV4 + for _, name := range network { + netProtoFactory, ok := networkProtocols[name] + if !ok { + continue + } + netProto := netProtoFactory() + s.networkProtocols[netProto.Number()] = netProto + // 判断该协议是否支持链路层地址解析协议接口,如果支持添加到 s.linkAddrResolvers 中, + // 如:ARP协议会添加 IPV4-ARP 的对应关系 + // 后面需要地址解析协议的时候会更改网络层协议号来找到相应的地址解析协议 + if r, ok := netProto.(LinkAddressResolver); ok { + s.linkAddrResolvers[r.LinkAddressProtocol()] = r + } + } + + // Add specified transport protocols. + // 添加传输层协议 + for _, name := range transport { + transProtoFactory, ok := transportProtocols[name] + if !ok { + continue + } + transProto := transProtoFactory() + s.transportProtocols[transProto.Number()] = &transportProtocolState{ + proto: transProto, + } + } + + // Create the global transport demuxer. + // 新建协议栈全局传输层分流器 + s.demux = newTransportDemuxer(s) + + return s +} + +// SetNetworkProtocolOption allows configuring individual protocol level +// options. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation or the provided value +// is incorrect. +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { + netProto, ok := s.networkProtocols[network] + if !ok { + return tcpip.ErrUnknownProtocol + } + return netProto.SetOption(option) +} + +// NetworkProtocolOption allows retrieving individual protocol level option +// values. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation. +// e.g. +// var v ipv4.MyOption +// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v) +// if err != nil { +// ... +// } +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { + netProto, ok := s.networkProtocols[network] + if !ok { + return tcpip.ErrUnknownProtocol + } + return netProto.Option(option) +} + +// SetTransportProtocolOption allows configuring individual protocol level +// options. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation or the provided value +// is incorrect. +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { + transProtoState, ok := s.transportProtocols[transport] + if !ok { + return tcpip.ErrUnknownProtocol + } + return transProtoState.proto.SetOption(option) +} + +// TransportProtocolOption allows retrieving individual protocol level option +// values. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation. +// var v tcp.SACKEnabled +// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { +// ... +// } +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { + transProtoState, ok := s.transportProtocols[transport] + if !ok { + return tcpip.ErrUnknownProtocol + } + return transProtoState.proto.Option(option) +} + +// SetTransportProtocolHandler sets the per-stack default handler for the given +// protocol. +// +// It must be called only during initialization of the stack. Changing it as the +// stack is operating is not supported. +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) { + state := s.transportProtocols[p] + if state != nil { + state.defaultHandler = h + } +} + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (s *Stack) NowNanoseconds() int64 { + return s.clock.NowNanoseconds() +} + +// Stats returns a mutable copy of the current stats. +// +// This is not generally exported via the public interface, but is available +// internally. +func (s *Stack) Stats() tcpip.Stats { + return s.stats +} + +// SetForwarding enables or disables the packet forwarding between NICs. +func (s *Stack) SetForwarding(enable bool) { + // TODO: Expose via /proc/sys/net/ipv4/ip_forward. + s.mu.Lock() + s.forwarding = enable + s.mu.Unlock() +} + +// Forwarding returns if the packet forwarding between NICs is enabled. +func (s *Stack) Forwarding() bool { + // TODO: Expose via /proc/sys/net/ipv4/ip_forward. + s.mu.RLock() + defer s.mu.RUnlock() + return s.forwarding +} + +// SetRouteTable assigns the route table to be used by this stack. It +// specifies which NIC to use for given destination address ranges. +func (s *Stack) SetRouteTable(table []tcpip.Route) { + s.mu.Lock() + defer s.mu.Unlock() + + s.routeTable = table +} + +// GetRouteTable returns the route table which is currently in use. +func (s *Stack) GetRouteTable() []tcpip.Route { + s.mu.Lock() + defer s.mu.Unlock() + return append([]tcpip.Route(nil), s.routeTable...) +} + +// NewEndpoint creates a new transport layer endpoint of the given protocol. +func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + t, ok := s.transportProtocols[transport] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + return t.proto.NewEndpoint(s, network, waiterQueue) +} + +// createNIC creates a NIC with the provided id and link-layer endpoint, and +// optionally enable it. +// 新建一个网卡对象,并且激活它,激活的意思就是准备好从网卡中读取和写入数据。 +func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error { + ep := FindLinkEndpoint(linkEP) + if ep == nil { + return tcpip.ErrBadLinkEndpoint + } + + s.mu.Lock() + defer s.mu.Unlock() + + // Make sure id is unique. + if _, ok := s.nics[id]; ok { + return tcpip.ErrDuplicateNICID + } + + n := newNIC(s, id, name, ep) + + s.nics[id] = n + if enabled { + n.attachLinkEndpoint() + } + + return nil +} + +// CreateNIC creates a NIC with the provided id and link-layer endpoint. +// 根据nic id和linkEP id来创建和注册一个网卡对象 +func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, "", linkEP, true) +} + +// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, +// and a human-readable name. +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, name, linkEP, true) +} + +// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, +// but leave it disable. Stack.EnableNIC must be called before the link-layer +// endpoint starts delivering packets to it. +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, "", linkEP, false) +} + +// CreateDisabledNamedNIC is a combination of CreateNamedNIC and +// CreateDisabledNIC. +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, name, linkEP, false) +} + +// EnableNIC enables the given NIC so that the link-layer endpoint can start +// delivering packets to it. +func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[id] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + nic.attachLinkEndpoint() + + return nil +} + +// NICSubnets returns a map of NICIDs to their associated subnets. +func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := map[tcpip.NICID][]tcpip.Subnet{} + + for id, nic := range s.nics { + nics[id] = append(nics[id], nic.Subnets()...) + } + return nics +} + +// NICInfo captures the name and addresses assigned to a NIC. +type NICInfo struct { + Name string + LinkAddress tcpip.LinkAddress + ProtocolAddresses []tcpip.ProtocolAddress + + // Flags indicate the state of the NIC. + Flags NICStateFlags + + // MTU is the maximum transmission unit. + MTU uint32 +} + +// NICInfo returns a map of NICIDs to their associated information. +func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := make(map[tcpip.NICID]NICInfo) + for id, nic := range s.nics { + flags := NICStateFlags{ + Up: true, // Netstack interfaces are always up. + Running: nic.linkEP.IsAttached(), + Promiscuous: nic.isPromiscuousMode(), + Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0, + } + nics[id] = NICInfo{ + Name: nic.name, + LinkAddress: nic.linkEP.LinkAddress(), + ProtocolAddresses: nic.Addresses(), + Flags: flags, + MTU: nic.linkEP.MTU(), + } + } + return nics +} + +// NICStateFlags holds information about the state of an NIC. +type NICStateFlags struct { + // Up indicates whether the interface is running. + Up bool + + // Running indicates whether resources are allocated. + Running bool + + // Promiscuous indicates whether the interface is in promiscuous mode. + Promiscuous bool + + // Loopback indicates whether the interface is a loopback. + Loopback bool +} + +// AddAddress adds a new network-layer address to the specified NIC. +func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint) +} + +// AddAddressWithOptions is the same as AddAddress, but allows you to specify +// whether the new endpoint can be primary or not. +func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[id] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.AddAddressWithOptions(protocol, addr, peb) +} + +// AddSubnet adds a subnet range to the specified NIC. +func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + nic.AddSubnet(protocol, subnet) + return nil + } + + return tcpip.ErrUnknownNICID +} + +// RemoveSubnet removes the subnet range from the specified NIC. +func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + nic.RemoveSubnet(subnet) + return nil + } + + return tcpip.ErrUnknownNICID +} + +// ContainsSubnet reports whether the specified NIC contains the specified +// subnet. +func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + return nic.ContainsSubnet(subnet), nil + } + + return false, tcpip.ErrUnknownNICID +} + +// RemoveAddress removes an existing network-layer address from the specified +// NIC. +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + return nic.RemoveAddress(addr) + } + + return tcpip.ErrUnknownNICID +} + +// GetMainNICAddress returns the first primary address (and the subnet that +// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's +// address if no primary addresses exist. Returns an error if the NIC doesn't +// exist or has no endpoints. +func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + return nic.getMainNICAddress(protocol) + } + + return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID +} + +// FindRoute creates a route to the given destination address, leaving through +// the given nic and local address (if provided). +// 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息 +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // 遍历路由表 + for i := range s.routeTable { + if (id != 0 && id != s.routeTable[i].NIC) || (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) { + continue + } + + nic := s.nics[s.routeTable[i].NIC] + if nic == nil { + continue + } + + var ref *referencedNetworkEndpoint + if len(localAddr) != 0 { + ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) + } else { + ref = nic.primaryEndpoint(netProto) + } + if ref == nil { + continue + } + + if len(remoteAddr) == 0 { + // If no remote address was provided, then the route + // provided will refer to the link local address. + remoteAddr = ref.ep.ID().LocalAddress + } + + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref) + r.NextHop = s.routeTable[i].Gateway + return r, nil + } + + return Route{}, tcpip.ErrNoRoute +} + +// CheckNetworkProtocol checks if a given network protocol is enabled in the +// stack. +func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool { + _, ok := s.networkProtocols[protocol] + return ok +} + +// CheckLocalAddress determines if the given local address exists, and if it +// does, returns the id of the NIC it's bound to. Returns 0 if the address +// does not exist. +func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { + s.mu.RLock() + defer s.mu.RUnlock() + + // If a NIC is specified, we try to find the address there only. + if nicid != 0 { + nic := s.nics[nicid] + if nic == nil { + return 0 + } + + ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) + if ref == nil { + return 0 + } + + ref.decRef() + + return nic.id + } + + // Go through all the NICs. + for _, nic := range s.nics { + ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) + if ref != nil { + ref.decRef() + return nic.id + } + } + + return 0 +} + +// SetPromiscuousMode enables or disables promiscuous mode in the given NIC. +func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + nic.setPromiscuousMode(enable) + + return nil +} + +// SetSpoofing enables or disables address spoofing in the given NIC, allowing +// endpoints to bind to any address in the NIC. +func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + nic.setSpoofing(enable) + + return nil +} + +// AddLinkAddress adds a link address to the stack link cache. +func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + s.linkAddrCache.add(fullAddr, linkAddr) + // TODO: provide a way for a transport endpoint to receive a signal + // that AddLinkAddress for a particular address has been called. +} + +// GetLinkAddress implements LinkAddressCache.GetLinkAddress. +// 获取链路层地址的方法 +func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { + s.mu.RLock() + // 获取网卡对象 + nic := s.nics[nicid] + if nic == nil { + s.mu.RUnlock() + return "", nil, tcpip.ErrUnknownNICID + } + s.mu.RUnlock() + + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + // 根据网络层协议号找到对应的地址解析协议, + // 如:IPV4协议会得到ARP协议 + linkRes := s.linkAddrResolvers[protocol] + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) +} + +// RemoveWaker implements LinkAddressCache.RemoveWaker. +func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic := s.nics[nicid]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + s.linkAddrCache.removeWaker(fullAddr, waker) + } +} + +// RegisterTransportEndpoint registers the given endpoint with the stack +// transport dispatcher. Received packets that match the provided id will be +// delivered to the given endpoint; specifying a nic is optional, but +// nic-specific IDs have precedence over global ones. +// RegisterTransportEndpoint 协议栈或者NIC的分流器注册给定传输层端点。 +// 收到的与提供的id匹配的数据包将被传送到给定的端点;指定nic是可选的,但特定于nic的ID优先于全局ID。 +// 最终调用 demuxer.registerEndpoint 函数来实现注册。 +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + if nicID == 0 { + return s.demux.registerEndpoint(netProtos, protocol, id, ep) + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.demux.registerEndpoint(netProtos, protocol, id, ep) +} + +// UnregisterTransportEndpoint removes the endpoint with the given id from the +// stack transport dispatcher. +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { + if nicID == 0 { + s.demux.unregisterEndpoint(netProtos, protocol, id) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic != nil { + nic.demux.unregisterEndpoint(netProtos, protocol, id) + } +} + +// NetworkProtocolInstance returns the protocol instance in the stack for the +// specified network protocol. This method is public for protocol implementers +// and tests to use. +func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol { + if p, ok := s.networkProtocols[num]; ok { + return p + } + return nil +} + +// TransportProtocolInstance returns the protocol instance in the stack for the +// specified transport protocol. This method is public for protocol implementers +// and tests to use. +func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol { + if pState, ok := s.transportProtocols[num]; ok { + return pState.proto + } + return nil +} + +// AddTCPProbe installs a probe function that will be invoked on every segment +// received by a given TCP endpoint. The probe function is passed a copy of the +// TCP endpoint state. +// +// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints +// created prior to this call will not call the probe function. +// +// Further, installing two different probes back to back can result in some +// endpoints calling the first one and some the second one. There is no +// guarantee provided on which probe will be invoked. Ideally this should only +// be called once per stack. +func (s *Stack) AddTCPProbe(probe TCPProbeFunc) { + s.mu.Lock() + s.tcpProbeFunc = probe + s.mu.Unlock() +} + +// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil +// otherwise. +func (s *Stack) GetTCPProbe() TCPProbeFunc { + s.mu.Lock() + p := s.tcpProbeFunc + s.mu.Unlock() + return p +} + +// RemoveTCPProbe removes an installed TCP probe. +// +// NOTE: This only ensures that endpoints created after this call do not +// have a probe attached. Endpoints already created will continue to invoke +// TCP probe. +func (s *Stack) RemoveTCPProbe() { + s.mu.Lock() + s.tcpProbeFunc = nil + s.mu.Unlock() +} + +// JoinGroup joins the given multicast group on the given NIC. +func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { + // TODO: notify network of subscription via igmp protocol. + return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint) +} + +// LeaveGroup leaves the given multicast group on the given NIC. +func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { + return s.RemoveAddress(nicID, multicastAddr) +} diff --git a/netstack/tcpip/stack/stack_test.go b/netstack/tcpip/stack/stack_test.go new file mode 100644 index 0000000..182455e --- /dev/null +++ b/netstack/tcpip/stack/stack_test.go @@ -0,0 +1,849 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package stack_test contains tests for the stack. It is in its own package so +// that the tests can also validate that all definitions needed to implement +// transport and network protocols are properly exported by the stack package. +package stack_test + +import ( + "math" + "strings" + "testing" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/link/channel" + "tcpip/netstack/tcpip/stack" +) + +const ( + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fakeNetHeaderLen = 12 + + // fakeControlProtocol is used for control packets that represent + // destination port unreachable. + fakeControlProtocol tcpip.TransportProtocolNumber = 2 + + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65536 +) + +// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and +// received packets; the counts of all endpoints are aggregated in the protocol +// descriptor. +// +// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only +// use the first three: destination address, source address, and transport +// protocol. They're all one byte fields to simplify parsing. +type fakeNetworkEndpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + proto *fakeNetworkProtocol + dispatcher stack.TransportDispatcher + linkEP stack.LinkEndpoint +} + +func (f *fakeNetworkEndpoint) MTU() uint32 { + return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) +} + +func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { + return f.nicid +} + +func (*fakeNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + +func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { + return &f.id +} + +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { + // Increment the received packet count in the protocol descriptor. + f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ + + // Consume the network header. + b := vv.First() + vv.TrimFront(fakeNetHeaderLen) + + // Handle control packets. + if b[2] == uint8(fakeControlProtocol) { + nb := vv.First() + if len(nb) < fakeNetHeaderLen { + return + } + + vv.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + return + } + + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv) +} + +func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { + return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen +} + +func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { + return 0 +} + +func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return f.linkEP.Capabilities() +} + +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error { + // Increment the sent packet count in the protocol descriptor. + f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ + + // Add the protocol's header to the packet and send it to the link + // endpoint. + b := hdr.Prepend(fakeNetHeaderLen) + b[0] = r.RemoteAddress[0] + b[1] = f.id.LocalAddress[0] + b[2] = byte(protocol) + return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber) +} + +func (*fakeNetworkEndpoint) Close() {} + +type fakeNetGoodOption bool + +type fakeNetBadOption bool + +type fakeNetInvalidValueOption int + +type fakeNetOptions struct { + good bool +} + +// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the +// number of packets sent and received via endpoints of this protocol. The index +// where packets are added is given by the packet's destination address MOD 10. +type fakeNetworkProtocol struct { + packetCount [10]int + sendPacketCount [10]int + opts fakeNetOptions +} + +func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { + return fakeNetNumber +} + +func (f *fakeNetworkProtocol) MinimumPacketSize() int { + return fakeNetHeaderLen +} + +func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) +} + +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return &fakeNetworkEndpoint{ + nicid: nicid, + id: stack.NetworkEndpointID{addr}, + proto: f, + dispatcher: dispatcher, + linkEP: linkEP, + }, nil +} + +func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case fakeNetGoodOption: + f.opts.good = bool(v) + return nil + case fakeNetInvalidValueOption: + return tcpip.ErrInvalidOptionValue + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *fakeNetGoodOption: + *v = fakeNetGoodOption(f.opts.good) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func TestNetworkReceive(t *testing.T) { + // Create a stack with the fake network protocol, one nic, and two + // addresses attached to it: 1 & 2. + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + buf := buffer.NewView(30) + + // Make sure packet with wrong address is not delivered. + buf[0] = 3 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 0 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + } + if fakeNet.packetCount[2] != 0 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) + } + + // Make sure packet is delivered to first endpoint. + buf[0] = 1 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 0 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) + } + + // Make sure packet is delivered to second endpoint. + buf[0] = 2 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 1 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) + } + + // Make sure packet is not delivered if protocol number is wrong. + linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 1 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) + } + + // Make sure packet that is too small is dropped. + buf.CapLength(2) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 1 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) + } +} + +func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) { + r, err := s.FindRoute(0, "", addr, fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + defer r.Release() + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) + if err := r.WritePacket(hdr, buffer.VectorisedView{}, fakeTransNumber, 123); err != nil { + t.Errorf("WritePacket failed: %v", err) + return + } +} + +func TestNetworkSend(t *testing.T) { + // Create a stack with the fake network protocol, one nic, and one + // address: 1. The route table sends all packets through the only + // existing nic. + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("NewNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Make sure that the link-layer endpoint received the outbound packet. + sendTo(t, s, "\x03") + if c := linkEP.Drain(); c != 1 { + t.Errorf("packetCount = %d, want %d", c, 1) + } +} + +func TestNetworkSendMultiRoute(t *testing.T) { + // Create a stack with the fake network protocol, two nics, and two + // addresses per nic, the first nic has odd address, the second one has + // even addresses. + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id1, linkEP1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id1); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + id2, linkEP2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, id2); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + s.SetRouteTable([]tcpip.Route{ + {"\x01", "\x01", "\x00", 1}, + {"\x00", "\x01", "\x00", 2}, + }) + + // Send a packet to an odd destination. + sendTo(t, s, "\x05") + + if c := linkEP1.Drain(); c != 1 { + t.Errorf("packetCount = %d, want %d", c, 1) + } + + // Send a packet to an even destination. + sendTo(t, s, "\x06") + + if c := linkEP2.Drain(); c != 1 { + t.Errorf("packetCount = %d, want %d", c, 1) + } +} + +func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { + r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + + defer r.Release() + + if r.LocalAddress != expectedSrcAddr { + t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress) + } + + if r.RemoteAddress != dstAddr { + t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress) + } +} + +func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { + _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber) + if err != tcpip.ErrNoRoute { + t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err) + } +} + +func TestRoutes(t *testing.T) { + // Create a stack with the fake network protocol, two nics, and two + // addresses per nic, the first nic has odd address, the second one has + // even addresses. + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id1, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id1); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + id2, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, id2); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + s.SetRouteTable([]tcpip.Route{ + {"\x01", "\x01", "\x00", 1}, + {"\x00", "\x01", "\x00", 2}, + }) + + // Test routes to odd address. + testRoute(t, s, 0, "", "\x05", "\x01") + testRoute(t, s, 0, "\x01", "\x05", "\x01") + testRoute(t, s, 1, "\x01", "\x05", "\x01") + testRoute(t, s, 0, "\x03", "\x05", "\x03") + testRoute(t, s, 1, "\x03", "\x05", "\x03") + + // Test routes to even address. + testRoute(t, s, 0, "", "\x06", "\x02") + testRoute(t, s, 0, "\x02", "\x06", "\x02") + testRoute(t, s, 2, "\x02", "\x06", "\x02") + testRoute(t, s, 0, "\x04", "\x06", "\x04") + testRoute(t, s, 2, "\x04", "\x06", "\x04") + + // Try to send to odd numbered address from even numbered ones, then + // vice-versa. + testNoRoute(t, s, 0, "\x02", "\x05") + testNoRoute(t, s, 2, "\x02", "\x05") + testNoRoute(t, s, 0, "\x04", "\x05") + testNoRoute(t, s, 2, "\x04", "\x05") + + testNoRoute(t, s, 0, "\x01", "\x06") + testNoRoute(t, s, 1, "\x01", "\x06") + testNoRoute(t, s, 0, "\x03", "\x06") + testNoRoute(t, s, 1, "\x03", "\x06") +} + +func TestAddressRemoval(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + buf := buffer.NewView(30) + + // Write a packet, and check that it gets delivered. + fakeNet.packetCount[1] = 0 + buf[0] = 1 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Remove the address, then check that packet doesn't get delivered + // anymore. + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed: %v", err) + } + + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Check that removing the same address fails. + if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + t.Fatalf("RemoveAddress failed: %v", err) + } +} + +func TestDelayedRemovalDueToRoute(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + buf := buffer.NewView(30) + + // Write a packet, and check that it gets delivered. + fakeNet.packetCount[1] = 0 + buf[0] = 1 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Get a route, check that packet is still deliverable. + r, err := s.FindRoute(0, "", "\x02", fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 2 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) + } + + // Remove the address, then check that packet is still deliverable + // because the route is keeping the address alive. + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed: %v", err) + } + + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 3 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) + } + + // Check that removing the same address fails. + if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + t.Fatalf("RemoveAddress failed: %v", err) + } + + // Release the route, then check that packet is not deliverable anymore. + r.Release() + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 3 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) + } +} + +func TestPromiscuousMode(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + buf := buffer.NewView(30) + + // Write a packet, and check that it doesn't get delivered as we don't + // have a matching endpoint. + fakeNet.packetCount[1] = 0 + buf[0] = 1 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 0 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + } + + // Set promiscuous mode, then check that packet is delivered. + if err := s.SetPromiscuousMode(1, true); err != nil { + t.Fatalf("SetPromiscuousMode failed: %v", err) + } + + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Check that we can't get a route as there is no local address. + _, err := s.FindRoute(0, "", "\x02", fakeNetNumber) + if err != tcpip.ErrNoRoute { + t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err) + } + + // Set promiscuous mode to false, then check that packet can't be + // delivered anymore. + if err := s.SetPromiscuousMode(1, false); err != nil { + t.Fatalf("SetPromiscuousMode failed: %v", err) + } + + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } +} + +func TestAddressSpoofing(t *testing.T) { + srcAddr := tcpip.Address("\x01") + dstAddr := tcpip.Address("\x02") + + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatalf("SetSpoofing failed: %v", err) + } + r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + if r.LocalAddress != srcAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } +} + +// Set the subnet, then check that packet is delivered. +func TestSubnetAcceptsMatchingPacket(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + buf := buffer.NewView(30) + + buf[0] = 1 + fakeNet.packetCount[1] = 0 + subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } +} + +// Set destination outside the subnet, then check it doesn't get delivered. +func TestSubnetRejectsNonmatchingPacket(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + buf := buffer.NewView(30) + + buf[0] = 1 + fakeNet.packetCount[1] = 0 + subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeNet.packetCount[1] != 0 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + } +} + +func TestNetworkOptions(t *testing.T) { + s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) + + // Try an unsupported network protocol. + if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { + t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) + } + + testCases := []struct { + option interface{} + wantErr *tcpip.Error + verifier func(t *testing.T, p stack.NetworkProtocol) + }{ + {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { + t.Helper() + fakeNet := p.(*fakeNetworkProtocol) + if fakeNet.opts.good != true { + t.Fatalf("fakeNet.opts.good = false, want = true") + } + var v fakeNetGoodOption + if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { + t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) + } + if v != true { + t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) + } + }}, + {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, + {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, + } + for _, tc := range testCases { + if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { + t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) + } + if tc.verifier != nil { + tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) + } + } +} + +func TestSubnetAddRemove(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + addr := tcpip.Address("\x01\x01\x01\x01") + mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) + subnet, err := tcpip.NewSubnet(addr, mask) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + + if contained, err := s.ContainsSubnet(1, subnet); err != nil { + t.Fatalf("ContainsSubnet failed: %v", err) + } else if contained { + t.Fatal("got s.ContainsSubnet(...) = true, want = false") + } + + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + + if contained, err := s.ContainsSubnet(1, subnet); err != nil { + t.Fatalf("ContainsSubnet failed: %v", err) + } else if !contained { + t.Fatal("got s.ContainsSubnet(...) = false, want = true") + } + + if err := s.RemoveSubnet(1, subnet); err != nil { + t.Fatalf("RemoveSubnet failed: %v", err) + } + + if contained, err := s.ContainsSubnet(1, subnet); err != nil { + t.Fatalf("ContainsSubnet failed: %v", err) + } else if contained { + t.Fatal("got s.ContainsSubnet(...) = true, want = false") + } +} + +func TestGetMainNICAddress(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + for _, tc := range []struct { + name string + address tcpip.Address + }{ + {"IPv4", "\x01\x01\x01\x01"}, + {"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"}, + } { + t.Run(tc.name, func(t *testing.T) { + address := tc.address + mask := tcpip.AddressMask(strings.Repeat("\xff", len(address))) + subnet, err := tcpip.NewSubnet(address, mask) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, address); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + + // Check that we get the right initial address and subnet. + if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + t.Fatalf("GetMainNICAddress failed: %v", err) + } else if gotAddress != address { + t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address) + } else if gotSubnet != subnet { + t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet) + } + + if err := s.RemoveSubnet(1, subnet); err != nil { + t.Fatalf("RemoveSubnet failed: %v", err) + } + + if err := s.RemoveAddress(1, address); err != nil { + t.Fatalf("RemoveAddress failed: %v", err) + } + + // Check that we get an error after removal. + if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress) + } + }) + } +} + +func init() { + stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { + return &fakeNetworkProtocol{} + }) +} diff --git a/netstack/tcpip/stack/transport_demuxer.go b/netstack/tcpip/stack/transport_demuxer.go new file mode 100644 index 0000000..63bc0e5 --- /dev/null +++ b/netstack/tcpip/stack/transport_demuxer.go @@ -0,0 +1,192 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "sync" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" +) + +// 网络层协议号和传输层协议号的组合,当作分流器的key值 +type protocolIDs struct { + network tcpip.NetworkProtocolNumber + transport tcpip.TransportProtocolNumber +} + +// transportEndpoints manages all endpoints of a given protocol. It has its own +// mutex so as to reduce interference between protocols. +// transportEndpoints 管理给定协议的所有端点。 +type transportEndpoints struct { + mu sync.RWMutex + endpoints map[TransportEndpointID]TransportEndpoint +} + +// transportDemuxer demultiplexes packets targeted at a transport endpoint +// (i.e., after they've been parsed by the network layer). It does two levels +// of demultiplexing: first based on the network and transport protocols, then +// based on endpoints IDs. +// transportDemuxer 解复用针对传输端点的数据包(即,在它们被网络层解析之后)。 +// 它执行两级解复用:首先基于网络层协议和传输协议,然后基于端点ID。 +type transportDemuxer struct { + protocol map[protocolIDs]*transportEndpoints +} + +// 新建一个分流器 +func newTransportDemuxer(stack *Stack) *transportDemuxer { + d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + + // Add each network and transport pair to the demuxer. + for netProto := range stack.networkProtocols { + for proto := range stack.transportProtocols { + d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)} + } + } + + return d +} + +// registerEndpoint registers the given endpoint with the dispatcher such that +// packets that match the endpoint ID are delivered to it. +// registerEndpoint 向分发器注册给定端点,以便将与端点ID匹配的数据包传递给它。 +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + for i, n := range netProtos { + if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id) + return err + } + } + + return nil +} + +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return nil + } + + eps.mu.Lock() + defer eps.mu.Unlock() + + if _, ok := eps.endpoints[id]; ok { + return tcpip.ErrPortInUse + } + + eps.endpoints[id] = ep + + return nil +} + +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +// unregisterEndpoint 使用给定的id注销端点,使其不再接收任何数据包。 +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { + for _, n := range netProtos { + if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { + eps.mu.Lock() + delete(eps.endpoints, id) + eps.mu.Unlock() + } + } +} + +// deliverPacket attempts to deliver the given packet. Returns true if it found +// an endpoint, false otherwise. +// 根据传输层的id来找到对应的传输端,再将数据包交给这个传输端处理 +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { + // 先看看分流器里有没有注册相关协议端,如果没有则返回false + eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] + if !ok { + return false + } + + // 从 eps 中找符合 id 的传输端 + eps.mu.RLock() + ep := d.findEndpointLocked(eps, vv, id) + eps.mu.RUnlock() + + // Fail if we didn't find one. + if ep == nil { + // UDP packet could not be delivered to an unknown destination port. + if protocol == header.UDPProtocolNumber { + r.Stats().UDP.UnknownPortErrors.Increment() + } + return false + } + + // Deliver the packet. + // 找到传输端到,让它来处理数据包 + ep.HandlePacket(r, id, vv) + + return true +} + +// deliverControlPacket attempts to deliver the given control packet. Returns +// true if it found an endpoint, false otherwise. +func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{net, trans}] + if !ok { + return false + } + + // Try to find the endpoint. + eps.mu.RLock() + ep := d.findEndpointLocked(eps, vv, id) + eps.mu.RUnlock() + + // Fail if we didn't find one. + if ep == nil { + return false + } + + // Deliver the packet. + ep.HandleControlPacket(id, typ, extra, vv) + + return true +} + +// 根据传输层id来找到相应的传输层端 +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { + // Try to find a match with the id as provided. + // 从 endpoints 中看有没有匹配到id的传输层端 + if ep := eps.endpoints[id]; ep != nil { + return ep + } + + // Try to find a match with the id minus the local address. + nid := id + // 如果上面的 endpoints 没有找到,那么去掉本地ip地址,看看有没有相应的传输层端 + // 因为有时候传输层监听的时候没有绑定本地ip,也就是 any address,此时的 LocalAddress + // 为空。 + nid.LocalAddress = "" + if ep := eps.endpoints[nid]; ep != nil { + return ep + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep := eps.endpoints[nid]; ep != nil { + return ep + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + return eps.endpoints[nid] +} diff --git a/netstack/tcpip/stack/transport_test.go b/netstack/tcpip/stack/transport_test.go new file mode 100644 index 0000000..0abb119 --- /dev/null +++ b/netstack/tcpip/stack/transport_test.go @@ -0,0 +1,422 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "testing" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/link/channel" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +const ( + fakeTransNumber tcpip.TransportProtocolNumber = 1 + fakeTransHeaderLen = 3 +) + +// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts +// received packets; the counts of all endpoints are aggregated in the protocol +// descriptor. +// +// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't +// use it. +type fakeTransportEndpoint struct { + id stack.TransportEndpointID + stack *stack.Stack + netProto tcpip.NetworkProtocolNumber + proto *fakeTransportProtocol + peerAddr tcpip.Address + route stack.Route +} + +func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto} +} + +func (f *fakeTransportEndpoint) Close() { + f.route.Release() +} + +func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + return mask +} + +func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return buffer.View{}, tcpip.ControlMessages{}, nil +} + +func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + if len(f.route.RemoteAddress) == 0 { + return 0, nil, tcpip.ErrNoRoute + } + + hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) + v, err := p.Get(p.Size()) + if err != nil { + return 0, nil, err + } + if err := f.route.WritePacket(hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil { + return 0, nil, err + } + + return uintptr(len(v)), nil, nil +} + +func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// SetSockOpt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return nil + } + return tcpip.ErrInvalidEndpointState +} + +func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + f.peerAddr = addr.Addr + + // Find the route. + r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber) + if err != nil { + return tcpip.ErrNoRoute + } + defer r.Release() + + // Try to register so that we can start receiving packets. + f.id.RemoteAddress = addr.Addr + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f) + if err != nil { + return err + } + + f.route = r.Clone() + + return nil +} + +func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { + return nil +} + +func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { + return nil +} + +func (*fakeTransportEndpoint) Reset() { +} + +func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { + return nil +} + +func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, nil +} + +func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + return commit() +} + +func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, nil +} + +func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, nil +} + +func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) { + // Increment the number of received packets. + f.proto.packetCount++ +} + +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { + // Increment the number of received control packets. + f.proto.controlCount++ +} + +type fakeTransportGoodOption bool + +type fakeTransportBadOption bool + +type fakeTransportInvalidValueOption int + +type fakeTransportProtocolOptions struct { + good bool +} + +// fakeTransportProtocol is a transport-layer protocol descriptor. It +// aggregates the number of packets received via endpoints of this protocol. +type fakeTransportProtocol struct { + packetCount int + controlCount int + opts fakeTransportProtocolOptions +} + +func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { + return fakeTransNumber +} + +func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newFakeTransportEndpoint(stack, f, netProto), nil +} + +func (*fakeTransportProtocol) MinimumPacketSize() int { + return fakeTransHeaderLen +} + +func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { + return 0, 0, nil +} + +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { + return true +} + +func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case fakeTransportGoodOption: + f.opts.good = bool(v) + return nil + case fakeTransportInvalidValueOption: + return tcpip.ErrInvalidOptionValue + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *fakeTransportGoodOption: + *v = fakeTransportGoodOption(f.opts.good) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func TestTransportReceive(t *testing.T) { + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Create endpoint and connect to remote address. + wq := waiter.Queue{} + ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) + + // Create buffer that will hold the packet. + buf := buffer.NewView(30) + + // Make sure packet with wrong protocol is not delivered. + buf[0] = 1 + buf[2] = 0 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeTrans.packetCount != 0 { + t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) + } + + // Make sure packet from the wrong source is not delivered. + buf[0] = 1 + buf[1] = 3 + buf[2] = byte(fakeTransNumber) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeTrans.packetCount != 0 { + t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) + } + + // Make sure packet is delivered. + buf[0] = 1 + buf[1] = 2 + buf[2] = byte(fakeTransNumber) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeTrans.packetCount != 1 { + t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) + } +} + +func TestTransportControlReceive(t *testing.T) { + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Create endpoint and connect to remote address. + wq := waiter.Queue{} + ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) + + // Create buffer that will hold the control packet. + buf := buffer.NewView(2*fakeNetHeaderLen + 30) + + // Outer packet contains the control protocol number. + buf[0] = 1 + buf[1] = 0xfe + buf[2] = uint8(fakeControlProtocol) + + // Make sure packet with wrong protocol is not delivered. + buf[fakeNetHeaderLen+0] = 0 + buf[fakeNetHeaderLen+1] = 1 + buf[fakeNetHeaderLen+2] = 0 + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeTrans.controlCount != 0 { + t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) + } + + // Make sure packet from the wrong source is not delivered. + buf[fakeNetHeaderLen+0] = 3 + buf[fakeNetHeaderLen+1] = 1 + buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeTrans.controlCount != 0 { + t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) + } + + // Make sure packet is delivered. + buf[fakeNetHeaderLen+0] = 2 + buf[fakeNetHeaderLen+1] = 1 + buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) + linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + if fakeTrans.controlCount != 1 { + t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) + } +} + +func TestTransportSend(t *testing.T) { + id, _ := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + // Create endpoint and bind it. + wq := waiter.Queue{} + ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + // Create buffer that will hold the payload. + view := buffer.NewView(30) + _, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("write failed: %v", err) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + if fakeNet.sendPacketCount[2] != 1 { + t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) + } +} + +func TestTransportOptions(t *testing.T) { + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + + // Try an unsupported transport protocol. + if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { + t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) + } + + testCases := []struct { + option interface{} + wantErr *tcpip.Error + verifier func(t *testing.T, p stack.TransportProtocol) + }{ + {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { + t.Helper() + fakeTrans := p.(*fakeTransportProtocol) + if fakeTrans.opts.good != true { + t.Fatalf("fakeTrans.opts.good = false, want = true") + } + var v fakeTransportGoodOption + if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) + } + if v != true { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) + } + + }}, + {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, + {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, + } + for _, tc := range testCases { + if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { + t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) + } + if tc.verifier != nil { + tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) + } + } +} + +func init() { + stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { + return &fakeTransportProtocol{} + }) +} diff --git a/netstack/tcpip/tcpip.go b/netstack/tcpip/tcpip.go new file mode 100644 index 0000000..532ed93 --- /dev/null +++ b/netstack/tcpip/tcpip.go @@ -0,0 +1,834 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcpip provides the interfaces and related types that users of the +// tcpip stack will use in order to create endpoints used to send and receive +// data over the network stack. +// +// The starting point is the creation and configuration of a stack. A stack can +// be created by calling the New() function of the tcpip/stack/stack package; +// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()), +// adding network addresses (via calls to Stack.AddAddress()), and +// setting a route table (via a call to Stack.SetRouteTable()). +// +// Once a stack is configured, endpoints can be created by calling +// Stack.NewEndpoint(). Such endpoints can be used to send/receive data, connect +// to peers, listen for connections, accept connections, etc., depending on the +// transport protocol selected. +package tcpip + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/waiter" +) + +// Error represents an error in the netstack error space. Using a special type +// ensures that errors outside of this space are not accidentally introduced. +// +// Note: to support save / restore, it is important that all tcpip errors have +// distinct error messages. +type Error struct { + msg string + + ignoreStats bool +} + +// String implements fmt.Stringer.String. +func (e *Error) String() string { + return e.msg +} + +// IgnoreStats indicates whether this error type should be included in failure +// counts in tcpip.Stats structs. +func (e *Error) IgnoreStats() bool { + return e.ignoreStats +} + +// Errors that can be returned by the network stack. +var ( + ErrUnknownProtocol = &Error{msg: "unknown protocol"} + ErrUnknownNICID = &Error{msg: "unknown nic id"} + ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"} + ErrDuplicateNICID = &Error{msg: "duplicate nic id"} + ErrDuplicateAddress = &Error{msg: "duplicate address"} + ErrNoRoute = &Error{msg: "no route"} + ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"} + ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true} + ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"} + ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true} + ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true} + ErrNoPortAvailable = &Error{msg: "no ports are available"} + ErrPortInUse = &Error{msg: "port is in use"} + ErrBadLocalAddress = &Error{msg: "bad local address"} + ErrClosedForSend = &Error{msg: "endpoint is closed for send"} + ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"} + ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true} + ErrConnectionRefused = &Error{msg: "connection was refused"} + ErrTimeout = &Error{msg: "operation timed out"} + ErrAborted = &Error{msg: "operation aborted"} + ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true} + ErrDestinationRequired = &Error{msg: "destination address is required"} + ErrNotSupported = &Error{msg: "operation not supported"} + ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"} + ErrNotConnected = &Error{msg: "endpoint not connected"} + ErrConnectionReset = &Error{msg: "connection reset by peer"} + ErrConnectionAborted = &Error{msg: "connection aborted"} + ErrNoSuchFile = &Error{msg: "no such file"} + ErrInvalidOptionValue = &Error{msg: "invalid option value specified"} + ErrNoLinkAddress = &Error{msg: "no remote link address"} + ErrBadAddress = &Error{msg: "bad address"} + ErrNetworkUnreachable = &Error{msg: "network is unreachable"} + ErrMessageTooLong = &Error{msg: "message too long"} + ErrNoBufferSpace = &Error{msg: "no buffer space available"} +) + +// Errors related to Subnet +var ( + errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") + errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask") +) + +// ErrSaveRejection indicates a failed save due to unsupported networking state. +// This type of errors is only used for save logic. +type ErrSaveRejection struct { + Err error +} + +// Error returns a sensible description of the save rejection error. +func (e ErrSaveRejection) Error() string { + return "save rejected due to unsupported networking state: " + e.Err.Error() +} + +// A Clock provides the current time. +// +// Times returned by a Clock should always be used for application-visible +// time, but never for netstack internal timekeeping. +type Clock interface { + // NowNanoseconds returns the current real time as a number of + // nanoseconds since the Unix epoch. + NowNanoseconds() int64 + + // NowMonotonic returns a monotonic time value. + NowMonotonic() int64 +} + +// Address is a byte slice cast as a string that represents the address of a +// network node. Or, in the case of unix endpoints, it may represent a path. +type Address string + +// AddressMask is a bitmask for an address. +type AddressMask string + +// String implements Stringer. +func (a AddressMask) String() string { + return Address(a).String() +} + +// Subnet is a subnet defined by its address and mask. +type Subnet struct { + address Address + mask AddressMask +} + +// NewSubnet creates a new Subnet, checking that the address and mask are the same length. +func NewSubnet(a Address, m AddressMask) (Subnet, error) { + if len(a) != len(m) { + return Subnet{}, errSubnetLengthMismatch + } + for i := 0; i < len(a); i++ { + if a[i]&^m[i] != 0 { + return Subnet{}, errSubnetAddressMasked + } + } + return Subnet{a, m}, nil +} + +// Contains returns true iff the address is of the same length and matches the +// subnet address and mask. +func (s *Subnet) Contains(a Address) bool { + if len(a) != len(s.address) { + return false + } + for i := 0; i < len(a); i++ { + if a[i]&s.mask[i] != s.address[i] { + return false + } + } + return true +} + +// ID returns the subnet ID. +func (s *Subnet) ID() Address { + return s.address +} + +// Bits returns the number of ones (network bits) and zeros (host bits) in the +// subnet mask. +func (s *Subnet) Bits() (ones int, zeros int) { + for _, b := range []byte(s.mask) { + for i := uint(0); i < 8; i++ { + if b&(1<= 0; j-- { + if b&(1< s.Size() { + size = s.Size() + } + return s[:size], nil +} + +// Size implements Payload. +func (s SlicePayload) Size() int { + return len(s) +} + +// A ControlMessages contains socket control messages for IP sockets. +// +// +stateify savable +type ControlMessages struct { + // HasTimestamp indicates whether Timestamp is valid/set. + HasTimestamp bool + + // Timestamp is the time (in ns) that the last packed used to create + // the read data was received. + Timestamp int64 +} + +// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) +// that exposes functionality like read, write, connect, etc. to users of the +// networking stack. +type Endpoint interface { + // Close puts the endpoint in a closed state and frees all resources + // associated with it. + Close() + + // Read reads data from the endpoint and optionally returns the sender. + // + // This method does not block if there is no data pending. It will also + // either return an error or data, never both. + // + // A timestamp (in ns) is optionally returned. A zero value indicates + // that no timestamp was available. + Read(*FullAddress) (buffer.View, ControlMessages, *Error) + + // Write writes data to the endpoint's peer. This method does not block if + // the data cannot be written. + // + // Unlike io.Writer.Write, Endpoint.Write transfers ownership of any bytes + // successfully written to the Endpoint. That is, if a call to + // Write(SlicePayload{data}) returns (n, err), it may retain data[:n], and + // the caller should not use data[:n] after Write returns. + // + // Note that unlike io.Writer.Write, it is not an error for Write to + // perform a partial write. + // + // For UDP and Ping sockets if address resolution is required, + // ErrNoLinkAddress and a notification channel is returned for the caller to + // block. Channel is closed once address resolution is complete (success or + // not). The channel is only non-nil in this case. + Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error) + + // Peek reads data without consuming it from the endpoint. + // + // This method does not block if there is no data pending. + // + // A timestamp (in ns) is optionally returned. A zero value indicates + // that no timestamp was available. + Peek([][]byte) (uintptr, ControlMessages, *Error) + + // Connect connects the endpoint to its peer. Specifying a NIC is + // optional. + // + // There are three classes of return values: + // nil -- the attempt to connect succeeded. + // ErrConnectStarted/ErrAlreadyConnecting -- the connect attempt started + // but hasn't completed yet. In this case, the caller must call Connect + // or GetSockOpt(ErrorOption) when the endpoint becomes writable to + // get the actual result. The first call to Connect after the socket has + // connected returns nil. Calling connect again results in ErrAlreadyConnected. + // Anything else -- the attempt to connect failed. + Connect(address FullAddress) *Error + + // Shutdown closes the read and/or write end of the endpoint connection + // to its peer. + Shutdown(flags ShutdownFlags) *Error + + // Listen puts the endpoint in "listen" mode, which allows it to accept + // new connections. + Listen(backlog int) *Error + + // Accept returns a new endpoint if a peer has established a connection + // to an endpoint previously set to listen mode. This method does not + // block if no new connections are available. + // + // The returned Queue is the wait queue for the newly created endpoint. + Accept() (Endpoint, *waiter.Queue, *Error) + + // Bind binds the endpoint to a specific local address and port. + // Specifying a NIC is optional. + // + // An optional commit function will be executed atomically with respect + // to binding the endpoint. If this returns an error, the bind will not + // occur and the error will be propagated back to the caller. + Bind(address FullAddress, commit func() *Error) *Error + + // GetLocalAddress returns the address to which the endpoint is bound. + GetLocalAddress() (FullAddress, *Error) + + // GetRemoteAddress returns the address to which the endpoint is + // connected. + GetRemoteAddress() (FullAddress, *Error) + + // Readiness returns the current readiness of the endpoint. For example, + // if waiter.EventIn is set, the endpoint is immediately readable. + Readiness(mask waiter.EventMask) waiter.EventMask + + // SetSockOpt sets a socket option. opt should be one of the *Option types. + SetSockOpt(opt interface{}) *Error + + // GetSockOpt gets a socket option. opt should be a pointer to one of the + // *Option types. + GetSockOpt(opt interface{}) *Error +} + +// WriteOptions contains options for Endpoint.Write. +type WriteOptions struct { + // If To is not nil, write to the given address instead of the endpoint's + // peer. + To *FullAddress + + // More has the same semantics as Linux's MSG_MORE. + More bool + + // EndOfRecord has the same semantics as Linux's MSG_EOR. + EndOfRecord bool +} + +// ErrorOption is used in GetSockOpt to specify that the last error reported by +// the endpoint should be cleared and returned. +type ErrorOption struct{} + +// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send +// buffer size option. +type SendBufferSizeOption int + +// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the +// receive buffer size option. +type ReceiveBufferSizeOption int + +// SendQueueSizeOption is used in GetSockOpt to specify that the number of +// unread bytes in the output buffer should be returned. +type SendQueueSizeOption int + +// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of +// unread bytes in the input buffer should be returned. +type ReceiveQueueSizeOption int + +// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 +// socket is to be restricted to sending and receiving IPv6 packets only. +type V6OnlyOption int + +// NoDelayOption is used by SetSockOpt/GetSockOpt to specify if data should be +// sent out immediately by the transport protocol. For TCP, it determines if the +// Nagle algorithm is on or off. +type NoDelayOption int + +// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind() +// should allow reuse of local address. +type ReuseAddressOption int + +// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether +// SCM_CREDENTIALS socket control messages are enabled. +// +// Only supported on Unix sockets. +type PasscredOption int + +// TimestampOption is used by SetSockOpt/GetSockOpt to specify whether +// SO_TIMESTAMP socket control messages are enabled. +type TimestampOption int + +// TCPInfoOption is used by GetSockOpt to expose TCP statistics. +// +// TODO: Add and populate stat fields. +type TCPInfoOption struct { + RTT time.Duration + RTTVar time.Duration +} + +// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether +// TCP keepalive is enabled for this socket. +type KeepaliveEnabledOption int + +// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a +// connection must remain idle before the first TCP keepalive packet is sent. +// Once this time is reached, KeepaliveIntervalOption is used instead. +type KeepaliveIdleOption time.Duration + +// KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the +// interval between sending TCP keepalive packets. +type KeepaliveIntervalOption time.Duration + +// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number +// of un-ACKed TCP keepalives that will be sent before the connection is +// closed. +type KeepaliveCountOption int + +// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default +// TTL value for multicast messages. The default is 1. +type MulticastTTLOption uint8 + +// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to +// AddMembershipOption and RemoveMembershipOption. +type MembershipOption struct { + NIC NICID + InterfaceAddr Address + MulticastAddr Address +} + +// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast +// group identified by the given multicast address, on the interface matching +// the given interface address. +type AddMembershipOption MembershipOption + +// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast +// group identified by the given multicast address, on the interface matching +// the given interface address. +type RemoveMembershipOption MembershipOption + +// Route is a row in the routing table. It specifies through which NIC (and +// gateway) sets of packets should be routed. A row is considered viable if the +// masked target address matches the destination adddress in the row. +type Route struct { + // Destination is the address that must be matched against the masked + // target address to check if this row is viable. + Destination Address + + // Mask specifies which bits of the Destination and the target address + // must match for this row to be viable. + Mask AddressMask + + // Gateway is the gateway to be used if this row is viable. + Gateway Address + + // NIC is the id of the nic to be used if this row is viable. + NIC NICID +} + +// Match determines if r is viable for the given destination address. +func (r *Route) Match(addr Address) bool { + if len(addr) != len(r.Destination) { + return false + } + + for i := 0; i < len(r.Destination); i++ { + if (addr[i] & r.Mask[i]) != r.Destination[i] { + return false + } + } + + return true +} + +// LinkEndpointID represents a data link layer endpoint. +type LinkEndpointID uint64 + +// TransportProtocolNumber is the number of a transport protocol. +type TransportProtocolNumber uint32 + +// NetworkProtocolNumber is the number of a network protocol. +type NetworkProtocolNumber uint32 + +// A StatCounter keeps track of a statistic. +type StatCounter struct { + count uint64 +} + +// Increment adds one to the counter. +func (s *StatCounter) Increment() { + s.IncrementBy(1) +} + +// Value returns the current value of the counter. +func (s *StatCounter) Value() uint64 { + return atomic.LoadUint64(&s.count) +} + +// IncrementBy increments the counter by v. +func (s *StatCounter) IncrementBy(v uint64) { + atomic.AddUint64(&s.count, v) +} + +// IPStats collects IP-specific stats (both v4 and v6). +type IPStats struct { + // PacketsReceived is the total number of IP packets received from the link + // layer in nic.DeliverNetworkPacket. + PacketsReceived *StatCounter + + // InvalidAddressesReceived is the total number of IP packets received + // with an unknown or invalid destination address. + InvalidAddressesReceived *StatCounter + + // PacketsDelivered is the total number of incoming IP packets that + // are successfully delivered to the transport layer via HandlePacket. + PacketsDelivered *StatCounter + + // PacketsSent is the total number of IP packets sent via WritePacket. + PacketsSent *StatCounter + + // OutgoingPacketErrors is the total number of IP packets which failed + // to write to a link-layer endpoint. + OutgoingPacketErrors *StatCounter +} + +// TCPStats collects TCP-specific stats. +type TCPStats struct { + // ActiveConnectionOpenings is the number of connections opened successfully + // via Connect. + ActiveConnectionOpenings *StatCounter + + // PassiveConnectionOpenings is the number of connections opened + // successfully via Listen. + PassiveConnectionOpenings *StatCounter + + // FailedConnectionAttempts is the number of calls to Connect or Listen + // (active and passive openings, respectively) that end in an error. + FailedConnectionAttempts *StatCounter + + // ValidSegmentsReceived is the number of TCP segments received that the + // transport layer successfully parsed. + ValidSegmentsReceived *StatCounter + + // InvalidSegmentsReceived is the number of TCP segments received that + // the transport layer could not parse. + InvalidSegmentsReceived *StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent *StatCounter + + // ResetsSent is the number of TCP resets sent. + ResetsSent *StatCounter + + // ResetsReceived is the number of TCP resets received. + ResetsReceived *StatCounter +} + +// UDPStats collects UDP-specific stats. +type UDPStats struct { + // PacketsReceived is the number of UDP datagrams received via + // HandlePacket. + PacketsReceived *StatCounter + + // UnknownPortErrors is the number of incoming UDP datagrams dropped + // because they did not have a known destination port. + UnknownPortErrors *StatCounter + + // ReceiveBufferErrors is the number of incoming UDP datagrams dropped + // due to the receiving buffer being in an invalid state. + ReceiveBufferErrors *StatCounter + + // MalformedPacketsReceived is the number of incoming UDP datagrams + // dropped due to the UDP header being in a malformed state. + MalformedPacketsReceived *StatCounter + + // PacketsSent is the number of UDP datagrams sent via sendUDP. + PacketsSent *StatCounter +} + +// Stats holds statistics about the networking stack. +// +// All fields are optional. +type Stats struct { + // UnknownProtocolRcvdPackets is the number of packets received by the + // stack that were for an unknown or unsupported protocol. + UnknownProtocolRcvdPackets *StatCounter + + // MalformedRcvPackets is the number of packets received by the stack + // that were deemed malformed. + MalformedRcvdPackets *StatCounter + + // DroppedPackets is the number of packets dropped due to full queues. + DroppedPackets *StatCounter + + // IP breaks out IP-specific stats (both v4 and v6). + IP IPStats + + // TCP breaks out TCP-specific stats. + TCP TCPStats + + // UDP breaks out UDP-specific stats. + UDP UDPStats +} + +func fillIn(v reflect.Value) { + for i := 0; i < v.NumField(); i++ { + v := v.Field(i) + switch v.Kind() { + case reflect.Ptr: + if s, ok := v.Addr().Interface().(**StatCounter); ok { + if *s == nil { + *s = &StatCounter{} + } + } + case reflect.Struct: + fillIn(v) + } + } +} + +// FillIn returns a copy of s with nil fields initialized to new StatCounters. +func (s Stats) FillIn() Stats { + fillIn(reflect.ValueOf(&s).Elem()) + return s +} + +// String implements the fmt.Stringer interface. +func (a Address) String() string { + switch len(a) { + case 4: + return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3])) + case 16: + // Find the longest subsequence of hexadecimal zeros. + start, end := -1, -1 + for i := 0; i < len(a); i += 2 { + j := i + for j < len(a) && a[j] == 0 && a[j+1] == 0 { + j += 2 + } + if j > i+2 && j-i > end-start { + start, end = i, j + } + } + + var b strings.Builder + for i := 0; i < len(a); i += 2 { + if i == start { + b.WriteString("::") + i = end + if end >= len(a) { + break + } + } else if i > 0 { + b.WriteByte(':') + } + v := uint16(a[i+0])<<8 | uint16(a[i+1]) + if v == 0 { + b.WriteByte('0') + } else { + const digits = "0123456789abcdef" + for i := uint(3); i < 4; i-- { + if v := v >> (i * 4); v != 0 { + b.WriteByte(digits[v&0xf]) + } + } + } + } + return b.String() + default: + return fmt.Sprintf("%x", []byte(a)) + } +} + +// To4 converts the IPv4 address to a 4-byte representation. +// If the address is not an IPv4 address, To4 returns "". +func (a Address) To4() Address { + const ( + ipv4len = 4 + ipv6len = 16 + ) + if len(a) == ipv4len { + return a + } + if len(a) == ipv6len && + isZeros(a[0:10]) && + a[10] == 0xff && + a[11] == 0xff { + return a[12:16] + } + return "" +} + +// isZeros reports whether a is all zeros. +func isZeros(a Address) bool { + for i := 0; i < len(a); i++ { + if a[i] != 0 { + return false + } + } + return true +} + +// LinkAddress is a byte slice cast as a string that represents a link address. +// It is typically a 6-byte MAC address. +type LinkAddress string + +// String implements the fmt.Stringer interface. +func (a LinkAddress) String() string { + switch len(a) { + case 6: + return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5]) + default: + return fmt.Sprintf("%x", []byte(a)) + } +} + +// ParseMACAddress parses an IEEE 802 address. +// +// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. +func ParseMACAddress(s string) (LinkAddress, error) { + parts := strings.FieldsFunc(s, func(c rune) bool { + return c == ':' || c == '-' + }) + if len(parts) != 6 { + return "", fmt.Errorf("inconsistent parts: %s", s) + } + addr := make([]byte, 0, len(parts)) + for _, part := range parts { + u, err := strconv.ParseUint(part, 16, 8) + if err != nil { + return "", fmt.Errorf("invalid hex digits: %s", s) + } + addr = append(addr, byte(u)) + } + return LinkAddress(addr), nil +} + +// ProtocolAddress is an address and the network protocol it is associated +// with. +type ProtocolAddress struct { + // Protocol is the protocol of the address. + Protocol NetworkProtocolNumber + + // Address is a network address. + Address Address +} + +// danglingEndpointsMu protects access to danglingEndpoints. +var danglingEndpointsMu sync.Mutex + +// danglingEndpoints tracks all dangling endpoints no longer owned by the app. +var danglingEndpoints = make(map[Endpoint]struct{}) + +// GetDanglingEndpoints returns all dangling endpoints. +func GetDanglingEndpoints() []Endpoint { + es := make([]Endpoint, 0, len(danglingEndpoints)) + danglingEndpointsMu.Lock() + for e := range danglingEndpoints { + es = append(es, e) + } + danglingEndpointsMu.Unlock() + return es +} + +// AddDanglingEndpoint adds a dangling endpoint. +func AddDanglingEndpoint(e Endpoint) { + danglingEndpointsMu.Lock() + danglingEndpoints[e] = struct{}{} + danglingEndpointsMu.Unlock() +} + +// DeleteDanglingEndpoint removes a dangling endpoint. +func DeleteDanglingEndpoint(e Endpoint) { + danglingEndpointsMu.Lock() + delete(danglingEndpoints, e) + danglingEndpointsMu.Unlock() +} + +// AsyncLoading is the global barrier for asynchronous endpoint loading +// activities. +var AsyncLoading sync.WaitGroup diff --git a/netstack/tcpip/tcpip_test.go b/netstack/tcpip/tcpip_test.go new file mode 100644 index 0000000..cfde6aa --- /dev/null +++ b/netstack/tcpip/tcpip_test.go @@ -0,0 +1,195 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpip + +import ( + "net" + "testing" +) + +func TestSubnetContains(t *testing.T) { + tests := []struct { + s Address + m AddressMask + a Address + want bool + }{ + {"\xa0", "\xf0", "\x90", false}, + {"\xa0", "\xf0", "\xa0", true}, + {"\xa0", "\xf0", "\xa5", true}, + {"\xa0", "\xf0", "\xaf", true}, + {"\xa0", "\xf0", "\xb0", false}, + {"\xa0", "\xf0", "", false}, + {"\xa0", "\xf0", "\xa0\x00", false}, + {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, + {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, + {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, + {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, + } + for _, tt := range tests { + s, err := NewSubnet(tt.s, tt.m) + if err != nil { + t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err) + continue + } + if got := s.Contains(tt.a); got != tt.want { + t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want) + } + } +} + +func TestSubnetBits(t *testing.T) { + tests := []struct { + a AddressMask + want1 int + want0 int + }{ + {"\x00", 0, 8}, + {"\x00\x00", 0, 16}, + {"\x36", 4, 4}, + {"\x5c", 4, 4}, + {"\x5c\x5c", 8, 8}, + {"\x5c\x36", 8, 8}, + {"\x36\x5c", 8, 8}, + {"\x36\x36", 8, 8}, + {"\xff", 8, 0}, + {"\xff\xff", 16, 0}, + } + for _, tt := range tests { + s := &Subnet{mask: tt.a} + got1, got0 := s.Bits() + if got1 != tt.want1 || got0 != tt.want0 { + t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0) + } + } +} + +func TestSubnetPrefix(t *testing.T) { + tests := []struct { + a AddressMask + want int + }{ + {"\x00", 0}, + {"\x00\x00", 0}, + {"\x36", 0}, + {"\x86", 1}, + {"\xc5", 2}, + {"\xff\x00", 8}, + {"\xff\x36", 8}, + {"\xff\x8c", 9}, + {"\xff\xc8", 10}, + {"\xff", 8}, + {"\xff\xff", 16}, + } + for _, tt := range tests { + s := &Subnet{mask: tt.a} + got := s.Prefix() + if got != tt.want { + t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want) + } + } +} + +func TestSubnetCreation(t *testing.T) { + tests := []struct { + a Address + m AddressMask + want error + }{ + {"\xa0", "\xf0", nil}, + {"\xa0\xa0", "\xf0", errSubnetLengthMismatch}, + {"\xaa", "\xf0", errSubnetAddressMasked}, + {"", "", nil}, + } + for _, tt := range tests { + if _, err := NewSubnet(tt.a, tt.m); err != tt.want { + t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want) + } + } +} + +func TestRouteMatch(t *testing.T) { + tests := []struct { + d Address + m AddressMask + a Address + want bool + }{ + {"\xc2\x80", "\xff\xf0", "\xc2\x80", true}, + {"\xc2\x80", "\xff\xf0", "\xc2\x00", false}, + {"\xc2\x00", "\xff\xf0", "\xc2\x00", true}, + {"\xc2\x00", "\xff\xf0", "\xc2\x80", false}, + } + for _, tt := range tests { + r := Route{Destination: tt.d, Mask: tt.m} + if got := r.Match(tt.a); got != tt.want { + t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want) + } + } +} + +func TestAddressString(t *testing.T) { + for _, want := range []string{ + // Taken from stdlib. + "2001:db8::123:12:1", + "2001:db8::1", + "2001:db8:0:1:0:1:0:1", + "2001:db8:1:0:1:0:1:0", + "2001::1:0:0:1", + "2001:db8:0:0:1::", + "2001:db8::1:0:0:1", + "2001:db8::a:b:c:d", + + // Leading zeros. + "::1", + // Trailing zeros. + "8::", + // No zeros. + "1:1:1:1:1:1:1:1", + // Longer sequence is after other zeros, but not at the end. + "1:0:0:1::1", + // Longer sequence is at the beginning, shorter sequence is at + // the end. + "::1:1:1:0:0", + // Longer sequence is not at the beginning, shorter sequence is + // at the end. + "1::1:1:0:0", + // Longer sequence is at the beginning, shorter sequence is not + // at the end. + "::1:1:0:0:1", + // Neither sequence is at an end, longer is after shorter. + "1:0:0:1::1", + // Shorter sequence is at the beginning, longer sequence is not + // at the end. + "0:0:1:1::1", + // Shorter sequence is at the beginning, longer sequence is at + // the end. + "0:0:1:1:1::", + // Short sequences at both ends, longer one in the middle. + "0:1:1::1:1:0", + // Short sequences at both ends, longer one in the middle. + "0:1::1:0:0", + // Short sequences at both ends, longer one in the middle. + "0:0:1::1:0", + // Longer sequence surrounded by shorter sequences, but none at + // the end. + "1:0:1::1:0:1", + } { + addr := Address(net.ParseIP(want)) + if got := addr.String(); got != want { + t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want) + } + } +} diff --git a/netstack/tcpip/time.s b/netstack/tcpip/time.s new file mode 100644 index 0000000..e07189f --- /dev/null +++ b/netstack/tcpip/time.s @@ -0,0 +1,15 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Empty assembly file so empty func definitions work. diff --git a/netstack/tcpip/time_unsafe.go b/netstack/tcpip/time_unsafe.go new file mode 100644 index 0000000..a6d6a17 --- /dev/null +++ b/netstack/tcpip/time_unsafe.go @@ -0,0 +1,42 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build go1.9 + +package tcpip + +import ( + _ "time" // Used with go:linkname. + _ "unsafe" // Required for go:linkname. +) + +// StdClock implements Clock with the time package. +type StdClock struct{} + +var _ Clock = (*StdClock)(nil) + +//go:linkname now time.now +func now() (sec int64, nsec int32, mono int64) + +// NowNanoseconds implements Clock.NowNanoseconds. +func (*StdClock) NowNanoseconds() int64 { + sec, nsec, _ := now() + return sec*1e9 + int64(nsec) +} + +// NowMonotonic implements Clock.NowMonotonic. +func (*StdClock) NowMonotonic() int64 { + _, _, mono := now() + return mono +} diff --git a/netstack/tcpip/transport/ping/endpoint.go b/netstack/tcpip/transport/ping/endpoint.go new file mode 100644 index 0000000..73063ab --- /dev/null +++ b/netstack/tcpip/transport/ping/endpoint.go @@ -0,0 +1,717 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ping + +import ( + "encoding/binary" + "sync" + + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +// +stateify savable +// ping包的结构体 +type pingPacket struct { + pingPacketEntry + senderAddress tcpip.FullAddress + data buffer.VectorisedView + timestamp int64 + hasTimestamp bool + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View +} + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateConnected + stateClosed +) + +// endpoint represents a ping endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// 表示ping端,用作用户端和协议栈之间的接口 +type endpoint struct { + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue, and are + // protected by rcvMu. + rcvMu sync.Mutex + rcvReady bool + rcvList pingPacketList + rcvBufSizeMax int + rcvBufSize int + rcvClosed bool + rcvTimestamp bool + + // The following fields are protected by the mu mutex. + mu sync.RWMutex + sndBufSize int + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + bindAddr tcpip.Address + regNICID tcpip.NICID + route stack.Route +} + +// 根据参数,新建一个ping端 +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, + transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + return &endpoint{ + stack: stack, + netProto: netProto, + transProto: transProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } +} + +// Close puts the endpoint in a closed state and frees all resources +// associated with it. +// 关闭ping端,回收资源 +func (e *endpoint) Close() { + e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite + switch e.state { + case stateBound, stateConnected: + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id) + } + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = stateClosed + + e.mu.Unlock() + + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// Read reads data from the endpoint. This method does not block if +// there is no data pending. +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.rcvMu.Lock() + + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + ts := e.rcvTimestamp + + e.rcvMu.Unlock() + + if addr != nil { + *addr = p.senderAddress + } + + if ts && !p.hasTimestamp { + // Linux uses the current time. + p.timestamp = e.stack.NowNanoseconds() + } + + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil +} + +// prepareForWrite prepares the endpoint for sending data. In particular, it +// binds it if it's still in the initial state. To do so, it must first +// reacquire the mutex in exclusive mode. +// +// Returns true for retry if preparation should be retried. +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case stateInitial: + case stateConnected: + return false, nil + + case stateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != stateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil { + return false, err + } + + return true, nil +} + +// Write writes data to the endpoint's peer. This method does not block +// if the data cannot be written. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, nil, tcpip.ErrClosedForSend + } + + // Prepare for write. + for { + retry, err := e.prepareForWrite(to) + if err != nil { + return 0, nil, err + } + + if !retry { + break + } + } + + var route *stack.Route + if to == nil { + route = &e.route + + if route.IsResolutionRequired() { + // Promote lock to exclusive if using a shared route, given that it may + // need to change in Route.Resolve() call below. + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // Recheck state after lock was re-acquired. + if e.state != stateConnected { + return 0, nil, tcpip.ErrInvalidEndpointState + } + } + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicid := to.NIC + if e.bindNICID != 0 { + if nicid != 0 && nicid != e.bindNICID { + return 0, nil, tcpip.ErrNoRoute + } + + nicid = e.bindNICID + } + + toCopy := *to + to = &toCopy + netProto, err := e.checkV4Mapped(to, true) + if err != nil { + return 0, nil, err + } + + // Find the enpoint. + r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto) + if err != nil { + return 0, nil, err + } + defer r.Release() + + route = &r + } + + if route.IsResolutionRequired() { + waker := &sleep.Waker{} + if ch, err := route.Resolve(waker); err != nil { + if err == tcpip.ErrWouldBlock { + // Link address needs to be resolved. Resolution was triggered the + // background. Better luck next time. + route.RemoveWaker(waker) + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + v, err := p.Get(p.Size()) + if err != nil { + return 0, nil, err + } + + switch e.netProto { + case header.IPv4ProtocolNumber: + err = sendPing4(route, e.id.LocalPort, v) + + case header.IPv6ProtocolNumber: + err = sendPing6(route, e.id.LocalPort, v) + } + + return uintptr(len(v)), nil, err +} + +// Peek only returns data from a single datagram, so do nothing here. +func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// SetSockOpt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.TimestampOption: + e.rcvMu.Lock() + e.rcvTimestamp = v != 0 + e.rcvMu.Unlock() + } + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + e.mu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + e.rcvMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + e.rcvMu.Lock() + if e.rcvList.Empty() { + *o = 0 + } else { + p := e.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + e.rcvMu.Unlock() + return nil + + case *tcpip.TimestampOption: + e.rcvMu.Lock() + *o = 0 + if e.rcvTimestamp { + *o = 1 + } + e.rcvMu.Unlock() + } + + return tcpip.ErrUnknownProtocolOption +} + +func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { + if len(data) < header.ICMPv4EchoMinimumSize { + return tcpip.ErrInvalidEndpointState + } + + // Set the ident. Sequence number is provided by the user. + binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident) + + hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) + + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + copy(icmpv4, data) + data = data[header.ICMPv4EchoMinimumSize:] + + // Linux performs these basic checks. + if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { + return tcpip.ErrInvalidEndpointState + } + + icmpv4.SetChecksum(0) + icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + + return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) +} + +func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { + if len(data) < header.ICMPv6EchoMinimumSize { + return tcpip.ErrInvalidEndpointState + } + + // Set the ident. Sequence number is provided by the user. + binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident) + + hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength())) + + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(icmpv6, data) + data = data[header.ICMPv6EchoMinimumSize:] + + if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { + return tcpip.ErrInvalidEndpointState + } + + icmpv6.SetChecksum(0) + icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + + return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + return 0, tcpip.ErrNoRoute + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. Specifying a NIC is optional. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + nicid := addr.NIC + localPort := uint16(0) + switch e.state { + case stateBound, stateConnected: + localPort = e.id.LocalPort + if e.bindNICID == 0 { + break + } + + if nicid != 0 && nicid != e.bindNICID { + return tcpip.ErrInvalidEndpointState + } + + nicid = e.bindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto) + if err != nil { + return err + } + defer r.Release() + + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + LocalPort: localPort, + RemoteAddress: r.RemoteAddress, + } + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + + id, err = e.registerWithStack(nicid, netProtos, id) + if err != nil { + return err + } + + e.id = id + e.route = r.Clone() + e.regNICID = nicid + + e.state = stateConnected + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection +// to its peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + e.shutdownFlags |= flags + + if e.state != stateConnected { + return tcpip.ErrNotConnected + } + + if flags&tcpip.ShutdownRead != 0 { + e.rcvMu.Lock() + wasClosed := e.rcvClosed + e.rcvClosed = true + e.rcvMu.Unlock() + + if !wasClosed { + e.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen is not supported by UDP, it just fails. +func (*endpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept is not supported by UDP, it just fails. +func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, + id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if id.LocalPort != 0 { + // The endpoint already has a local port, just attempt to + // register it. + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e) + return id, err + } + + // We need to find a port for the endpoint. + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + id.LocalPort = p + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e) + switch err { + case nil: + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }) + + return id, err +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + + if len(addr.Addr) != 0 { + // A local address was specified, verify that it's valid. + if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + } + + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err = e.registerWithStack(addr.NIC, netProtos, id) + if err != nil { + return err + } + if commit != nil { + if err := commit(); err != nil { + // Unregister, the commit failed. + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id) + return err + } + } + + e.id = id + e.regNICID = addr.NIC + + // Mark endpoint as bound. + e.state = stateBound + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// Bind binds the endpoint to a specific local address and port. +// Specifying a NIC is optional. +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + err := e.bindLocked(addr, commit) + if err != nil { + return err + } + + e.bindNICID = addr.NIC + e.bindAddr = addr.Addr + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + }, nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + e.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + return + } + + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + pkt := &pingPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + }, + } + pkt.data = vv.Clone(pkt.views[:]) + e.rcvList.PushBack(pkt) + e.rcvBufSize += vv.Size() + + if e.rcvTimestamp { + pkt.timestamp = e.stack.NowNanoseconds() + pkt.hasTimestamp = true + } + + e.rcvMu.Unlock() + + // Notify any waiters that there's data to be read now. + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, + typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +} diff --git a/netstack/tcpip/transport/ping/ping_packet_list.go b/netstack/tcpip/transport/ping/ping_packet_list.go new file mode 100644 index 0000000..5fb79d3 --- /dev/null +++ b/netstack/tcpip/transport/ping/ping_packet_list.go @@ -0,0 +1,173 @@ +package ping + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type pingPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (pingPacketElementMapper) linkerFor(elem *pingPacket) *pingPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type pingPacketList struct { + head *pingPacket + tail *pingPacket +} + +// Reset resets list l to the empty state. +func (l *pingPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *pingPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *pingPacketList) Front() *pingPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *pingPacketList) Back() *pingPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *pingPacketList) PushFront(e *pingPacket) { + pingPacketElementMapper{}.linkerFor(e).SetNext(l.head) + pingPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + pingPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *pingPacketList) PushBack(e *pingPacket) { + pingPacketElementMapper{}.linkerFor(e).SetNext(nil) + pingPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + pingPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *pingPacketList) PushBackList(m *pingPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + pingPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + pingPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *pingPacketList) InsertAfter(b, e *pingPacket) { + a := pingPacketElementMapper{}.linkerFor(b).Next() + pingPacketElementMapper{}.linkerFor(e).SetNext(a) + pingPacketElementMapper{}.linkerFor(e).SetPrev(b) + pingPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + pingPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *pingPacketList) InsertBefore(a, e *pingPacket) { + b := pingPacketElementMapper{}.linkerFor(a).Prev() + pingPacketElementMapper{}.linkerFor(e).SetNext(a) + pingPacketElementMapper{}.linkerFor(e).SetPrev(b) + pingPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + pingPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *pingPacketList) Remove(e *pingPacket) { + prev := pingPacketElementMapper{}.linkerFor(e).Prev() + next := pingPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + pingPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + pingPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type pingPacketEntry struct { + next *pingPacket + prev *pingPacket +} + +// Next returns the entry that follows e in the list. +func (e *pingPacketEntry) Next() *pingPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *pingPacketEntry) Prev() *pingPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *pingPacketEntry) SetNext(elem *pingPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *pingPacketEntry) SetPrev(elem *pingPacket) { + e.prev = elem +} diff --git a/netstack/tcpip/transport/ping/protocol.go b/netstack/tcpip/transport/ping/protocol.go new file mode 100644 index 0000000..865feaa --- /dev/null +++ b/netstack/tcpip/transport/ping/protocol.go @@ -0,0 +1,124 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ping contains the implementation of the ICMP and IPv6-ICMP transport +// protocols for use in ping. To use it in the networking stack, this package +// must be added to the project, and +// activated on the stack by passing ping.ProtocolName (or "ping") and/or +// ping.ProtocolName6 (or "ping6") as one of the transport protocols when +// calling stack.New(). Then endpoints can be created by passing +// ping.ProtocolNumber or ping.ProtocolNumber6 as the transport protocol number +// when calling Stack.NewEndpoint(). +package ping + +import ( + "encoding/binary" + "fmt" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +const ( + // ProtocolName4 is the string representation of the ping protocol name. + ProtocolName4 = "ping4" + + // ProtocolNumber4 is the ICMP protocol number. + ProtocolNumber4 = header.ICMPv4ProtocolNumber + + // ProtocolName6 is the string representation of the ping protocol name. + ProtocolName6 = "ping6" + + // ProtocolNumber6 is the IPv6-ICMP protocol number. + ProtocolNumber6 = header.ICMPv6ProtocolNumber +) + +type protocol struct { + number tcpip.TransportProtocolNumber +} + +// Number returns the ICMP protocol number. +func (p *protocol) Number() tcpip.TransportProtocolNumber { + return p.number +} + +func (p *protocol) netProto() tcpip.NetworkProtocolNumber { + switch p.number { + case ProtocolNumber4: + return header.IPv4ProtocolNumber + case ProtocolNumber6: + return header.IPv6ProtocolNumber + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// NewEndpoint creates a new ping endpoint. +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return newEndpoint(stack, netProto, p.number, waiterQueue), nil +} + +// MinimumPacketSize returns the minimum valid ping packet size. +func (p *protocol) MinimumPacketSize() int { + switch p.number { + case ProtocolNumber4: + return header.ICMPv4EchoMinimumSize + case ProtocolNumber6: + return header.ICMPv6EchoMinimumSize + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// ParsePorts returns the source and destination ports stored in the given ping +// packet. +func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + switch p.number { + case ProtocolNumber4: + return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil + case ProtocolNumber6: + return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { + return true +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol { + return &protocol{ProtocolNumber4} + }) + + stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol { + return &protocol{ProtocolNumber6} + }) +} diff --git a/netstack/tcpip/transport/tcp/accept.go b/netstack/tcpip/transport/tcp/accept.go new file mode 100644 index 0000000..3dab278 --- /dev/null +++ b/netstack/tcpip/transport/tcp/accept.go @@ -0,0 +1,443 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "encoding/binary" + "hash" + "io" + "sync" + "time" + + "crypto/sha1" + "tcpip/netstack/rand" + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +const ( + // tsLen is the length, in bits, of the timestamp in the SYN cookie. + tsLen = 8 + + // tsMask is a mask for timestamp values (i.e., tsLen bits). + tsMask = (1 << tsLen) - 1 + + // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. + tsOffset = 24 + + // hashMask is the mask for hash values (i.e., tsOffset bits). + hashMask = (1 << tsOffset) - 1 + + // maxTSDiff is the maximum allowed difference between a received cookie + // timestamp and the current timestamp. If the difference is greater + // than maxTSDiff, the cookie is expired. + maxTSDiff = 2 +) + +var ( + // SynRcvdCountThreshold is the global maximum number of connections + // that are allowed to be in SYN-RCVD state before TCP starts using SYN + // cookies to accept connections. + // + // It is an exported variable only for testing, and should not otherwise + // be used by importers of this package. + SynRcvdCountThreshold uint64 = 1000 + + // mssTable is a slice containing the possible MSS values that we + // encode in the SYN cookie with two bits. + mssTable = []uint16{536, 1300, 1440, 1460} +) + +func encodeMSS(mss uint16) uint32 { + for i := len(mssTable) - 1; i > 0; i-- { + if mss >= mssTable[i] { + return uint32(i) + } + } + return 0 +} + +// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is +// protected by a mutex so that we can increment only when it's guaranteed not +// to go above a threshold. +var synRcvdCount struct { + sync.Mutex + value uint64 + pending sync.WaitGroup +} + +// listenContext is used by a listening endpoint to store state used while +// listening for connections. This struct is allocated by the listen goroutine +// and must not be accessed or have its methods called concurrently as they +// may mutate the stored objects. +type listenContext struct { + stack *stack.Stack + rcvWnd seqnum.Size + nonce [2][sha1.BlockSize]byte + + hasherMu sync.Mutex + hasher hash.Hash + v6only bool + netProto tcpip.NetworkProtocolNumber +} + +// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. +func timeStamp() uint32 { + return uint32(time.Now().Unix()>>6) & tsMask +} + +// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD +// state. It succeeds if the increment doesn't make the count go beyond the +// threshold, and fails otherwise. +func incSynRcvdCount() bool { + synRcvdCount.Lock() + defer synRcvdCount.Unlock() + + if synRcvdCount.value >= SynRcvdCountThreshold { + return false + } + + synRcvdCount.pending.Add(1) + synRcvdCount.value++ + + return true +} + +// decSynRcvdCount atomically decrements the global number of endpoints in +// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount +// succeeded. +func decSynRcvdCount() { + synRcvdCount.Lock() + defer synRcvdCount.Unlock() + + synRcvdCount.value-- + synRcvdCount.pending.Done() +} + +// newListenContext creates a new listen context. +func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { + l := &listenContext{ + stack: stack, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + } + + rand.Read(l.nonce[0][:]) + rand.Read(l.nonce[1][:]) + + return l +} + +// cookieHash calculates the cookieHash for the given id, timestamp and nonce +// index. The hash is used to create and validate cookies. +func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { + + // Initialize block with fixed-size data: local ports and v. + var payload [8]byte + binary.BigEndian.PutUint16(payload[0:], id.LocalPort) + binary.BigEndian.PutUint16(payload[2:], id.RemotePort) + binary.BigEndian.PutUint32(payload[4:], ts) + + // Feed everything to the hasher. + l.hasherMu.Lock() + l.hasher.Reset() + l.hasher.Write(payload[:]) + l.hasher.Write(l.nonce[nonceIndex][:]) + io.WriteString(l.hasher, string(id.LocalAddress)) + io.WriteString(l.hasher, string(id.RemoteAddress)) + + // Finalize the calculation of the hash and return the first 4 bytes. + h := make([]byte, 0, sha1.Size) + h = l.hasher.Sum(h) + l.hasherMu.Unlock() + + return binary.BigEndian.Uint32(h[:]) +} + +// createCookie creates a SYN cookie for the given id and incoming sequence +// number. +func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { + ts := timeStamp() + v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) + v += (l.cookieHash(id, ts, 1) + data) & hashMask + return seqnum.Value(v) +} + +// isCookieValid checks if the supplied cookie is valid for the given id and +// sequence number. If it is, it also returns the data originally encoded in the +// cookie when createCookie was called. +func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { + ts := timeStamp() + v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) + cookieTS := v >> tsOffset + if ((ts - cookieTS) & tsMask) > maxTSDiff { + return 0, false + } + + return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true +} + +// createConnectedEndpoint creates a new connected endpoint, with the connection +// parameters given by the arguments. +func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // Create a new endpoint. + netProto := l.netProto + if netProto == 0 { + netProto = s.route.NetProto + } + n := newEndpoint(l.stack, netProto, nil) + n.v6only = l.v6only + n.id = s.id + n.boundNICID = s.route.NICID() + n.route = s.route.Clone() + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.rcvBufSize = int(l.rcvWnd) + + n.maybeEnableTimestamp(rcvdSynOpts) + n.maybeEnableSACKPermitted(rcvdSynOpts) + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil { + n.Close() + return nil, err + } + + n.isRegistered = true + n.state = stateConnected + + // Create sender and receiver. + // + // The receiver at least temporarily has a zero receive window scale, + // but the caller may change it (before starting the protocol loop). + n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) + n.rcv = newReceiver(n, irs, l.rcvWnd, 0) + + return n, nil +} + +// createEndpoint creates a new endpoint in connected state and then performs +// the TCP 3-way handshake. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // Create new endpoint. + irs := s.sequenceNumber + cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) + ep, err := l.createConnectedEndpoint(s, cookie, irs, opts) + if err != nil { + return nil, err + } + + // Perform the 3-way handshake. + // 执行三次握手 + h, err := newHandshake(ep, l.rcvWnd) + if err != nil { + ep.Close() + return nil, err + } + + // 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack + h.resetToSynRcvd(cookie, irs, opts) + if err := h.execute(); err != nil { + ep.Close() + return nil, err + } + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + ep.rcv.rcvWndScale = h.effectiveRcvWndScale() + + return ep, nil +} + +// deliverAccepted delivers the newly-accepted endpoint to the listener. If the +// endpoint has transitioned out of the listen state, the new endpoint is closed +// instead. +func (e *endpoint) deliverAccepted(n *endpoint) { + e.mu.RLock() + if e.state == stateListen { + e.acceptedChan <- n + e.waiterQueue.Notify(waiter.EventIn) + } else { + n.Close() + } + e.mu.RUnlock() +} + +// handleSynSegment is called in its own goroutine once the listening endpoint +// receives a SYN segment. It is responsible for completing the handshake and +// queueing the new endpoint for acceptance. +// +// A limited number of these goroutines are allowed before TCP starts using SYN +// cookies to accept connections. +// 一旦侦听端点收到SYN段,handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。 +// 在TCP开始使用SYN cookie接受连接之前,允许使用有限数量的这些goroutine。 +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { + defer decSynRcvdCount() + defer s.decRef() + + n, err := ctx.createEndpointAndPerformHandshake(s, opts) + if err != nil { + return + } + // 到这里,三次握手已经完成,那么分发一个新的连接 + e.deliverAccepted(n) +} + +// handleListenSegment is called when a listening endpoint receives a segment +// and needs to handle it. +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + switch s.flags { + case flagSyn: + // syn报文处理 + // 分析tcp选项 + opts := parseSynSegmentOptions(s) + // 如果处于 syn-recv 状态的连接没有超过 1000,那么不需要生成 cookie, + // 而是直接分配和新建连接。 + if incSynRcvdCount() { + s.incRef() + go e.handleSynSegment(ctx, s, &opts) + } else { + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + // Send SYN with window scaling because we currently + // dont't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(timeStampOffset()), + TSEcr: opts.TSVal, + } + // 返回 syn+ack 报文 + sendSynTCP(&s.route, s.id, flagSyn|flagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + } + + case flagAck: + // 三次握手最后一次 ack 报文 + if data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1); ok && int(data) < len(mssTable) { + // Create newly accepted endpoint and deliver it. + rcvdSynOptions := &header.TCPSynOptions{ + MSS: mssTable[data], + // Disable Window scaling as original SYN is + // lost. + WS: -1, + } + // When syn cookies are in use we enable timestamp only + // if the ack specifies the timestamp option assuming + // that the other end did in fact negotiate the + // timestamp option in the original SYN. + if s.parsedOptions.TS { + rcvdSynOptions.TS = true + rcvdSynOptions.TSVal = s.parsedOptions.TSVal + rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr + } + + // 三次握手已经完成,新建一个tcp连接 + n, err := ctx.createConnectedEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + if err == nil { + // clear the tsOffset for the newly created + // endpoint as the Timestamp was already + // randomly offset when the original SYN-ACK was + // sent above. + n.tsOffset = 0 + e.deliverAccepted(n) + } + } + } +} + +// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in +// its own goroutine and is responsible for handling connection requests. +// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行,负责处理连接请求。 +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + defer func() { + // Mark endpoint as closed. This will prevent goroutines running + // handleSynSegment() from attempting to queue new connections + // to the endpoint. + e.mu.Lock() + e.state = stateClosed + + // Do cleanup if needed. + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + e.mu.Unlock() + + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + }() + + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + + ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto) + + // 事件触发器的初始化和事件的添加 + s := sleep.Sleeper{} + s.AddWaker(&e.notificationWaker, wakerForNotification) + s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) + // for循环拿到事件index,然后进行相应的处理 + for { + switch index, _ := s.Fetch(true); index { + case wakerForNotification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + s := e.segmentQueue.dequeue() + e.handleListenSegment(ctx, s) + s.decRef() + } + synRcvdCount.pending.Wait() + close(e.drainDone) + <-e.undrain + } + + case wakerForNewSegment: + // Process at most maxSegmentsPerWake segments. + // 接收和处理tcp报文 + mayRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + mayRequeue = false + break + } + + // 处理tcp数据段 s + e.handleListenSegment(ctx, s) + s.decRef() + } + + // If the queue is not empty, make sure we'll wake up + // in the next iteration. + if mayRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + } + } +} diff --git a/netstack/tcpip/transport/tcp/connect.go b/netstack/tcpip/transport/tcp/connect.go new file mode 100644 index 0000000..379ab28 --- /dev/null +++ b/netstack/tcpip/transport/tcp/connect.go @@ -0,0 +1,1158 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "log" + "sync" + "time" + + "tcpip/netstack/rand" + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +// maxSegmentsPerWake is the maximum number of segments to process in the main +// protocol goroutine per wake-up. Yielding [after this number of segments are +// processed] allows other events to be processed as well (e.g., timeouts, +// resets, etc.). +const maxSegmentsPerWake = 100 + +type handshakeState int + +// The following are the possible states of the TCP connection during a 3-way +// handshake. A depiction of the states and transitions can be found in RFC 793, +// page 23. +const ( + handshakeSynSent handshakeState = iota + handshakeSynRcvd + handshakeCompleted +) + +// The following are used to set up sleepers. +const ( + wakerForNotification = iota + wakerForNewSegment + wakerForResend + wakerForResolution +) + +const ( + // Maximum space available for options. + // tcp选项的最大长度 + maxOptionSize = 40 +) + +// handshake holds the state used during a TCP 3-way handshake. +// tcp三次握手时候使用的对象 +type handshake struct { + ep *endpoint + // 握手的状态 + state handshakeState + active bool + flags uint8 + ackNum seqnum.Value + + // iss is the initial send sequence number, as defined in RFC 793. + // 初始序列号 + iss seqnum.Value + + // rcvWnd is the receive window, as defined in RFC 793. + // 接收窗口 + rcvWnd seqnum.Size + + // sndWnd is the send window, as defined in RFC 793. + // 发送窗口 + sndWnd seqnum.Size + + // mss is the maximum segment size received from the peer. + // 最大报文段大小 + mss uint16 + + // sndWndScale is the send window scale, as defined in RFC 1323. A + // negative value means no scaling is supported by the peer. + // 发送窗口扩展因子 + sndWndScale int + + // rcvWndScale is the receive window scale, as defined in RFC 1323. + // 接收窗口扩展因子 + rcvWndScale int +} + +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { + h := handshake{ + ep: ep, + active: true, + rcvWnd: rcvWnd, + rcvWndScale: FindWndScale(rcvWnd), // 接收窗口扩展因子 + } + if err := h.resetState(); err != nil { + return handshake{}, err + } + + return h, nil +} + +// FindWndScale determines the window scale to use for the given maximum window +// size. +// 因为窗口的大小不能超过序列号范围的一半,即窗口最大2^30, +// so (2^16)*(2^maxWnsScale) < 2^30,get maxWnsScale = 14 +func FindWndScale(wnd seqnum.Size) int { + if wnd < 0x10000 { + return 0 + } + + max := seqnum.Size(0xffff) + s := 0 + for wnd > max && s < header.MaxWndScale { + s++ + max <<= 1 + } + + return s +} + +// resetState resets the state of the handshake object such that it becomes +// ready for a new 3-way handshake. +func (h *handshake) resetState() *tcpip.Error { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + + // 初始化状态为 SynSent + h.state = handshakeSynSent + h.flags = flagSyn + h.ackNum = 0 + h.mss = 0 + h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + + return nil +} + +// effectiveRcvWndScale returns the effective receive window scale to be used. +// If the peer doesn't support window scaling, the effective rcv wnd scale is +// zero; otherwise it's the value calculated based on the initial rcv wnd. +func (h *handshake) effectiveRcvWndScale() uint8 { + if h.sndWndScale < 0 { + return 0 + } + return uint8(h.rcvWndScale) +} + +// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD +// state. +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { + h.active = false + h.state = handshakeSynRcvd + h.flags = flagSyn | flagAck + h.iss = iss + h.ackNum = irs + 1 + h.mss = opts.MSS + h.sndWndScale = opts.WS +} + +// checkAck checks if the ACK number, if present, of a segment received during +// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in +// response. +func (h *handshake) checkAck(s *segment) bool { + if s.flagIsSet(flagAck) && s.ackNumber != h.iss+1 { + // RFC 793, page 36, states that a reset must be generated when + // the connection is in any non-synchronized state and an + // incoming segment acknowledges something not yet sent. The + // connection remains in the same state. + ack := s.sequenceNumber.Add(s.logicalLen()) + h.ep.sendRaw(buffer.VectorisedView{}, flagRst|flagAck, s.ackNumber, ack, 0) + return false + } + + return true +} + +// synSentState handles a segment received when the TCP 3-way handshake is in +// the SYN-SENT state. +// synSentState 是客户端或者服务端接收到第一个握手报文的处理 +// 正常情况下,如果是客户端,此时应该收到 syn+ack 报文,处理后发送 ack 报文给服务端。 +// 如果是服务端,此时接收到syn报文,那么应该回复 syn+ack 报文给客户端,并设置状态为 handshakeSynRcvd。 +func (h *handshake) synSentState(s *segment) *tcpip.Error { + // RFC 793, page 37, states that in the SYN-SENT state, a reset is + // acceptable if the ack field acknowledges the SYN. + if s.flagIsSet(flagRst) { + if s.flagIsSet(flagAck) && s.ackNumber == h.iss+1 { + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + // We are in the SYN-SENT state. We only care about segments that have + // the SYN flag. + if !s.flagIsSet(flagSyn) { + return nil + } + + // Parse the SYN options. + rcvSynOpts := parseSynSegmentOptions(s) + + // Remember if the Timestamp option was negotiated. + h.ep.maybeEnableTimestamp(&rcvSynOpts) + + // Remember if the SACKPermitted option was negotiated. + h.ep.maybeEnableSACKPermitted(&rcvSynOpts) + + // Remember the sequence we'll ack from now on. + h.ackNum = s.sequenceNumber + 1 + h.flags |= flagAck + h.mss = rcvSynOpts.MSS + h.sndWndScale = rcvSynOpts.WS + + // If this is a SYN ACK response, we only need to acknowledge the SYN + // and the handshake is completed. + // 客户端接收到了 syn+ack 报文 + if s.flagIsSet(flagAck) { + // 客户端握手完成,发送 ack 报文给服务端 + h.state = handshakeCompleted + // 最后依次 ack 报文丢了也没关系,因为后面一但发送任何数据包都是带ack的 + h.ep.sendRaw(buffer.VectorisedView{}, flagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) + return nil + } + + // A SYN segment was received, but no ACK in it. We acknowledge the SYN + // but resend our own SYN and wait for it to be acknowledged in the + // SYN-RCVD state. + // 服务端收到了 syn 报文,应该回复客户端 syn+ack 报文,且设置状态为 handshakeSynRcvd + h.state = handshakeSynRcvd + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: rcvSynOpts.TS, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + + // We only send SACKPermitted if the other side indicated it + // permits SACK. This is not explicitly defined in the RFC but + // this is the behaviour implemented by Linux. + SACKPermitted: rcvSynOpts.SACKPermitted, + } + // 发送 syn+ack 报文,如果该报文在链路中丢了,没有关系,客户端会重新发送 syn 报文 + sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + return nil +} + +// synRcvdState handles a segment received when the TCP 3-way handshake is in +// the SYN-RCVD state. +// 正常情况下,会调用该函数来处理第三次 ack 报文 +func (h *handshake) synRcvdState(s *segment) *tcpip.Error { + if s.flagIsSet(flagRst) { + // RFC 793, page 37, states that in the SYN-RCVD state, a reset + // is acceptable if the sequence number is in the window. + if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + // 如果是syn报文,且序列号对应不上,那么返回 rst + if s.flagIsSet(flagSyn) && s.sequenceNumber != h.ackNum-1 { + // We received two SYN segments with different sequence + // numbers, so we reset this and restart the whole + // process, except that we don't reset the timer. + ack := s.sequenceNumber.Add(s.logicalLen()) + seq := seqnum.Value(0) + if s.flagIsSet(flagAck) { + seq = s.ackNumber + } + h.ep.sendRaw(buffer.VectorisedView{}, flagRst|flagAck, seq, ack, 0) + + if !h.active { + return tcpip.ErrInvalidEndpointState + } + + if err := h.resetState(); err != nil { + return err + } + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: h.ep.sendTSOk, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + SACKPermitted: h.ep.sackPermitted, + } + sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + return nil + } + + // We have previously received (and acknowledged) the peer's SYN. If the + // peer acknowledges our SYN, the handshake is completed. + // 如果是 ack 报文,表明三次握手已经完成 + if s.flagIsSet(flagAck) { + + // If the timestamp option is negotiated and the segment does + // not carry a timestamp option then the segment must be dropped + // as per https://tools.ietf.org/html/rfc7323#section-3.2. + if h.ep.sendTSOk && !s.parsedOptions.TS { + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + + // Update timestamp if required. See RFC7323, section-4.3. + h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) + + h.state = handshakeCompleted + return nil + } + + return nil +} + +// 握手的时候处理tcp段 +func (h *handshake) handleSegment(s *segment) *tcpip.Error { + h.sndWnd = s.window + if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { + h.sndWnd <<= uint8(h.sndWndScale) + } + + switch h.state { + case handshakeSynRcvd: + // 正常情况下,服务端接收客户端第三次 ack 报文 + return h.synRcvdState(s) + case handshakeSynSent: + // 客户端发送了syn报文后的处理 + return h.synSentState(s) + } + return nil +} + +// processSegments goes through the segment queue and processes up to +// maxSegmentsPerWake (if they're available). +func (h *handshake) processSegments() *tcpip.Error { + for i := 0; i < maxSegmentsPerWake; i++ { + // 从队列中取出一个tcp段 + s := h.ep.segmentQueue.dequeue() + if s == nil { + return nil + } + + // 处理tcp段 + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + + // We stop processing packets once the handshake is completed, + // otherwise we may process packets meant to be processed by + // the main protocol goroutine. + if h.state == handshakeCompleted { + break + } + } + + // If the queue is not empty, make sure we'll wake up in the next + // iteration. + if !h.ep.segmentQueue.empty() { + h.ep.newSegmentWaker.Assert() + } + + return nil +} + +func (h *handshake) resolveRoute() *tcpip.Error { + log.Printf("tcp resolveRoute") + // Set up the wakers. + s := sleep.Sleeper{} + resolutionWaker := &sleep.Waker{} + s.AddWaker(resolutionWaker, wakerForResolution) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + defer s.Done() + + // Initial action is to resolve route. + index := wakerForResolution + for { + switch index { + case wakerForResolution: + if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + // Either success (err == nil) or failure. + return err + } + // Resolution not completed. Keep trying... + + case wakerForNotification: + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + h.ep.route.RemoveWaker(resolutionWaker) + return tcpip.ErrAborted + } + if n¬ifyDrain != 0 { + close(h.ep.drainDone) + <-h.ep.undrain + } + } + + // Wait for notification. + index, _ = s.Fetch(true) + } +} + +// execute executes the TCP 3-way handshake. +// 执行tcp 3次握手,客户端和服务端都是调用该函数来实现三次握手 +/* + c s + | | + sync_sent|------sync---->|sync_rcvd + | | + | | + established|<--sync|ack----| + | | + | | + |------ack----->|established +*/ +func (h *handshake) execute() *tcpip.Error { + // 是否需要拿到下一条地址 + if h.ep.route.IsResolutionRequired() { + if err := h.resolveRoute(); err != nil { + return err + } + } + + // Initialize the resend timer. + // 初始化重传定时器 + resendWaker := sleep.Waker{} + // 设置1s超时 + timeOut := time.Duration(time.Second) + rt := time.AfterFunc(timeOut, func() { + resendWaker.Assert() + }) + defer rt.Stop() + + // Set up the wakers. + s := sleep.Sleeper{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + + var sackEnabled SACKEnabled + // 是否开启 sack + if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { + // If stack returned an error when checking for SACKEnabled + // status then just default to switching off SACK negotiation. + sackEnabled = false + } + + // Send the initial SYN segment and loop until the handshake is + // completed. + // sync报文的选项参数 + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: true, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + SACKPermitted: bool(sackEnabled), + } + + // Execute is also called in a listen context so we want to make sure we + // only send the TS/SACK option when we received the TS/SACK in the + // initial SYN. + // 表示服务端收到了syn报文 + if h.state == handshakeSynRcvd { + synOpts.TS = h.ep.sendTSOk + synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + } + // 如果是客户端发送 syn 报文,如果是服务端发送 syn+ack 报文 + + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + // 判断握手是否结束,没有结束则循环 + for h.state != handshakeCompleted { + // 获取事件id + switch index, _ := s.Fetch(true); index { + case wakerForResend: + // 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传 + // 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传 + // 超时时间变为上次的2倍 + timeOut *= 2 + if timeOut > 60*time.Second { + return tcpip.ErrTimeout + } + rt.Reset(timeOut) + // 重新发送syn报文 + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + case wakerForNotification: + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + return tcpip.ErrAborted + } + if n¬ifyDrain != 0 { + for !h.ep.segmentQueue.empty() { + s := h.ep.segmentQueue.dequeue() + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + if h.state == handshakeCompleted { + return nil + } + } + close(h.ep.drainDone) + <-h.ep.undrain + } + + case wakerForNewSegment: + // 处理握手报文 + if err := h.processSegments(); err != nil { + return err + } + } + } + + return nil +} + +func parseSynSegmentOptions(s *segment) header.TCPSynOptions { + synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck)) + if synOpts.TS { + s.parsedOptions.TSVal = synOpts.TSVal + s.parsedOptions.TSEcr = synOpts.TSEcr + } + return synOpts +} + +var optionPool = sync.Pool{ + New: func() interface{} { + return make([]byte, maxOptionSize) + }, +} + +func getOptions() []byte { + return optionPool.Get().([]byte) +} + +func putOptions(options []byte) { + // Reslice to full capacity. + optionPool.Put(options[0:cap(options)]) +} + +// tcp选项的编码 +func makeSynOptions(opts header.TCPSynOptions) []byte { + // Emulate linux option order. This is as follows: + // + // if md5: NOP NOP MD5SIG 18 md5sig(16) + // if mss: MSS 4 mss(2) + // if ts and sack_advertise: + // SACK 2 TIMESTAMP 2 timestamp(8) + // elif ts: NOP NOP TIMESTAMP 10 timestamp(8) + // elif sack: NOP NOP SACK 2 + // if wscale: NOP WINDOW 3 ws(1) + // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8)) + // [for each block] start_seq(4) end_seq(4) + // if fastopen_cookie: + // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2) + // else: FASTOPEN (2 + len(cookie)) + // cookie(variable) [padding to four bytes] + // + options := getOptions() + + // Always encode the mss. + offset := header.EncodeMSSOption(uint32(opts.MSS), options) + + // Special ordering is required here. If both TS and SACK are enabled, + // then the SACK option precedes TS, with no padding. If they are + // enabled individually, then we see padding before the option. + if opts.TS && opts.SACKPermitted { + offset += header.EncodeSACKPermittedOption(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.TS { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.SACKPermitted { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKPermittedOption(options[offset:]) + } + + // Initialize the WS option. + if opts.WS >= 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeWSOption(opts.WS, options[offset:]) + } + + // Padding to the end; note that this never apply unless we add a + // fastopen option, we always expect the offset to remain the same. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +// 封装 sendTCP ,发送 syn 报文 +func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { + // The MSS in opts is automatically calculated as this function is + // called from many places and we don't want every call point being + // embedded with the MSS calculation. + + if opts.MSS == 0 { + opts.MSS = uint16(r.MTU() - header.TCPMinimumSize) + } + + options := makeSynOptions(opts) + err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options) + putOptions(options) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +// 发送一个tcp段数据,封装 tcp 首部,并写入网路层 +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { + optLen := len(opts) + // Allocate a buffer for the TCP header. + hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) + + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + // Initialize the header. + tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + tcp.Encode(&header.TCPFields{ + SrcPort: id.LocalPort, + DstPort: id.RemotePort, + SeqNum: uint32(seq), + AckNum: uint32(ack), + DataOffset: uint8(header.TCPMinimumSize + optLen), + Flags: flags, + WindowSize: uint16(rcvWnd), + }) + copy(tcp[header.TCPMinimumSize:], opts) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { + length := uint16(hdr.UsedLength() + data.Size()) + // tcp伪首部校验和的计算 + xsum := r.PseudoHeaderChecksum(ProtocolNumber) + for _, v := range data.Views() { + xsum = header.Checksum(v, xsum) + } + + // tcp的可靠性:校验和的计算,用于检测损伤的报文段 + tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) + } + + r.Stats().TCP.SegmentsSent.Increment() + if (flags & flagRst) != 0 { + r.Stats().TCP.ResetsSent.Increment() + } + + log.Printf("send tcp %s segment to %s, seq: %d, ack: %d, rcvWnd: %d", + flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort), + seq, ack, rcvWnd) + + return r.WritePacket(hdr, data, ProtocolNumber, ttl) +} + +// makeOptions makes an options slice. +func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { + options := getOptions() + offset := 0 + + // N.B. the ordering here matches the ordering used by Linux internally + // and described in the raw makeOptions function. We don't include + // unnecessary cases here (post connection.) + if e.sendTSOk { + // Embed the timestamp if timestamp has been enabled. + // + // We only use the lower 32 bits of the unix time in + // milliseconds. This is similar to what Linux does where it + // uses the lower 32 bits of the jiffies value in the tsVal + // field of the timestamp option. + // + // Further, RFC7323 section-5.4 recommends millisecond + // resolution as the lowest recommended resolution for the + // timestamp clock. + // + // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + } + if e.sackPermitted && len(sackBlocks) > 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) + } + + // We expect the above to produce an aligned offset. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +// sendRaw sends a TCP segment to the endpoint's peer. +// 发送一个tcp段给对端 +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { + var sackBlocks []header.SACKBlock + if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&flagAck != 0) { + sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] + } + options := e.makeOptions(sackBlocks) + err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options) + putOptions(options) + return err +} + +// 从发送队列中取出数据并发送出去 +func (e *endpoint) handleWrite() *tcpip.Error { + // Move packets from send queue to send list. The queue is accessible + // from other goroutines and protected by the send mutex, while the send + // list is only accessible from the handler goroutine, so it needs no + // mutexes. + e.sndBufMu.Lock() + + // 得到第一个tcp段 + first := e.sndQueue.Front() + if first != nil { + e.snd.writeList.PushBackList(&e.sndQueue) + e.snd.sndNxtList.UpdateForward(e.sndBufInQueue) + e.sndBufInQueue = 0 + } + + e.sndBufMu.Unlock() + + // Initialize the next segment to write if it's currently nil. + if e.snd.writeNext == nil { + e.snd.writeNext = first + } + + // Push out any new packets. + // 将数据发送出去 + e.snd.sendData() + + return nil +} + +// 关闭连接的处理,最终会调用 sendData 来发送 fin 包 +func (e *endpoint) handleClose() *tcpip.Error { + // Drain the send queue. + e.handleWrite() + + // Mark send side as closed. + // 标记发送器关闭 + e.snd.closed = true + + return nil +} + +// resetConnectionLocked sends a RST segment and puts the endpoint in an error +// state with the given error code. This method must only be called from the +// protocol goroutine. +// resetConnectionLocked 发送一个RST段,并将端点置于具有给定错误代码的错误状态。 +// 只能从协议goroutine中调用此方法。 +func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { + e.sendRaw(buffer.VectorisedView{}, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + + e.state = stateError + e.hardError = err +} + +// completeWorkerLocked is called by the worker goroutine when it's about to +// exit. It marks the worker as completed and performs cleanup work if requested +// by Close(). +func (e *endpoint) completeWorkerLocked() { + e.workerRunning = false + if e.workerCleanup { + e.cleanupLocked() + } +} + +// handleSegments pulls segments from the queue and processes them. It returns +// no error if the protocol loop should continue, an error otherwise. +// handleSegments 从队列中取出 tcp 段数据,然后处理它们。 +func (e *endpoint) handleSegments() *tcpip.Error { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(flagRst) { + // 如果收到 rst 报文 + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + return tcpip.ErrConnectionReset + } + } else if s.flagIsSet(flagAck) { + // 处理正常的报文 + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // If the timestamp option is negotiated and the segment + // does not carry a timestamp option then the segment + // must be dropped as per + // https://tools.ietf.org/html/rfc7323#section-3.2. + if e.sendTSOk && !s.parsedOptions.TS { + e.stack.Stats().DroppedPackets.Increment() + s.decRef() + continue + } + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + // 处理tcp数据段,同时给接收器和发送器 + // 为何要给发送器传接收到的数据段呢?主要是为了滑动窗口的滑动和拥塞控制处理 + e.rcv.handleRcvdSegment(s) + e.snd.handleRcvdSegment(s) + } + s.decRef() + } + + // If the queue is not empty, make sure we'll wake up in the next + // iteration. + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + + // Send an ACK for all processed packets if needed. + // tcp可靠性:累积确认 + // 如果发送的最大ack不等于下一个接收的序列号,发送ack + if e.rcv.rcvNxt != e.snd.maxSentAck { + e.snd.sendAck() + } + + e.resetKeepaliveTimer(true) + + return nil +} + +// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP +// keepalive packets periodically when the connection is idle. If we don't hear +// from the other side after a number of tries, we terminate the connection. +func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.keepalive.Lock() + if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + e.keepalive.Unlock() + return nil + } + + if e.keepalive.unacked >= e.keepalive.count { + e.keepalive.Unlock() + return tcpip.ErrConnectionReset + } + + // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with + // seg.seq = snd.nxt-1. + e.keepalive.unacked++ + e.keepalive.Unlock() + e.snd.sendSegment(buffer.VectorisedView{}, flagAck, e.snd.sndNxt-1) + e.resetKeepaliveTimer(false) + return nil +} + +// resetKeepaliveTimer restarts or stops the keepalive timer, depending on +// whether it is enabled for this endpoint. +func (e *endpoint) resetKeepaliveTimer(receivedData bool) { + e.keepalive.Lock() + defer e.keepalive.Unlock() + if receivedData { + e.keepalive.unacked = 0 + } + // Start the keepalive timer IFF it's enabled and there is no pending + // data to send. + if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + e.keepalive.timer.disable() + return + } + if e.keepalive.unacked > 0 { + e.keepalive.timer.enable(e.keepalive.interval) + } else { + e.keepalive.timer.enable(e.keepalive.idle) + } +} + +// disableKeepaliveTimer stops the keepalive timer. +func (e *endpoint) disableKeepaliveTimer() { + e.keepalive.Lock() + e.keepalive.timer.disable() + e.keepalive.Unlock() +} + +// protocolMainLoop is the main loop of the TCP protocol. It runs in its own +// goroutine and is responsible for sending segments and handling received +// segments. +// protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行,负责握手、发送段和处理收到的段。 +func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { + var closeTimer *time.Timer + var closeWaker sleep.Waker + + // 收尾的一些工作 + epilogue := func() { + // e.mu is expected to be hold upon entering this section. + + if e.snd != nil { + e.snd.resendTimer.cleanup() + } + + if closeTimer != nil { + closeTimer.Stop() + } + + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + + e.mu.Unlock() + + // When the protocol loop exits we should wake up our waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + + // 如果需要三次握手 + if handshake { + // This is an active connection, so we must initiate the 3-way + // handshake, and then inform potential waiters about its + // completion. + h, err := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + if err == nil { + // 执行握手 + err = h.execute() + } + // 处理握手有错 + if err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.mu.Lock() + e.state = stateError + e.hardError = err + // Lock released below. + epilogue() + + return err + } + + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + // 到这里就表示三次握手已经成功了,那么初始化发送器和接收器 + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) + e.rcvListMu.Unlock() + } + + e.keepalive.timer.init(&e.keepalive.waker) + defer e.keepalive.timer.cleanup() + + // Tell waiters that the endpoint is connected and writable. + e.mu.Lock() + e.state = stateConnected + drained := e.drainDone != nil + e.mu.Unlock() + if drained { + close(e.drainDone) + <-e.undrain + } + + e.waiterQueue.Notify(waiter.EventOut) + + // Set up the functions that will be called when the main protocol loop + // wakes up. + // 触发器的事件,这些函数很重要 + funcs := []struct { + w *sleep.Waker + f func() *tcpip.Error + }{ + { + w: &e.sndWaker, + f: e.handleWrite, + }, + { + w: &e.sndCloseWaker, + f: e.handleClose, + }, + { + w: &e.newSegmentWaker, + f: e.handleSegments, + }, + { + w: &closeWaker, + f: func() *tcpip.Error { + return tcpip.ErrConnectionAborted + }, + }, + { + w: &e.snd.resendWaker, + f: func() *tcpip.Error { + // 如果重传触发了,表示在rto时间内没有收到ack包 + // 也就是说假设它丢包了 + if !e.snd.retransmitTimerExpired() { + return tcpip.ErrTimeout + } + return nil + }, + }, + { + w: &e.keepalive.waker, + f: e.keepaliveTimerExpired, + }, + { + w: &e.notificationWaker, + f: func() *tcpip.Error { + n := e.fetchNotifications() + if n¬ifyNonZeroReceiveWindow != 0 { + e.rcv.nonZeroWindow() + } + + if n¬ifyReceiveWindowChanged != 0 { + e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) + } + + if n¬ifyMTUChanged != 0 { + e.sndBufMu.Lock() + count := e.packetTooBigCount + e.packetTooBigCount = 0 + mtu := e.sndMTU + e.sndBufMu.Unlock() + + e.snd.updateMaxPayloadSize(mtu, count) + } + + if n¬ifyReset != 0 { + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + } + if n¬ifyClose != 0 && closeTimer == nil { + // Reset the connection 3 seconds after the + // endpoint has been closed. + closeTimer = time.AfterFunc(3*time.Second, func() { + closeWaker.Assert() + }) + } + + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + if err := e.handleSegments(); err != nil { + return err + } + } + close(e.drainDone) + <-e.undrain + } + + if n¬ifyKeepaliveChanged != 0 { + e.resetKeepaliveTimer(true) + } + + return nil + }, + }, + } + + // Initialize the sleeper based on the wakers in funcs. + s := sleep.Sleeper{} + for i := range funcs { + s.AddWaker(funcs[i].w, i) + } + + // The following assertions and notifications are needed for restored + // endpoints. Fresh newly created endpoints have empty states and should + // not invoke any. + // 恢复的端点需要以下断言和通知。新创建的新端点具有空状态,不应调用任何端点。 + e.segmentQueue.mu.Lock() + if !e.segmentQueue.list.Empty() { + e.newSegmentWaker.Assert() + } + e.segmentQueue.mu.Unlock() + + e.rcvListMu.Lock() + if !e.rcvList.Empty() { + e.waiterQueue.Notify(waiter.EventIn) + } + e.rcvListMu.Unlock() + + e.mu.RLock() + if e.workerCleanup { + e.notifyProtocolGoroutine(notifyClose) + } + e.mu.RUnlock() + + // Main loop. Handle segments until both send and receive ends of the + // connection have completed. + // 主循环,处理tcp报文 + // 要使这个主循环结束,也就是tcp连接完全关闭,得同时满足三个条件: + // 1,接收器关闭了 2,发送器关闭了 3,下一个未确认的序列号等于添加到发送列表的下一个段的序列号 + for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + e.workMu.Unlock() + // s.Fetch 会返回事件的index,比如 v=0 的话, + // funcs[v].f()就是调用 e.handleWrite + // 所以这里的函数应该尽量不阻塞,否则会影响其他事件的接收 + v, _ := s.Fetch(true) + e.workMu.Lock() + if err := funcs[v].f(); err != nil { + e.mu.Lock() + e.resetConnectionLocked(err) + // Lock released below. + epilogue() + + return nil + } + } + + // Mark endpoint as closed. + e.mu.Lock() + if e.state != stateError { + e.state = stateClosed + } + // Lock released below. + epilogue() + + return nil +} diff --git a/netstack/tcpip/transport/tcp/cubic.go b/netstack/tcpip/transport/tcp/cubic.go new file mode 100644 index 0000000..4ce2af8 --- /dev/null +++ b/netstack/tcpip/transport/tcp/cubic.go @@ -0,0 +1,255 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "math" + "time" +) + +// cubicState 存储与TCP CUBIC拥塞控制算法状态相关的变量。详见: https://tools.ietf.org/html/rfc8312. +type cubicState struct { + // wLastMax is the previous wMax value. + // 上次最大的拥塞窗口 + wLastMax float64 + + // wMax is the value of the congestion window at the + // time of last congestion event. + // 上次拥塞时间时的拥塞窗口大小 + wMax float64 + + // t denotes the time when the current congestion avoidance + // was entered. + // t 表进入拥塞避免阶段的时间。 + t time.Time + + // numCongestionEvents tracks the number of congestion events since last + // RTO. + // numCongestionEvents 跟踪自上次 RTO 以来的拥塞事件数。 + numCongestionEvents int + + // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as + // per RFC. + // 三次方函数的系数 + c float64 + + // k is the time period that the above function takes to increase the + // current window size to W_max if there are no further congestion + // events and is calculated using the following equation: + // k是在没有其他拥塞事件时将当前窗口大小增加到W_max所需的时间段,并使用以下公式计算: + // + // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) + k float64 + + // beta is the CUBIC multiplication decrease factor. that is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // W_cubic(0)=W_max*beta_cubic. + // beta 是CUBIC乘法减少因子。也就是说,当检测到拥塞事件时,CUBIC将其cwnd减少到 + beta float64 + + // wC is window computed by CUBIC at time t. It's calculated using the + // formula: + // wC 是由CUBIC在时间t计算的窗口。它使用公式计算: + // + // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) + wC float64 + + // wEst is the window computed by CUBIC at time t+RTT i.e + // wEs t是CUBIC在时间 t+RTT 计算的窗口,即 + // W_cubic(t+RTT). + wEst float64 + + s *sender +} + +// newCubicCC returns a partially initialized cubic state with the constants +// beta and c set and t set to current time. +// newCubicCC 返回部分初始化的 cubic 状态,常量为beta和c,t为当前时间。 +func newCubicCC(s *sender) *cubicState { + return &cubicState{ + t: time.Now(), + beta: 0.7, + c: 0.4, + s: s, + } +} + +// enterCongestionAvoidance is used to initialize cubic in cases where we exit +// SlowStart without a real congestion event taking place. This can happen when +// a connection goes back to slow start due to a retransmit and we exceed the +// previously lowered ssThresh without experiencing packet loss. +// +// Refer: https://tools.ietf.org/html/rfc8312#section-4.8 +// enterCongestionAvoidance 用于在我们退出 SlowStart 而没有发生真正的拥塞事件的情况下初始化 cubic。 +// 当连接由于重新传输而返回慢启动时会发生这种情况,并且我们超过先前降低的ssThresh而不会遇到丢包。 +func (c *cubicState) enterCongestionAvoidance() { + // See: https://tools.ietf.org/html/rfc8312#section-4.7 & + // https://tools.ietf.org/html/rfc8312#section-4.8 + // 初次进入拥塞避免,只记录当前拥塞避免的时间点,当前的窗口wMax=cwnd + if c.numCongestionEvents == 0 { + c.k = 0 + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + } +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window we cross +// the ssThresh then it will return the number of packets that must be consumed +// in congestion avoidance mode. +// updateSlowStart 将根据NewReno使用的慢启动算法更新拥塞窗口。 +// 如果在调整拥塞窗口之后我们越过ssThresh,那么它将返回在拥塞避免模式下必须消耗的数据包的数量。 +func (c *cubicState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := c.s.sndCwnd + packetsAcked + enterCA := false + if newcwnd >= c.s.sndSsthresh { + newcwnd = c.s.sndSsthresh + c.s.sndCAAckCount = 0 + enterCA = true + } + + packetsAcked -= newcwnd - c.s.sndCwnd + c.s.sndCwnd = newcwnd + if enterCA { + // 进入拥塞避免 + c.enterCongestionAvoidance() + } + return packetsAcked +} + +// Update updates cubic's internal state variables. It must be called on every +// ACK received. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +// Update 更新 cubic 内部状态变量,每次收到ack都必须调用 +func (c *cubicState) Update(packetsAcked int) { + if c.s.sndCwnd < c.s.sndSsthresh { + packetsAcked = c.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } else { + // 拥塞避免阶段 + c.s.rtt.Lock() + srtt := c.s.rtt.srtt + c.s.rtt.Unlock() + c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) + } +} + +// cubicCwnd computes the CUBIC congestion window after t seconds from last +// congestion event. +// cubicCwnd 在上次拥塞事件发生t秒后计算CUBIC拥塞窗口。 +func (c *cubicState) cubicCwnd(t float64) float64 { + return c.c*math.Pow(t, 3.0) + c.wMax +} + +// getCwnd returns the current congestion window as computed by CUBIC. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +// getCwnd 返回由CUBIC计算的当前拥塞窗口。 +func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { + elapsed := time.Since(c.t).Seconds() + + // Compute the window as per Cubic after 'elapsed' time + // since last congestion event. + c.wC = c.cubicCwnd(elapsed - c.k) + + // Compute the TCP friendly estimate of the congestion window. + c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + + // Make sure in the TCP friendly region CUBIC performs at least + // as well as Reno. + if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + // TCP Friendly region of cubic. + return int(c.wEst) + } + + // In Concave/Convex region of CUBIC, calculate what CUBIC window + // will be after 1 RTT and use that to grow congestion window + // for every ack. + tEst := (time.Since(c.t) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.k) + // As per 4.3 for each received ACK cwnd must be incremented + // by (w_cubic(t+RTT) - cwnd/cwnd. + cwnd := float64(sndCwnd) + for i := 0; i < packetsAcked; i++ { + // Concave/Convex regions of cubic have the same formulas. + // See: https://tools.ietf.org/html/rfc8312#section-4.3 + cwnd += (wtRtt - cwnd) / cwnd + } + return int(cwnd) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +// 收到三次重复ack,调用 HandleNDupAcks +func (c *cubicState) HandleNDupAcks() { + // See: https://tools.ietf.org/html/rfc8312#section-4.5 + c.numCongestionEvents++ + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + c.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionContrl.HandleRTOExpired. +// 发生重传时调用 HandleRTOExpired。 +func (c *cubicState) HandleRTOExpired() { + // See: https://tools.ietf.org/html/rfc8312#section-4.6 + c.t = time.Now() + c.numCongestionEvents = 0 + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + + // We lost a packet, so reduce ssthresh. + c.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + c.s.sndCwnd = 1 +} + +// fastConvergence implements the logic for Fast Convergence algorithm as +// described in https://tools.ietf.org/html/rfc8312#section-4.6. +// 快速收敛 +func (c *cubicState) fastConvergence() { + if c.wMax < c.wLastMax { + c.wLastMax = c.wMax + c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + } else { + c.wLastMax = c.wMax + } + // Recompute k as wMax may have changed. + c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) +} + +// PostRecovery implemements congestionControl.PostRecovery. +// 更新t为当前的时间,当发送方退出快速恢复阶段时,将调用 PostRecovery +func (c *cubicState) PostRecovery() { + c.t = time.Now() +} + +// reduceSlowStartThreshold returns new SsThresh as described in +// https://tools.ietf.org/html/rfc8312#section-4.7. +// 按 cubic 的算法更新慢启动和拥塞避免之间的阈值 +func (c *cubicState) reduceSlowStartThreshold() { + c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) +} diff --git a/netstack/tcpip/transport/tcp/dual_stack_test.go b/netstack/tcpip/transport/tcp/dual_stack_test.go new file mode 100644 index 0000000..51d128e --- /dev/null +++ b/netstack/tcpip/transport/tcp/dual_stack_test.go @@ -0,0 +1,560 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "testing" + "time" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/checker" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/network/ipv4" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/transport/tcp" + "tcpip/netstack/tcpip/transport/tcp/testing/context" + "tcpip/netstack/waiter" +) + +func TestV4MappedConnectOnV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Start connection attempt, it must fail. + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) + if err != tcpip.ErrNoRoute { + t.Fatalf("Unexpected return value from Connect: %v", err) + } +} + +func testV4Connect(t *testing.T, c *context.Context) { + // Start connection attempt. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-ch: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestV4MappedConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func testV6Connect(t *testing.T, c *context.Context) { + // Start connection attempt to IPv6 address. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetV6Packet() + checker.IPv6(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + tcp := header.TCP(header.IPv6(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + iss := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-ch: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestV6Connect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectWhenBoundToWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to local address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV4RefuseOnV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the RST reply. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.AckNum(uint32(irs)+1), + ), + ) +} + +func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the RST reply. + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.AckNum(uint32(irs)+1), + ), + ) +} + +func testV4Accept(t *testing.T, c *context.Context) { + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + checker.IPv4(t, b, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1), + ), + ) + + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Make sure we get the same error when calling the original ep and the + // new one. This validates that v4-mapped endpoints are still able to + // query the V6Only flag, whereas pure v4 endpoints are not. + var v tcpip.V6OnlyOption + expected := c.EP.GetSockOpt(&v) + if err := nep.GetSockOpt(&v); err != expected { + t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) + } + + // Check the peer address. + addr, err := nep.GetRemoteAddress() + if err != nil { + t.Fatalf("GetRemoteAddress failed failed: %v", err) + } + + if addr.Addr != context.TestAddr { + t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) + } +} + +func TestV4AcceptOnV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV6AcceptOnV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + checker.IPv6(t, b, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1), + ), + ) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Make sure we can still query the v6 only status of the new endpoint, + // that is, that it is in fact a v6 socket. + var v tcpip.V6OnlyOption + if err := nep.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt failed failed: %v", err) + } + + // Check the peer address. + addr, err := nep.GetRemoteAddress() + if err != nil { + t.Fatalf("GetRemoteAddress failed failed: %v", err) + } + + if addr.Addr != context.TestV6Addr { + t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + } +} + +func TestV4AcceptOnV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} diff --git a/netstack/tcpip/transport/tcp/endpoint.go b/netstack/tcpip/transport/tcp/endpoint.go new file mode 100644 index 0000000..eedf24c --- /dev/null +++ b/netstack/tcpip/transport/tcp/endpoint.go @@ -0,0 +1,1637 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "log" + "math" + "sync" + "sync/atomic" + "time" + + "tcpip/netstack/rand" + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/tmutex" + "tcpip/netstack/waiter" +) + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateListen + stateConnecting + stateConnected + stateClosed + stateError +) + +// Reasons for notifying the protocol goroutine. +const ( + notifyNonZeroReceiveWindow = 1 << iota + notifyReceiveWindowChanged + notifyClose + notifyMTUChanged + notifyDrain + notifyReset + notifyKeepaliveChanged +) + +// SACKInfo holds TCP SACK related information for a given endpoint. +// +// +stateify savable +type SACKInfo struct { + // Blocks is the maximum number of SACK blocks we track + // per endpoint. + Blocks [MaxSACKBlocks]header.SACKBlock + + // NumBlocks is the number of valid SACK blocks stored in the + // blocks array above. + NumBlocks int +} + +// endpoint represents a TCP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. The protocol implementation, however, runs in a single +// goroutine. +// +// +stateify savable +// endpoint 表示TCP端点。该结构用作端点用户和协议实现之间的接口;让并发goroutine调用端点是合法的, +// 它们是正确同步的。然而,协议实现在单个goroutine中运行。 +type endpoint struct { + // workMu is used to arbitrate which goroutine may perform protocol + // work. Only the main protocol goroutine is expected to call Lock() on + // it, but other goroutines (e.g., send) may call TryLock() to eagerly + // perform work without having to wait for the main one to wake up. + workMu tmutex.Mutex + + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + + // lastError represents the last error that the endpoint reported; + // access to it is protected by the following mutex. + lastErrorMu sync.Mutex + lastError *tcpip.Error + + // The following fields are used to manage the receive queue. The + // protocol goroutine adds ready-for-delivery segments to rcvList, + // which are returned by Read() calls to users. + // + // Once the peer has closed its send side, rcvClosed is set to true + // to indicate to users that no more data is coming. + // + // rcvListMu can be taken after the endpoint mu below. + rcvListMu sync.Mutex + rcvList segmentList + rcvClosed bool + rcvBufSize int + rcvBufUsed int + + // The following fields are protected by the mutex. + mu sync.RWMutex + id stack.TransportEndpointID + state endpointState + isPortReserved bool + isRegistered bool + boundNICID tcpip.NICID + route stack.Route + v6only bool + isConnectNotified bool + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber + + // hardError is meaningful only when state is stateError, it stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. hardError is protected by mu. + hardError *tcpip.Error + + // workerRunning specifies if a worker goroutine is running. + workerRunning bool + + // workerCleanup specifies if the worker goroutine must perform cleanup + // before exitting. This can only be set to true when workerRunning is + // also true, and they're both protected by the mutex. + workerCleanup bool + + // sendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1 + sendTSOk bool + + // recentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field is + // updated if required when a new segment is received by this endpoint. + recentTS uint32 + + // tsOffset is a randomized offset added to the value of the + // TSVal field in the timestamp option. + tsOffset uint32 + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // sackPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + sackPermitted bool + + // sack holds TCP SACK related information for this endpoint. + sack SACKInfo + + // The options below aren't implemented, but we remember the user + // settings because applications expect to be able to set/query these + // options. + noDelay bool + reuseAddr bool + + // segmentQueue is used to hand received segments to the protocol + // goroutine. Segments are queued as long as the queue is not full, + // and dropped when it is. + segmentQueue segmentQueue + + // The following fields are used to manage the send buffer. When + // segments are ready to be sent, they are added to sndQueue and the + // protocol goroutine is signaled via sndWaker. + // + // When the send side is closed, the protocol goroutine is notified via + // sndCloseWaker, and sndClosed is set to true. + sndBufMu sync.Mutex + sndBufSize int + sndBufUsed int + sndClosed bool + sndBufInQueue seqnum.Size + sndQueue segmentList + sndWaker sleep.Waker + sndCloseWaker sleep.Waker + + // cc stores the name of the Congestion Control algorithm to use for + // this endpoint. + cc CongestionControlOption + + // The following are used when a "packet too big" control packet is + // received. They are protected by sndBufMu. They are used to + // communicate to the main protocol goroutine how many such control + // messages have been received since the last notification was processed + // and what was the smallest MTU seen. + packetTooBigCount int + sndMTU int + + // newSegmentWaker is used to indicate to the protocol goroutine that + // it needs to wake up and handle new segments queued to it. + newSegmentWaker sleep.Waker + + // notificationWaker is used to indicate to the protocol goroutine that + // it needs to wake up and check for notifications. + notificationWaker sleep.Waker + + // notifyFlags is a bitmask of flags used to indicate to the protocol + // goroutine what it was notified; this is only accessed atomically. + notifyFlags uint32 + + // keepalive manages TCP keepalive state. When the connection is idle + // (no data sent or received) for keepaliveIdle, we start sending + // keepalives every keepalive.interval. If we send keepalive.count + // without hearing a response, the connection is closed. + keepalive keepalive + + // acceptedChan is used by a listening endpoint protocol goroutine to + // send newly accepted connections to the endpoint so that they can be + // read by Accept() calls. + acceptedChan chan *endpoint + + // The following are only used from the protocol goroutine, and + // therefore don't need locks to protect them. + rcv *receiver + snd *sender + + // The goroutine drain completion notification channel. + drainDone chan struct{} + + // The goroutine undrain notification channel. + undrain chan struct{} + + // probe if not nil is invoked on every received segment. It is passed + // a copy of the current state of the endpoint. + probe stack.TCPProbeFunc + + // The following are only used to assist the restore run to re-connect. + bindAddress tcpip.Address + connectingAddress tcpip.Address +} + +// keepalive is a synchronization wrapper used to appease stateify. See the +// comment in endpoint, where it is used. +// +// +stateify savable +type keepalive struct { + sync.Mutex + enabled bool + idle time.Duration + interval time.Duration + count int + unacked int + timer timer + waker sleep.Waker +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + e := &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSize: DefaultBufferSize, + sndBufSize: DefaultBufferSize, + sndMTU: int(math.MaxInt32), + noDelay: false, + reuseAddr: true, + keepalive: keepalive{ + // Linux defaults. + idle: 2 * time.Hour, + interval: 75 * time.Second, + count: 9, + }, + } + + var ss SendBufferSizeOption + if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + e.sndBufSize = ss.Default + } + + var rs ReceiveBufferSizeOption + if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + e.rcvBufSize = rs.Default + } + + var cs CongestionControlOption + if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + e.cc = cs + } + + if p := stack.GetTCPProbe(); p != nil { + e.probe = p + } + + e.segmentQueue.setLimit(2 * e.rcvBufSize) + e.workMu.Init() + e.workMu.Lock() + e.tsOffset = timeStampOffset() + return e +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +// Readiness 返回端点的当前读数据的准备情况。例如,如果设置了waiter.EventIn,则端点可立即读取。 +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + result := waiter.EventMask(0) + + e.mu.RLock() + defer e.mu.RUnlock() + + switch e.state { + case stateInitial, stateBound, stateConnecting: + // Ready for nothing. + + case stateClosed, stateError: + // Ready for anything. + result = mask + + case stateListen: + // Check if there's anything in the accepted channel. + if (mask & waiter.EventIn) != 0 { + if len(e.acceptedChan) > 0 { + result |= waiter.EventIn + } + } + + case stateConnected: + // Determine if the endpoint is writable if requested. + if (mask & waiter.EventOut) != 0 { + e.sndBufMu.Lock() + if e.sndClosed || e.sndBufUsed < e.sndBufSize { + result |= waiter.EventOut + } + e.sndBufMu.Unlock() + } + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvListMu.Lock() + if e.rcvBufUsed > 0 || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvListMu.Unlock() + } + } + + return result +} + +func (e *endpoint) fetchNotifications() uint32 { + return atomic.SwapUint32(&e.notifyFlags, 0) +} + +func (e *endpoint) notifyProtocolGoroutine(n uint32) { + for { + v := atomic.LoadUint32(&e.notifyFlags) + if v&n == n { + // The flags are already set. + return + } + + if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) { + if v == 0 { + // We are causing a transition from no flags to + // at least one flag set, so we must cause the + // protocol goroutine to wake up. + e.notificationWaker.Assert() + } + return + } + } +} + +// Close puts the endpoint in a closed state and frees all resources associated +// with it. It must be called only once and with no other concurrent calls to +// the endpoint. +func (e *endpoint) Close() { + // Issue a shutdown so that the peer knows we won't send any more data + // if we're connected, or stop accepting if we're listening. + e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + + e.mu.Lock() + + // We always release ports inline so that they are immediately available + // for reuse after Close() is called. If also registered, it means this + // is a listening socket, so we must unregister as well otherwise the + // next user would fail in Listen() when trying to register. + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.isPortReserved = false + + if e.isRegistered { + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.isRegistered = false + } + } + + // Either perform the local cleanup or kick the worker to make sure it + // knows it needs to cleanup. + tcpip.AddDanglingEndpoint(e) + if !e.workerRunning { + e.cleanupLocked() + } else { + e.workerCleanup = true + e.notifyProtocolGoroutine(notifyClose) + } + + e.mu.Unlock() +} + +// cleanupLocked frees all resources associated with the endpoint. It is called +// after Close() is called and the worker goroutine (if any) is done with its +// work. +func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by + // the client. + if e.acceptedChan != nil { + close(e.acceptedChan) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + e.acceptedChan = nil + } + e.workerCleanup = false + + if e.isRegistered { + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + } + + e.route.Release() + tcpip.DeleteDanglingEndpoint(e) +} + +// Read reads data from the endpoint. +// 从tcp的接收队列中读取数据 +func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.mu.RLock() + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. Also note that a RST being received + // would cause the state to become stateError so we should allow the + // reads to proceed before returning a ECONNRESET. + e.rcvListMu.Lock() + bufUsed := e.rcvBufUsed + if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + e.rcvListMu.Unlock() + he := e.hardError + e.mu.RUnlock() + if s == stateError { + return buffer.View{}, tcpip.ControlMessages{}, he + } + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + } + + v, err := e.readLocked() + e.rcvListMu.Unlock() + + e.mu.RUnlock() + + return v, tcpip.ControlMessages{}, err +} + +// 从tcp的接收队列中读取数据,并从接收队列中删除已读数据 +func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { + if e.rcvBufUsed == 0 { + if e.rcvClosed || e.state != stateConnected { + return buffer.View{}, tcpip.ErrClosedForReceive + } + return buffer.View{}, tcpip.ErrWouldBlock + } + + s := e.rcvList.Front() + views := s.data.Views() + v := views[s.viewToDeliver] + s.viewToDeliver++ + + if s.viewToDeliver >= len(views) { + e.rcvList.Remove(s) + s.decRef() + } + + scale := e.rcv.rcvWndScale + // tcp流量控制:检测接收窗口是否为0 + wasZero := e.zeroReceiveWindow(scale) + e.rcvBufUsed -= len(v) + // 检测糊涂窗口,主动发送窗口不为0的通告给对方 + if wasZero && !e.zeroReceiveWindow(scale) { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + return v, nil +} + +// Write writes data to the endpoint's peer. +// 接收上层的数,通过tcp连接发送到对端 +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint cannot be written to if it's not connected. + // 判断tcp状态,必须已经建立了连接才能发送数据 + if e.state != stateConnected { + switch e.state { + case stateError: + return 0, nil, e.hardError + default: + return 0, nil, tcpip.ErrClosedForSend + } + } + + // Nothing to do if the buffer is empty. + // 检查负载的长度,如果为0,直接返回 + if p.Size() == 0 { + return 0, nil, nil + } + + e.sndBufMu.Lock() + + // Check if the connection has already been closed for sends. + if e.sndClosed { + e.sndBufMu.Unlock() + return 0, nil, tcpip.ErrClosedForSend + } + + // Check against the limit. + // tcp流量控制:未被占用发送缓存还剩多少,如果发送缓存已经被用光了,返回 ErrWouldBlock + avail := e.sndBufSize - e.sndBufUsed + if avail <= 0 { + e.sndBufMu.Unlock() + return 0, nil, tcpip.ErrWouldBlock + } + + v, perr := p.Get(avail) + if perr != nil { + e.sndBufMu.Unlock() + return 0, nil, perr + } + + var err *tcpip.Error + if p.Size() > avail { + err = tcpip.ErrWouldBlock + } + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) + + // Add data to the send queue. + // 插入发送队列 + e.sndBufUsed += l + e.sndBufInQueue += seqnum.Size(l) + e.sndQueue.PushBack(s) + + e.sndBufMu.Unlock() + + // 发送数据,最终会调用 sender sendData 来发送数据。 + if e.workMu.TryLock() { + // Do the work inline. + e.handleWrite() + e.workMu.Unlock() + } else { + // Let the protocol goroutine do the work. + e.sndWaker.Assert() + } + return uintptr(l), nil, err +} + +// Peek reads data without consuming it from the endpoint. +// +// This method does not block if there is no data pending. +func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. + if s := e.state; s != stateConnected && s != stateClosed { + if s == stateError { + return 0, tcpip.ControlMessages{}, e.hardError + } + return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + if e.rcvBufUsed == 0 { + if e.rcvClosed || e.state != stateConnected { + return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + } + return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + + // Make a copy of vec so we can modify the slide headers. + vec = append([][]byte(nil), vec...) + + var num uintptr + + for s := e.rcvList.Front(); s != nil; s = s.Next() { + views := s.data.Views() + + for i := s.viewToDeliver; i < len(views); i++ { + v := views[i] + + for len(v) > 0 { + if len(vec) == 0 { + return num, tcpip.ControlMessages{}, nil + } + if len(vec[0]) == 0 { + vec = vec[1:] + continue + } + + n := copy(vec[0], v) + v = v[n:] + vec[0] = vec[0][n:] + num += uintptr(n) + } + } + } + + return num, tcpip.ControlMessages{}, nil +} + +// zeroReceiveWindow checks if the receive window to be announced now would be +// zero, based on the amount of available buffer and the receive window scaling. +// +// It must be called with rcvListMu held. +// zeroReceiveWindow 根据可用缓冲区的数量和接收窗口缩放,检查现在要宣布的接收窗口是否为零。 +func (e *endpoint) zeroReceiveWindow(scale uint8) bool { + if e.rcvBufUsed >= e.rcvBufSize { + return true + } + + return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.NoDelayOption: + e.mu.Lock() + e.noDelay = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs ReceiveBufferSizeOption + size := int(v) + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if size < rs.Min { + size = rs.Min + } + if size > rs.Max { + size = rs.Max + } + } + + mask := uint32(notifyReceiveWindowChanged) + + e.rcvListMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.rcvWndScale + } + if size>>scale == 0 { + size = 1 << scale + } + + // Make sure 2*size doesn't overflow. + if size > math.MaxInt32/2 { + size = math.MaxInt32 / 2 + } + + wasZero := e.zeroReceiveWindow(scale) + e.rcvBufSize = size + if wasZero && !e.zeroReceiveWindow(scale) { + mask |= notifyNonZeroReceiveWindow + } + e.rcvListMu.Unlock() + + e.segmentQueue.setLimit(2 * size) + + e.notifyProtocolGoroutine(mask) + return nil + + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + size := int(v) + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if size < ss.Min { + size = ss.Min + } + if size > ss.Max { + size = ss.Max + } + } + + e.sndBufMu.Lock() + e.sndBufSize = size + e.sndBufMu.Unlock() + + return nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v != 0 + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v != 0 + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + e.keepalive.idle = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + e.keepalive.interval = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = int(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + } + + return nil +} + +// readyReceiveSize returns the number of bytes ready to be received. +func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint cannot be in listen state. + if e.state == stateListen { + return 0, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + return e.rcvBufUsed, nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + e.lastErrorMu.Lock() + err := e.lastError + e.lastError = nil + e.lastErrorMu.Unlock() + return err + + case *tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.sndBufMu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) + e.rcvListMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + v, err := e.readyReceiveSize() + if err != nil { + return err + } + + *o = tcpip.ReceiveQueueSizeOption(v) + return nil + + case *tcpip.NoDelayOption: + e.mu.RLock() + v := e.noDelay + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.ReuseAddressOption: + e.mu.RLock() + v := e.reuseAddr + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.TCPInfoOption: + *o = tcpip.TCPInfoOption{} + e.mu.RLock() + snd := e.snd + e.mu.RUnlock() + if snd != nil { + snd.rtt.Lock() + o.RTT = snd.rtt.srtt + o.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + } + + return nil + + case *tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + *o = 0 + if v { + *o = 1 + } + + case *tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) + e.keepalive.Unlock() + + case *tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) + e.keepalive.Unlock() + + case *tcpip.KeepaliveCountOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveCountOption(e.keepalive.count) + e.keepalive.Unlock() + + } + + return tcpip.ErrUnknownProtocolOption +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. +// 和地址为 addr 的远端建立tcp连接 +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + return e.connect(addr, true, true) +} + +// connect connects the endpoint to its peer. In the normal non-S/R case, the +// new connection is expected to run the main goroutine and perform handshake. +// In restore of previously connected endpoints, both ends will be passively +// created (so no new handshaking is done); for stack-accepted connections not +// yet accepted by the app, they are restored without running the main goroutine +// here. +// connect将端点连接到其对等端。在正常的非S/R情况下,新连接应该运行主goroutine并执行握手。 +// 在恢复先前连接的端点时,将被动地创建两端(因此不会进行新的握手);对于应用程序尚未接受的堆栈接受连接, +// 它们将在不运行主goroutine的情况下进行恢复。 +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + defer func() { + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() + + connectingAddr := addr.Addr + + // 检查ipv4是否映射到ipv6 + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + + nicid := addr.NIC + // 判断连接的状态 + switch e.state { + case stateBound: + // If we're already bound to a NIC but the caller is requesting + // that we use a different one now, we cannot proceed. + if e.boundNICID == 0 { + break + } + + if nicid != 0 && nicid != e.boundNICID { + return tcpip.ErrNoRoute + } + + nicid = e.boundNICID + + case stateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID + // (if any) when we find a route. + + case stateConnecting: + // A connection request has already been issued but hasn't + // completed yet. + return tcpip.ErrAlreadyConnecting + + case stateConnected: + // The endpoint is already connected. If caller hasn't been notified yet, return success. + if !e.isConnectNotified { + e.isConnectNotified = true + return nil + } + // Otherwise return that it's already connected. + return tcpip.ErrAlreadyConnected + + case stateError: + return e.hardError + + default: + return tcpip.ErrInvalidEndpointState + } + + // Find a route to the desired destination. + // 根据目标ip查找路由信息 + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) + if err != nil { + return err + } + defer r.Release() + + origID := e.id + + netProtos := []tcpip.NetworkProtocolNumber{netProto} + e.id.LocalAddress = r.LocalAddress + e.id.RemoteAddress = r.RemoteAddress + e.id.RemotePort = addr.Port + + if e.id.LocalPort != 0 { + // 记录和检查原端口是否已被使用 + // The endpoint is bound to a port, attempt to register it. + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) + if err != nil { + return err + } + } else { + // The endpoint doesn't have a local port yet, so try to get + // one. Make sure that it isn't one that will result in the same + // address/port for both local and remote (otherwise this + // endpoint would be trying to connect to itself). + // 端点还没有本地端口,所以尝试获取一个端口。确保它不会导致本地和远程的相同地址/端口(否则此端点将尝试连接到自身) + // 远端地址和本地地址是否相同 + sameAddr := e.id.LocalAddress == e.id.RemoteAddress + // 随机分配一个本地端口 p + if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.id.RemotePort { + return false, nil + } + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) { + return false, nil + } + + id := e.id + id.LocalPort = p + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) { + case nil: + e.id = id + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }); err != nil { + return err + } + } + + // Remove the port reservation. This can happen when Bind is called + // before Connect: in such a case we don't want to hold on to + // reservations anymore. + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.isPortReserved = false + } + + // 记录该端点的参数 + e.isRegistered = true + e.state = stateConnecting + e.route = r.Clone() + e.boundNICID = nicid + e.effectiveNetProtos = netProtos + e.connectingAddress = connectingAddr + + // Connect in the restore phase does not perform handshake. Restore its + // connection setting here. + if !handshake { + e.segmentQueue.mu.Lock() + for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + for s := l.Front(); s != nil; s = s.Next() { + s.id = e.id + s.route = r.Clone() + e.sndWaker.Assert() + } + } + e.segmentQueue.mu.Unlock() + e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) + e.state = stateConnected + } + + if run { + e.workerRunning = true + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + // tcp的主函数 + go e.protocolMainLoop(handshake) + } + + return tcpip.ErrConnectStarted +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection to its +// peer. +// Shutdown 表示关闭该连接的读写。 +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + e.shutdownFlags |= flags + + switch e.state { + case stateConnected: + // Close for write. + // 不能直接关闭读数据包,因为关闭连接的时候四次挥手还需要读取报文。 + if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { + if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + // We're fully closed, if we have unread data we need to abort + // the connection with a RST. + // 完全关闭连接,也就是读写都关闭 + e.rcvListMu.Lock() + rcvBufUsed := e.rcvBufUsed + e.rcvListMu.Unlock() + + if rcvBufUsed > 0 { + // 如果接收队列还有数据,那么通知对端reset + e.notifyProtocolGoroutine(notifyReset) + return nil + } + } + + e.sndBufMu.Lock() + + if e.sndClosed { + // Already closed. + e.sndBufMu.Unlock() + break + } + + // Queue fin segment. + s := newSegmentFromView(&e.route, e.id, nil) + e.sndQueue.PushBack(s) + e.sndBufInQueue++ + + // Mark endpoint as closed. + e.sndClosed = true + + e.sndBufMu.Unlock() + + // Tell protocol goroutine to close. + // 触发调用 handleClose + e.sndCloseWaker.Assert() + } + + case stateListen: + // Tell protocolListenLoop to stop. + if flags&tcpip.ShutdownRead != 0 { + e.notifyProtocolGoroutine(notifyClose) + } + + default: + return tcpip.ErrNotConnected + } + + return nil +} + +// Listen puts the endpoint in "listen" mode, which allows it to accept +// new connections. +// Listen 表示进入监听状态,允许新的连接建立,backlog 表示最大的半连接缓存数目 +func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + defer func() { + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() + + // Allow the backlog to be adjusted if the endpoint is not shutting down. + // When the endpoint shuts down, it sets workerCleanup to true, and from + // that point onward, acceptedChan is the responsibility of the cleanup() + // method (and should not be touched anywhere else, including here). + // 如果端点未关闭,则允许调整 backlog。 + // 当端点关闭时,它将workerCleanup设置为true,从那时起, + // acceptedChan 负责 cleanup 方法(并且不应该在其他任何地方触及,包括此处) + if e.state == stateListen && !e.workerCleanup { + // Adjust the size of the channel iff we can fix existing + // pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + if cap(e.acceptedChan) == backlog { + return nil + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *endpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } + return nil + } + + // Endpoint must be bound before it can transition to listen mode. + // 在调用 Listen 之前,必须先 Bind + if e.state != stateBound { + return tcpip.ErrInvalidEndpointState + } + + // Register the endpoint. + // 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。 + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { + return err + } + + e.isRegistered = true + e.state = stateListen + if e.acceptedChan == nil { + e.acceptedChan = make(chan *endpoint, backlog) + } + e.workerRunning = true + + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + // tcp服务端实现的主循环,这个函数很重要,用一个goroutine来服务 + go e.protocolListenLoop( + seqnum.Size(e.receiveBufferAvailable())) + + return nil +} + +// startAcceptedLoop sets up required state and starts a goroutine with the +// main loop for accepted connections. +func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.waiterQueue = waiterQueue + e.workerRunning = true + go e.protocolMainLoop(false) +} + +// Accept returns a new endpoint if a peer has established a connection +// to an endpoint previously set to listen mode. +// Accept 返回新的tcp连接。必须先监听才能接收新的tcp连接。 +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // Endpoint must be in listen state before it can accept connections. + if e.state != stateListen { + return nil, nil, tcpip.ErrInvalidEndpointState + } + + // Get the new accepted endpoint. + var n *endpoint + select { + case n = <-e.acceptedChan: + default: + return nil, nil, tcpip.ErrWouldBlock + } + + // Start the protocol goroutine. + wq := &waiter.Queue{} + // 处理刚刚新建的tcp连接 + n.startAcceptedLoop(wq) + + return n, wq, nil +} + +// Bind binds the endpoint to a specific local port and optionally address. +// 将端点绑定到特定的本地端口和可选的地址。 +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (err *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + + // Don't allow binding once endpoint is not in the initial state + // anymore. This is because once the endpoint goes into a connected or + // listen state, it is already bound. + // 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。 + if e.state != stateInitial { + return tcpip.ErrAlreadyBound + } + + e.bindAddress = addr.Addr + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + // 绑定端口 + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) + if err != nil { + return err + } + + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.id.LocalPort = port + + // Any failures beyond this point must remove the port registration. + defer func() { + // 如果有错,在退出的时候应该解除端口绑定 + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.isPortReserved = false + e.effectiveNetProtos = nil + e.id.LocalPort = 0 + e.id.LocalAddress = "" + e.boundNICID = 0 + } + }() + + // If an address is specified, we must ensure that it's one of our + // local addresses. + if len(addr.Addr) != 0 { + nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + + e.boundNICID = nic + e.id.LocalAddress = addr.Addr + } + + // Check the commit function. + if commit != nil { + if err := commit(); err != nil { + // The defer takes care of unwind. + return err + } + } + + // Mark endpoint as bound. + // 标记状态为 stateBound + e.state = stateBound + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + NIC: e.boundNICID, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + NIC: e.boundNICID, + }, nil +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +// 从网络层接收到该端点的数据,会调用该函数来处理报文 +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + s := newSegment(r, id, vv) + // 解析tcp段,如果解析失败,丢弃该报文 + if !s.parse() { + e.stack.Stats().MalformedRcvdPackets.Increment() + e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + s.decRef() + return + } + + e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + if (s.flags & flagRst) != 0 { + e.stack.Stats().TCP.ResetsReceived.Increment() + } + + // Send packet to worker goroutine. + if e.segmentQueue.enqueue(s) { + log.Printf("recv tcp %s segment from %s, seq: %d, ack: %d", + flagString(s.flags), fmt.Sprintf("%s:%d", s.id.RemoteAddress, s.id.RemotePort), + s.sequenceNumber, s.ackNumber) + e.newSegmentWaker.Assert() + } else { + // The queue is full, so we drop the segment. + e.stack.Stats().DroppedPackets.Increment() + s.decRef() + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + switch typ { + case stack.ControlPacketTooBig: + e.sndBufMu.Lock() + e.packetTooBigCount++ + if v := int(extra); v < e.sndMTU { + e.sndMTU = v + } + e.sndBufMu.Unlock() + + e.notifyProtocolGoroutine(notifyMTUChanged) + } +} + +// updateSndBufferUsage is called by the protocol goroutine when room opens up +// in the send buffer. The number of newly available bytes is v. +// 当收到ack确认时,需要更新发送缓冲占用 +func (e *endpoint) updateSndBufferUsage(v int) { + e.sndBufMu.Lock() + notify := e.sndBufUsed >= e.sndBufSize>>1 + e.sndBufUsed -= v + // We only notify when there is half the sndBufSize available after + // a full buffer event occurs. This ensures that we don't wake up + // writers to queue just 1-2 segments and go back to sleep. + notify = notify && e.sndBufUsed < e.sndBufSize>>1 + e.sndBufMu.Unlock() + + if notify { + e.waiterQueue.Notify(waiter.EventOut) + } +} + +// readyToRead is called by the protocol goroutine when a new segment is ready +// to be read, or when the connection is closed for receiving (in which case +// s will be nil). +func (e *endpoint) readyToRead(s *segment) { + e.rcvListMu.Lock() + if s != nil { + s.incRef() + e.rcvBufUsed += s.data.Size() + e.rcvList.PushBack(s) + } else { + e.rcvClosed = true + } + e.rcvListMu.Unlock() + + e.waiterQueue.Notify(waiter.EventIn) +} + +// receiveBufferAvailable calculates how many bytes are still available in the +// receive buffer. +// tcp流量控制:计算未被占用的接收缓存大小 +func (e *endpoint) receiveBufferAvailable() int { + e.rcvListMu.Lock() + size := e.rcvBufSize + used := e.rcvBufUsed + e.rcvListMu.Unlock() + + // We may use more bytes than the buffer size when the receive buffer + // shrinks. + if used >= size { + return 0 + } + + return size - used +} + +func (e *endpoint) receiveBufferSize() int { + e.rcvListMu.Lock() + size := e.rcvBufSize + e.rcvListMu.Unlock() + + return size +} + +// updateRecentTimestamp updates the recent timestamp using the algorithm +// described in https://tools.ietf.org/html/rfc7323#section-4.3 +func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { + if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.recentTS = tsVal + } +} + +// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if +// the SYN options indicate that timestamp option was negotiated. It also +// initializes the recentTS with the value provided in synOpts.TSval. +func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { + if synOpts.TS { + e.sendTSOk = true + e.recentTS = synOpts.TSVal + } +} + +// timestamp returns the timestamp value to be used in the TSVal field of the +// timestamp option for outgoing TCP segments for a given endpoint. +func (e *endpoint) timestamp() uint32 { + return tcpTimeStamp(e.tsOffset) +} + +// tcpTimeStamp returns a timestamp offset by the provided offset. This is +// not inlined above as it's used when SYN cookies are in use and endpoint +// is not created at the time when the SYN cookie is sent. +func tcpTimeStamp(offset uint32) uint32 { + now := time.Now() + return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +} + +// timeStampOffset returns a randomized timestamp offset to be used when sending +// timestamp values in a timestamp option for a TCP segment. +func timeStampOffset() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + // Initialize a random tsOffset that will be added to the recentTS + // everytime the timestamp is sent when the Timestamp option is enabled. + // + // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on + // why this is required. + // + // NOTE: This is not completely to spec as normally this should be + // initialized in a manner analogous to how sequence numbers are + // randomized per connection basis. But for now this is sufficient. + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint +// if the SYN options indicate that the SACK option was negotiated and the TCP +// stack is configured to enable TCP SACK option. +func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { + var v SACKEnabled + if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return + } + if bool(v) && synOpts.SACKPermitted { + e.sackPermitted = true + } +} + +// completeState makes a full copy of the endpoint and returns it. This is used +// before invoking the probe. The state returned may not be fully consistent if +// there are intervening syscalls when the state is being copied. +func (e *endpoint) completeState() stack.TCPEndpointState { + var s stack.TCPEndpointState + s.SegTime = time.Now() + + // Copy EndpointID. + e.mu.Lock() + s.ID = stack.TCPEndpointID(e.id) + e.mu.Unlock() + + // Copy endpoint rcv state. + e.rcvListMu.Lock() + s.RcvBufSize = e.rcvBufSize + s.RcvBufUsed = e.rcvBufUsed + s.RcvClosed = e.rcvClosed + e.rcvListMu.Unlock() + + // Endpoint TCP Option state. + s.SendTSOk = e.sendTSOk + s.RecentTS = e.recentTS + s.TSOffset = e.tsOffset + s.SACKPermitted = e.sackPermitted + s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) + copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) + + // Copy endpoint send state. + e.sndBufMu.Lock() + s.SndBufSize = e.sndBufSize + s.SndBufUsed = e.sndBufUsed + s.SndClosed = e.sndClosed + s.SndBufInQueue = e.sndBufInQueue + s.PacketTooBigCount = e.packetTooBigCount + s.SndMTU = e.sndMTU + e.sndBufMu.Unlock() + + // Copy receiver state. + s.Receiver = stack.TCPReceiverState{ + RcvNxt: e.rcv.rcvNxt, + RcvAcc: e.rcv.rcvAcc, + RcvWndScale: e.rcv.rcvWndScale, + PendingBufUsed: e.rcv.pendingBufUsed, + PendingBufSize: e.rcv.pendingBufSize, + } + + // Copy sender state. + s.Sender = stack.TCPSenderState{ + LastSendTime: e.snd.lastSendTime, + DupAckCount: e.snd.dupAckCount, + FastRecovery: stack.TCPFastRecoveryState{ + Active: e.snd.fr.active, + First: e.snd.fr.first, + Last: e.snd.fr.last, + MaxCwnd: e.snd.fr.maxCwnd, + }, + SndCwnd: e.snd.sndCwnd, + Ssthresh: e.snd.sndSsthresh, + SndCAAckCount: e.snd.sndCAAckCount, + Outstanding: e.snd.outstanding, + SndWnd: e.snd.sndWnd, + SndUna: e.snd.sndUna, + SndNxt: e.snd.sndNxt, + RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, + RTTMeasureTime: e.snd.rttMeasureTime, + Closed: e.snd.closed, + RTO: e.snd.rto, + SRTTInited: e.snd.srttInited, + MaxPayloadSize: e.snd.maxPayloadSize, + SndWndScale: e.snd.sndWndScale, + MaxSentAck: e.snd.maxSentAck, + } + e.snd.rtt.Lock() + s.Sender.SRTT = e.snd.rtt.srtt + e.snd.rtt.Unlock() + + if cubic, ok := e.snd.cc.(*cubicState); ok { + s.Sender.Cubic = stack.TCPCubicState{ + WMax: cubic.wMax, + WLastMax: cubic.wLastMax, + T: cubic.t, + TimeSinceLastCongestion: time.Since(cubic.t), + C: cubic.c, + K: cubic.k, + Beta: cubic.beta, + WC: cubic.wC, + WEst: cubic.wEst, + } + } + return s +} diff --git a/netstack/tcpip/transport/tcp/forwarder.go b/netstack/tcpip/transport/tcp/forwarder.go new file mode 100644 index 0000000..ac82205 --- /dev/null +++ b/netstack/tcpip/transport/tcp/forwarder.go @@ -0,0 +1,174 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "sync" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +// Forwarder is a connection request forwarder, which allows clients to decide +// what to do with a connection request, for example: ignore it, send a RST, or +// attempt to complete the 3-way handshake. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +// Forwarder是一个连接请求转发器,它允许客户端决定如何处理连接请求,例如:忽略它,发送RST或尝试完成3次握手。 +// +//使用它的规范方法是将Forwarder.HandlePacket函数传递给stack.SetTransportProtocolHandler。 +type Forwarder struct { + maxInFlight int + handler func(*ForwarderRequest) + + mu sync.Mutex + inFlight map[stack.TransportEndpointID]struct{} + listen *listenContext +} + +// NewForwarder allocates and initializes a new forwarder with the given +// maximum number of in-flight connection attempts. Once the maximum is reached +// new incoming connection requests will be ignored. +// +// If rcvWnd is set to zero, the default buffer size is used instead. +func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder { + if rcvWnd == 0 { + rcvWnd = DefaultBufferSize + } + return &Forwarder{ + maxInFlight: maxInFlight, + handler: handler, + inFlight: make(map[stack.TransportEndpointID]struct{}), + listen: newListenContext(s, seqnum.Size(rcvWnd), true, 0), + } +} + +// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if +// it's a SYN packet), returning true if it's the case. Otherwise the packet +// is not handled and false is returned. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { + s := newSegment(r, id, vv) + defer s.decRef() + + // We only care about well-formed SYN packets. + if !s.parse() || s.flags != flagSyn { + return false + } + + opts := parseSynSegmentOptions(s) + + f.mu.Lock() + defer f.mu.Unlock() + + // We have an inflight request for this id, ignore this one for now. + if _, ok := f.inFlight[id]; ok { + return true + } + + // Ignore the segment if we're beyond the limit. + if len(f.inFlight) >= f.maxInFlight { + return true + } + + // Launch a new goroutine to handle the request. + f.inFlight[id] = struct{}{} + s.incRef() + go f.handler(&ForwarderRequest{ + forwarder: f, + segment: s, + synOptions: opts, + }) + + return true +} + +// ForwarderRequest represents a connection request received by the forwarder +// and passed to the client. Clients must eventually call Complete() on it, and +// may optionally create an endpoint to represent it via CreateEndpoint. +type ForwarderRequest struct { + mu sync.Mutex + forwarder *Forwarder + segment *segment + synOptions header.TCPSynOptions +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the connection request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.segment.id +} + +// Complete completes the request, and optionally sends a RST segment back to the +// sender. +func (r *ForwarderRequest) Complete(sendReset bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + panic("Completing already completed forwarder request") + } + + // Remove request from the forwarder. + r.forwarder.mu.Lock() + delete(r.forwarder.inFlight, r.segment.id) + r.forwarder.mu.Unlock() + + // If the caller requested, send a reset. + if sendReset { + replyWithReset(r.segment) + } + + // Release all resources. + r.segment.decRef() + r.segment = nil + r.forwarder = nil +} + +// CreateEndpoint creates a TCP endpoint for the connection request, performing +// the 3-way handshake in the process. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + return nil, tcpip.ErrInvalidEndpointState + } + + f := r.forwarder + ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + MSS: r.synOptions.MSS, + WS: r.synOptions.WS, + TS: r.synOptions.TS, + TSVal: r.synOptions.TSVal, + TSEcr: r.synOptions.TSEcr, + SACKPermitted: r.synOptions.SACKPermitted, + }) + if err != nil { + return nil, err + } + + // Start the protocol goroutine. + ep.startAcceptedLoop(queue) + + return ep, nil +} diff --git a/netstack/tcpip/transport/tcp/protocol.go b/netstack/tcpip/transport/tcp/protocol.go new file mode 100644 index 0000000..04aade7 --- /dev/null +++ b/netstack/tcpip/transport/tcp/protocol.go @@ -0,0 +1,241 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcp contains the implementation of the TCP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing tcp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package tcp + +import ( + "strings" + "sync" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +const ( + // ProtocolName is the string representation of the tcp protocol name. + ProtocolName = "tcp" + + // ProtocolNumber is the tcp protocol number. + ProtocolNumber = header.TCPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + minBufferSize = 4 << 10 // 4096 bytes. + + // DefaultBufferSize is the default size of the receive and send buffers. + DefaultBufferSize = 1 << 20 // 1MB + + // MaxBufferSize is the largest size a receive and send buffer can grow to. + maxBufferSize = 4 << 20 // 4MB +) + +// SACKEnabled option can be used to enable SACK support in the TCP +// protocol. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// SendBufferSizeOption allows the default, min and max send buffer sizes for +// TCP endpoints to be queried or configured. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption allows the default, min and max receive buffer size +// for TCP endpoints to be queried or configured. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +const ( + ccReno = "reno" + ccCubic = "cubic" +) + +// CongestionControlOption sets the current congestion control algorithm. +type CongestionControlOption string + +// AvailableCongestionControlOption returns the supported congestion control +// algorithms. +type AvailableCongestionControlOption string + +// tcp协议传输层接口的实现 +type protocol struct { + mu sync.Mutex + sackEnabled bool + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption + congestionControl string + availableCongestionControl []string + allowedCongestionControl []string +} + +// Number returns the tcp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new tcp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// MinimumPacketSize returns the minimum valid tcp packet size. +func (*protocol) MinimumPacketSize() int { + return header.TCPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given tcp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.TCP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +// +// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then +// a reset is sent in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected by this +// means." +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { + s := newSegment(r, id, vv) + defer s.decRef() + + if !s.parse() { + return false + } + + // There's nothing to do if this is already a reset packet. + if s.flagIsSet(flagRst) { + return true + } + + replyWithReset(s) + return true +} + +// replyWithReset replies to the given segment with a reset segment. +func replyWithReset(s *segment) { + // Get the seqnum from the packet if the ack flag is set. + seq := seqnum.Value(0) + if s.flagIsSet(flagAck) { + seq = s.ackNumber + } + + ack := s.sequenceNumber.Add(s.logicalLen()) + + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0, nil) +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SACKEnabled: + p.mu.Lock() + p.sackEnabled = bool(v) + p.mu.Unlock() + return nil + + case SendBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.sendBufferSize = v + p.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.recvBufferSize = v + p.mu.Unlock() + return nil + + case CongestionControlOption: + for _, c := range p.availableCongestionControl { + if string(v) == c { + p.mu.Lock() + p.congestionControl = string(v) + p.mu.Unlock() + return nil + } + } + return tcpip.ErrInvalidOptionValue + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SACKEnabled: + p.mu.Lock() + *v = SACKEnabled(p.sackEnabled) + p.mu.Unlock() + return nil + + case *SendBufferSizeOption: + p.mu.Lock() + *v = p.sendBufferSize + p.mu.Unlock() + return nil + + case *ReceiveBufferSizeOption: + p.mu.Lock() + *v = p.recvBufferSize + p.mu.Unlock() + return nil + case *CongestionControlOption: + p.mu.Lock() + *v = CongestionControlOption(p.congestionControl) + p.mu.Unlock() + return nil + case *AvailableCongestionControlOption: + p.mu.Lock() + *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + p.mu.Unlock() + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func init() { + // tcp协议在协议栈中的注册 + stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { + return &protocol{ + sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, + } + }) +} diff --git a/netstack/tcpip/transport/tcp/rcv.go b/netstack/tcpip/transport/tcp/rcv.go new file mode 100644 index 0000000..f4139a4 --- /dev/null +++ b/netstack/tcpip/transport/tcp/rcv.go @@ -0,0 +1,251 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "container/heap" + + "tcpip/netstack/tcpip/seqnum" +) + +// receiver holds the state necessary to receive TCP segments and turn them +// into a stream of bytes. +// +// +stateify savable +type receiver struct { + ep *endpoint + + rcvNxt seqnum.Value + + // rcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to the + // its peer that it's willing to accept. This may be different than + // rcvNxt + rcvWnd if the receive window is reduced; in that case we + // have to reduce the window as we receive more data instead of + // shrinking it. + // rcvAcc 超出了最后一个可接受的序列号。也就是说,接收方向其同行宣布它愿意接受的“最大”序列值。 + // 如果接收窗口减少,这可能与rcvNxt + rcvWnd不同;在这种情况下,我们必须减少窗口,因为我们收到更多数据而不是缩小它。 + rcvAcc seqnum.Value + + rcvWndScale uint8 + + closed bool + + pendingRcvdSegments segmentHeap + pendingBufUsed seqnum.Size + pendingBufSize seqnum.Size +} + +// 初始化接收器 +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { + return &receiver{ + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWndScale: rcvWndScale, + pendingBufSize: rcvWnd, + } +} + +// acceptable checks if the segment sequence number range is acceptable +// according to the table on page 26 of RFC 793. +// tcp流量控制:判断 segSeq 在窗口內 +func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + rcvWnd := r.rcvNxt.Size(r.rcvAcc) + if rcvWnd == 0 { + return segLen == 0 && segSeq == r.rcvNxt + } + + return segSeq.InWindow(r.rcvNxt, rcvWnd) || + seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen) +} + +// getSendParams returns the parameters needed by the sender when building +// segments to send. +// getSendParams 在构建要发送的段时,返回发送方所需的参数。 +// 并且更新接收窗口的指标 rcvAcc +func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { + // Calculate the window size based on the current buffer size. + n := r.ep.receiveBufferAvailable() + acc := r.rcvNxt.Add(seqnum.Size(n)) + if r.rcvAcc.LessThan(acc) { + r.rcvAcc = acc + } + + return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale +} + +// nonZeroWindow is called when the receive window grows from zero to nonzero; +// in such cases we may need to send an ack to indicate to our peer that it can +// resume sending data. +// tcp流量控制:当接收窗口从零增长到非零时,调用 nonZeroWindow;在这种情况下, +// 我们可能需要发送一个 ack,以便向对端表明它可以恢复发送数据。 +func (r *receiver) nonZeroWindow() { + if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 { + // We never got around to announcing a zero window size, so we + // don't need to immediately announce a nonzero one. + return + } + + // Immediately send an ack. + r.ep.snd.sendAck() +} + +// consumeSegment attempts to consume a segment that was received by r. The +// segment may have just been received or may have been received earlier but +// wasn't ready to be consumed then. +// +// Returns true if the segment was consumed, false if it cannot be consumed +// yet because of a missing segment. +// tcp可靠性:consumeSegment 尝试使用r接收tcp段。该数据段可能刚刚收到或可能已经收到,但尚未准备好被消费。 +// 如果数据段被消耗则返回true,如果由于缺少段而无法消耗,则返回false。 +func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool { + if segLen > 0 { + // If the segment doesn't include the seqnum we're expecting to + // consume now, we're missing a segment. We cannot proceed until + // we receive that segment though. + // 我们期望接收到的序列号范围应该是 seqStart <= rcvNxt < seqEnd, + // 如果不在这个范围内说明我们少了数据段,返回false,表示不能立马消费 + if !r.rcvNxt.InWindow(segSeq, segLen) { + return false + } + + // Trim segment to eliminate already acknowledged data. + // 尝试去除已经确认过的数据 + if segSeq.LessThan(r.rcvNxt) { + diff := segSeq.Size(r.rcvNxt) + segLen -= diff + segSeq.UpdateForward(diff) + s.sequenceNumber.UpdateForward(diff) + s.data.TrimFront(int(diff)) + } + + // Move segment to ready-to-deliver list. Wakeup any waiters. + // 将tcp段插入接收链表,并通知应用层用数据来了 + r.ep.readyToRead(s) + + } else if segSeq != r.rcvNxt { + return false + } + + // Update the segment that we're expecting to consume. + // 因为前面已经收到正确按序到达的数据,那么我们应该更新一下我们期望下次收到的序列号了 + r.rcvNxt = segSeq.Add(segLen) + + // Trim SACK Blocks to remove any SACK information that covers + // sequence numbers that have been consumed. + // 修剪SACK块以删除任何涵盖已消耗序列号的SACK信息。 + TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + + // 如果收到 fin 报文 + if s.flagIsSet(flagFin) { + // 控制报文消耗一个字节的序列号,因此这边期望下次收到的序列号加1 + r.rcvNxt++ + + // 收到 fin,立即回复 ack + r.ep.snd.sendAck() + + // Tell any readers that no more data will come. + // 标记接收器关闭 + // 触发上层应用可以读取 + r.closed = true + r.ep.readyToRead(nil) + + // Flush out any pending segments, except the very first one if + // it happens to be the one we're handling now because the + // caller is using it. + first := 0 + if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s { + first = 1 + } + + for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingRcvdSegments[i].decRef() + } + r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + } + + return true +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +// 从 handleSegments 接收到tcp段,然后进行处理消费,所谓的消费就是将负载内容插入到接收队列中 +func (r *receiver) handleRcvdSegment(s *segment) { + // We don't care about receive processing anymore if the receive side + // is closed. + if r.closed { + return + } + + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // If the sequence number range is outside the acceptable range, just + // send an ACK. This is according to RFC 793, page 37. + // tcp流量控制:判断该数据段的序列号是否在接收窗口内,如果不在,立即返回ack给对端。 + if !r.acceptable(segSeq, segLen) { + r.ep.snd.sendAck() + return + } + + // Defer segment processing if it can't be consumed now. + // tcp可靠性:r.consumeSegment 返回值是个bool类型,如果是true,表示已经消费该数据段, + // 如果不是,那么进行下面的处理,插入到 pendingRcvdSegments,且进行堆排序。 + if !r.consumeSegment(s, segSeq, segLen) { + // 如果有负载数据或者是 fin 报文,立即回复一个 ack 报文 + if segLen > 0 || s.flagIsSet(flagFin) { + // We only store the segment if it's within our buffer + // size limit. + // tcp可靠性:对于乱序的tcp段,应该在等待处理段中缓存 + if r.pendingBufUsed < r.pendingBufSize { + r.pendingBufUsed += s.logicalLen() + s.incRef() + // 插入堆中,且进行排序 + heap.Push(&r.pendingRcvdSegments, s) + } + + // tcp的可靠性:更新 sack 块信息 + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + + // Immediately send an ack so that the peer knows it may + // have to retransmit. + r.ep.snd.sendAck() + } + return + } + + // By consuming the current segment, we may have filled a gap in the + // sequence number domain that allows pending segments to be consumed + // now. So try to do it. + // tcp的可靠性:通过使用当前段,我们可能填补了序列号域中的间隙,该间隙允许现在使用待处理段。 + // 所以试着去消费等待处理段。 + for !r.closed && r.pendingRcvdSegments.Len() > 0 { + s := r.pendingRcvdSegments[0] + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // Skip segment altogether if it has already been acknowledged. + if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + !r.consumeSegment(s, segSeq, segLen) { + break + } + + // 如果该tcp段,已经被正确消费,那么中等待处理段中删除 + heap.Pop(&r.pendingRcvdSegments) + r.pendingBufUsed -= s.logicalLen() + s.decRef() + } +} diff --git a/netstack/tcpip/transport/tcp/reno.go b/netstack/tcpip/transport/tcp/reno.go new file mode 100644 index 0000000..9675bfd --- /dev/null +++ b/netstack/tcpip/transport/tcp/reno.go @@ -0,0 +1,125 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// renoState stores the variables related to TCP New Reno congestion +// control algorithm. +// +// +stateify savable +type renoState struct { + s *sender +} + +// newRenoCC initializes the state for the NewReno congestion control algorithm. +// 新建 reno 算法对象 +func newRenoCC(s *sender) *renoState { + return &renoState{s: s} +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window +// we cross the SSthreshold then it will return the number of packets that +// must be consumed in congestion avoidance mode. +// updateSlowStart 将根据NewReno使用的慢启动算法更新拥塞窗口。 +// 如果在调整拥塞窗口后我们越过了 SSthreshold ,那么它将返回在拥塞避免模式下必须消耗的数据包的数量。 +func (r *renoState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + // 在慢启动阶段,每次收到ack,sndCwnd加上已确认的段数 + newcwnd := r.s.sndCwnd + packetsAcked + // 判断增大过后的拥塞窗口是否超过慢启动阀值 sndSsthresh, + // 如果超过 sndSsthresh ,将窗口调整为 sndSsthresh + if newcwnd >= r.s.sndSsthresh { + newcwnd = r.s.sndSsthresh + r.s.sndCAAckCount = 0 + } + // 是否超过 sndSsthresh, packetsAcked>0表示超过 + packetsAcked -= newcwnd - r.s.sndCwnd + // 更新拥塞窗口 + r.s.sndCwnd = newcwnd + return packetsAcked +} + +// updateCongestionAvoidance will update congestion window in congestion +// avoidance mode as described in RFC5681 section 3.1 +// updateCongestionAvoidance 将在拥塞避免模式下更新拥塞窗口, +// 如RFC5681第3.1节所述 +func (r *renoState) updateCongestionAvoidance(packetsAcked int) { + // Consume the packets in congestion avoidance mode. + // sndCAAckCount 累计收到的tcp段数 + r.s.sndCAAckCount += packetsAcked + // 如果累计的段数超过当前的拥塞窗口,那么 sndCwnd 加上 sndCAAckCount/sndCwnd 的整数倍 + if r.s.sndCAAckCount >= r.s.sndCwnd { + r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd + r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + } +} + +// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, +// page 6, eq. 4. It is called when we detect congestion in the network. +// 当检测到网络拥塞时,调用 reduceSlowStartThreshold。 +// 它将 sndSsthresh 变为 outstanding 的一半。 +// sndSsthresh 最小为2,因为至少要比丢包后的拥塞窗口(cwnd=1)来的大,才会进入慢启动阶段。 +func (r *renoState) reduceSlowStartThreshold() { + r.s.sndSsthresh = r.s.outstanding / 2 + if r.s.sndSsthresh < 2 { + r.s.sndSsthresh = 2 + } +} + +// Update updates the congestion state based on the number of packets that +// were acknowledged. +// Update implements congestionControl.Update. +// packetsAcked 表示已确认的tcp段数 +func (r *renoState) Update(packetsAcked int) { + // 当拥塞窗口没有超过慢启动阀值的时候,使用慢启动来增大窗口, + // 否则进入拥塞避免阶段 + if r.s.sndCwnd < r.s.sndSsthresh { + packetsAcked = r.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } + // 进入拥塞避免阶段 + r.updateCongestionAvoidance(packetsAcked) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +// 当收到三个重复ack时,调用 HandleNDupAcks 来处理。 +func (r *renoState) HandleNDupAcks() { + // A retransmit was triggered due to nDupAckThreshold + // being hit. Reduce our slow start threshold. + // 减小慢启动阀值 + r.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionControl.HandleRTOExpired. +// 当当发生重传包时,调用 HandleRTOExpired 来处理。 +func (r *renoState) HandleRTOExpired() { + // We lost a packet, so reduce ssthresh. + // 减小慢启动阀值 + r.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + // 更新拥塞窗口为1,这样就会重新进入慢启动 + r.s.sndCwnd = 1 +} + +// PostRecovery implements congestionControl.PostRecovery. +func (r *renoState) PostRecovery() { + // noop. +} diff --git a/netstack/tcpip/transport/tcp/sack.go b/netstack/tcpip/transport/tcp/sack.go new file mode 100644 index 0000000..c74b7f6 --- /dev/null +++ b/netstack/tcpip/transport/tcp/sack.go @@ -0,0 +1,103 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" +) + +const ( + // MaxSACKBlocks is the maximum number of SACK blocks stored + // at receiver side. + // MaxSACKBlocks 是接收端存储的最大SACK块数。 + MaxSACKBlocks = 6 +) + +// UpdateSACKBlocks updates the list of SACK blocks to include the segment +// specified by segStart->segEnd. If the segment happens to be an out of order +// delivery then the first block in the sack.blocks always includes the +// segment identified by segStart->segEnd. +// tcp的可靠性:UpdateSACKBlocks 更新SACK块列表以包含 segStart-segEnd 指定的段,只有没有被消费掉的seg才会被用来更新sack。 +// 如果该段恰好是无序传递,那么sack.blocks中的第一个块总是包括由 segStart-segEnd 标识的段。 +func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) { + newSB := header.SACKBlock{Start: segStart, End: segEnd} + if sack.NumBlocks == 0 { + sack.Blocks[0] = newSB + sack.NumBlocks = 1 + return + } + var n = 0 + for i := 0; i < sack.NumBlocks; i++ { + start, end := sack.Blocks[i].Start, sack.Blocks[i].End + if end.LessThanEq(start) || start.LessThanEq(rcvNxt) { + // Discard any invalid blocks where end is before start + // and discard any sack blocks that are before rcvNxt as + // those have already been acked. + continue + } + if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) { + // Merge this SACK block into newSB and discard this SACK + // block. + if start.LessThan(newSB.Start) { + newSB.Start = start + } + if newSB.End.LessThan(end) { + newSB.End = end + } + } else { + // Save this block. + sack.Blocks[n] = sack.Blocks[i] + n++ + } + } + if rcvNxt.LessThan(newSB.Start) { + // If this was an out of order segment then make sure that the + // first SACK block is the one that includes the segment. + // + // See the first bullet point in + // https://tools.ietf.org/html/rfc2018#section-4 + if n == MaxSACKBlocks { + // If the number of SACK blocks is equal to + // MaxSACKBlocks then discard the last SACK block. + n-- + } + for i := n - 1; i >= 0; i-- { + sack.Blocks[i+1] = sack.Blocks[i] + } + sack.Blocks[0] = newSB + n++ + } + sack.NumBlocks = n +} + +// TrimSACKBlockList updates the sack block list by removing/modifying any block +// where start is < rcvNxt. +// tcp的可靠性:TrimSACKBlockList 通过删除/修改 start为 len(h) { + return false + } + + s.options = []byte(h[header.TCPMinimumSize:offset]) + s.parsedOptions = header.ParseTCPOptions(s.options) + s.data.TrimFront(offset) + + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) + + return true +} diff --git a/netstack/tcpip/transport/tcp/segment_heap.go b/netstack/tcpip/transport/tcp/segment_heap.go new file mode 100644 index 0000000..b961a5d --- /dev/null +++ b/netstack/tcpip/transport/tcp/segment_heap.go @@ -0,0 +1,48 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// tcp段的堆,用来暂存失序的tcp段 +// 实现了堆排序 +type segmentHeap []*segment + +// Len returns the length of h. +func (h segmentHeap) Len() int { + return len(h) +} + +// Less determines whether the i-th element of h is less than the j-th element. +func (h segmentHeap) Less(i, j int) bool { + return h[i].sequenceNumber.LessThan(h[j].sequenceNumber) +} + +// Swap swaps the i-th and j-th elements of h. +func (h segmentHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +// Push adds x as the last element of h. +func (h *segmentHeap) Push(x interface{}) { + *h = append(*h, x.(*segment)) +} + +// Pop removes the last element of h and returns it. +func (h *segmentHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[:n-1] + return x +} diff --git a/netstack/tcpip/transport/tcp/segment_queue.go b/netstack/tcpip/transport/tcp/segment_queue.go new file mode 100644 index 0000000..fb1347c --- /dev/null +++ b/netstack/tcpip/transport/tcp/segment_queue.go @@ -0,0 +1,81 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "sync" + + "tcpip/netstack/tcpip/header" +) + +// segmentQueue is a bounded, thread-safe queue of TCP segments. +// +// +stateify savable +type segmentQueue struct { + mu sync.Mutex + list segmentList + limit int + used int +} + +// empty determines if the queue is empty. +func (q *segmentQueue) empty() bool { + q.mu.Lock() + r := q.used == 0 + q.mu.Unlock() + + return r +} + +// setLimit updates the limit. No segments are immediately dropped in case the +// queue becomes full due to the new limit. +func (q *segmentQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} + +// enqueue adds the given segment to the queue. +// +// Returns true when the segment is successfully added to the queue, in which +// case ownership of the reference is transferred to the queue. And returns +// false if the queue is full, in which case ownership is retained by the +// caller. +func (q *segmentQueue) enqueue(s *segment) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used += s.data.Size() + header.TCPMinimumSize + } + q.mu.Unlock() + + return r +} + +// dequeue removes and returns the next segment from queue, if one exists. +// Ownership is transferred to the caller, who is responsible for decrementing +// the ref count when done. +func (q *segmentQueue) dequeue() *segment { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used -= s.data.Size() + header.TCPMinimumSize + } + q.mu.Unlock() + + return s +} diff --git a/netstack/tcpip/transport/tcp/snd.go b/netstack/tcpip/transport/tcp/snd.go new file mode 100644 index 0000000..109533a --- /dev/null +++ b/netstack/tcpip/transport/tcp/snd.go @@ -0,0 +1,797 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "math" + "sync" + "time" + + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" +) + +const ( + // minRTO is the minimum allowed value for the retransmit timeout. + minRTO = 200 * time.Millisecond + + // InitialCwnd is the initial congestion window. + // 初始拥塞窗口大小 + InitialCwnd = 10 + + // nDupAckThreshold is the number of duplicate ACK's required + // before fast-retransmit is entered. + nDupAckThreshold = 3 +) + +// congestionControl is an interface that must be implemented by any supported +// congestion control algorithm. +// tcp拥塞控制:拥塞控制算法的接口 +type congestionControl interface { + // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold + // just before entering fast retransmit. + // 在进入快速重新传输之前,当 sender.dupAckCount> = nDupAckThreshold 时调用HandleNDupAcks。 + HandleNDupAcks() + + // HandleRTOExpired is invoked when the retransmit timer expires. + // 当重新传输计时器到期时调用HandleRTOExpired。 + HandleRTOExpired() + + // Update is invoked when processing inbound acks. It's passed the + // number of packet's that were acked by the most recent cumulative + // acknowledgement. + // 已经有数据包被确认时调用 Update。它传递了最近累积确认所确认的数据包数。 + Update(packetsAcked int) + + // PostRecovery is invoked when the sender is exiting a fast retransmit/ + // recovery phase. This provides congestion control algorithms a way + // to adjust their state when exiting recovery. + // 当发送方退出快速重新传输/恢复阶段时,将调用PostRecovery。 + // 这为拥塞控制算法提供了一种在退出恢复时调整其状态的方法。 + PostRecovery() +} + +/* + +-------> sndwnd <-------+ + | | +---------------------+-------------+----------+-------------------- +| acked | * * * * * * | # # # # #| unable send +---------------------+-------------+----------+-------------------- + ^ ^ + | | + sndUna sndNxt +***** in flight data +##### able send date +*/ + +// sender holds the state necessary to send TCP segments. +// +// +stateify savable +// tcp发送器,它维护了tcp必要的状态 +type sender struct { + ep *endpoint + + // lastSendTime is the timestamp when the last packet was sent. + // lastSendTime 是发送最后一个数据包的时间戳。 + lastSendTime time.Time + + // dupAckCount is the number of duplicated acks received. It is used for + // fast retransmit. + // dupAckCount 是收到的重复ack数。它用于快速重传。 + dupAckCount int + + // fr holds state related to fast recovery. + // fr 持有与快速恢复有关的状态。 + fr fastRecovery + + // sndCwnd is the congestion window, in packets. + // sndCwnd 是拥塞窗口,单位是包 + sndCwnd int + + // sndSsthresh is the threshold between slow start and congestion + // avoidance. + // sndSsthresh 是慢启动和拥塞避免之间的阈值。 + sndSsthresh int + + // sndCAAckCount is the number of packets acknowledged during congestion + // avoidance. When enough packets have been ack'd (typically cwnd + // packets), the congestion window is incremented by one. + // sndCAAckCount 是拥塞避免期间确认的数据包数。当已经确认了足够的分组(通常是cwnd分组)时,拥塞窗口增加1。 + sndCAAckCount int + + // outstanding is the number of outstanding packets, that is, packets + // that have been sent but not yet acknowledged. + // outstanding 是正在发送的数据包的数量,即已发送但尚未确认的数据包。 + outstanding int + + // sndWnd is the send window size. + // 发送窗口大小,单位是字节 + sndWnd seqnum.Size + + // sndUna is the next unacknowledged sequence number. + // sndUna 是下一个未确认的序列号 + sndUna seqnum.Value + + // sndNxt is the sequence number of the next segment to be sent. + // sndNxt 是要发送的下一个段的序列号。 + sndNxt seqnum.Value + + // sndNxtList is the sequence number of the next segment to be added to + // the send list. + // sndNxtList 是要添加到发送列表的下一个段的序列号。 + sndNxtList seqnum.Value + + // rttMeasureSeqNum is the sequence number being used for the latest RTT + // measurement. + rttMeasureSeqNum seqnum.Value + + // rttMeasureTime is the time when the rttMeasureSeqNum was sent. + rttMeasureTime time.Time + + closed bool + writeNext *segment + // 发送链表 + writeList segmentList + resendTimer timer + resendWaker sleep.Waker + + // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", + // "round-trip time variation" and "retransmit timeout", as defined in + // section 2 of RFC 6298. + rtt rtt + rto time.Duration + srttInited bool + + // maxPayloadSize is the maximum size of the payload of a given segment. + // It is initialized on demand. + maxPayloadSize int + + // sndWndScale is the number of bits to shift left when reading the send + // window size from a segment. + sndWndScale uint8 + + // maxSentAck is the maxium acknowledgement actually sent. + maxSentAck seqnum.Value + + // cc is the congestion control algorithm in use for this sender. + // cc 是实现拥塞控制算法的接口 + cc congestionControl +} + +// rtt is a synchronization wrapper used to appease stateify. See the comment +// in sender, where it is used. +// +// +stateify savable +type rtt struct { + sync.Mutex + + srtt time.Duration + rttvar time.Duration +} + +// fastRecovery holds information related to fast recovery from a packet loss. +// +// +stateify savable +// fastRecovery 保存与数据包丢失快速恢复相关的信息 +type fastRecovery struct { + // active whether the endpoint is in fast recovery. The following fields + // are only meaningful when active is true. + active bool + + // first and last represent the inclusive sequence number range being + // recovered. + first seqnum.Value + last seqnum.Value + + // maxCwnd is the maximum value the congestion window may be inflated to + // due to duplicate acks. This exists to avoid attacks where the + // receiver intentionally sends duplicate acks to artificially inflate + // the sender's cwnd. + maxCwnd int +} + +// 新建并初始化发送器 +func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { + s := &sender{ + ep: ep, + sndCwnd: InitialCwnd, + sndSsthresh: math.MaxInt64, + sndWnd: sndWnd, + sndUna: iss + 1, + sndNxt: iss + 1, + sndNxtList: iss + 1, + rto: 1 * time.Second, + rttMeasureSeqNum: iss + 1, + lastSendTime: time.Now(), + maxPayloadSize: int(mss), + maxSentAck: irs + 1, + fr: fastRecovery{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + last: iss, + }, + } + + // 拥塞控制算法的初始化 + s.cc = s.initCongestionControl(ep.cc) + + // A negative sndWndScale means that no scaling is in use, otherwise we + // store the scaling value. + if sndWndScale > 0 { + s.sndWndScale = uint8(sndWndScale) + } + + s.updateMaxPayloadSize(int(ep.route.MTU()), 0) + + s.resendTimer.init(&s.resendWaker) + + return s +} + +// tcp拥塞控制:根据算法名,新建拥塞控制算法和初始化 +func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl { + switch congestionControlName { + case ccCubic: + return newCubicCC(s) + case ccReno: + fallthrough + default: + return newRenoCC(s) + } +} + +// updateMaxPayloadSize updates the maximum payload size based on the given +// MTU. If this is in response to "packet too big" control packets (indicated +// by the count argument), it also reduces the number of outstanding packets and +// attempts to retransmit the first packet above the MTU size. +func (s *sender) updateMaxPayloadSize(mtu, count int) { + m := mtu - header.TCPMinimumSize + + // Calculate the maximum option size. + // 计算MSS的大小 + var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock + options := s.ep.makeOptions(maxSackBlocks[:]) + m -= len(options) + putOptions(options) + + // We don't adjust up for now. + if m >= s.maxPayloadSize { + return + } + + // Make sure we can transmit at least one byte. + if m <= 0 { + m = 1 + } + + s.maxPayloadSize = m + + s.outstanding -= count + if s.outstanding < 0 { + s.outstanding = 0 + } + + // Rewind writeNext to the first segment exceeding the MTU. Do nothing + // if it is already before such a packet. + for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + if seg == s.writeNext { + // We got to writeNext before we could find a segment + // exceeding the MTU. + break + } + + if seg.data.Size() > m { + // We found a segment exceeding the MTU. Rewind + // writeNext and try to retransmit it. + s.writeNext = seg + break + } + } + + // Since we likely reduced the number of outstanding packets, we may be + // ready to send some more. + s.sendData() +} + +// sendAck sends an ACK segment. +// 发送一个ack tcp段 +func (s *sender) sendAck() { + s.sendSegment(buffer.VectorisedView{}, flagAck, s.sndNxt) +} + +// updateRTO updates the retransmit timeout when a new roud-trip time is +// available. This is done in accordance with section 2 of RFC 6298. +// 根据rtt来更新计算rto +func (s *sender) updateRTO(rtt time.Duration) { + s.rtt.Lock() + // 第一次计算 + if !s.srttInited { + s.rtt.rttvar = rtt / 2 + s.rtt.srtt = rtt + s.srttInited = true + } else { + diff := s.rtt.srtt - rtt + if diff < 0 { + diff = -diff + } + // Use RFC6298 standard algorithm to update rttvar and srtt when + // no timestamps are available. + if !s.ep.sendTSOk { + s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 + s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 + } else { + // When we are taking RTT measurements of every ACK then + // we need to use a modified method as specified in + // https://tools.ietf.org/html/rfc7323#appendix-G + if s.outstanding == 0 { + s.rtt.Unlock() + return + } + // Netstack measures congestion window/inflight all in + // terms of packets and not bytes. This is similar to + // how linux also does cwnd and inflight. In practice + // this approximation works as expected. + expectedSamples := math.Ceil(float64(s.outstanding) / 2) + + // alpha & beta values are the original values as recommended in + // https://tools.ietf.org/html/rfc6298#section-2.3. + const alpha = 0.125 + const beta = 0.25 + + alphaPrime := alpha / expectedSamples + betaPrime := beta / expectedSamples + rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds() + srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds() + s.rtt.rttvar = time.Duration(rttVar * float64(time.Second)) + s.rtt.srtt = time.Duration(srtt * float64(time.Second)) + } + } + + // 更新 rto 值 + s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.rtt.Unlock() + if s.rto < minRTO { + s.rto = minRTO + } +} + +// resendSegment resends the first unacknowledged segment. +// tcp的拥塞控制:快速重传 +func (s *sender) resendSegment() { + // Don't use any segments we already sent to measure RTT as they may + // have been affected by packets being lost. + s.rttMeasureSeqNum = s.sndNxt + + // Resend the segment. + if seg := s.writeList.Front(); seg != nil { + s.sendSegment(seg.data, seg.flags, seg.sequenceNumber) + } +} + +// retransmitTimerExpired is called when the retransmit timer expires, and +// unacknowledged segments are assumed lost, and thus need to be resent. +// Returns true if the connection is still usable, or false if the connection +// is deemed lost. +// tcp的可靠性:重传定时器触发的时候调用这个函数,也就是超时重传 +// tcp的拥塞控制:发生重传即认为发送丢包,拥塞控制需要对丢包进行相应的处理。 +func (s *sender) retransmitTimerExpired() bool { + // Check if the timer actually expired or if it's a spurious wake due + // to a previously orphaned runtime timer. + // 检查计时器是否真的到期 + if !s.resendTimer.checkExpiration() { + return true + } + + // Give up if we've waited more than a minute since the last resend. + // 如果rto已经超过了1分钟,直接放弃发送,返回错误 + if s.rto >= 60*time.Second { + return false + } + + // Set new timeout. The timer will be restarted by the call to sendData + // below. + // 每次超时,rto都变成原来的2倍 + s.rto *= 2 + + // tcp的拥塞控制:如果已经是快速恢复阶段,那么退出,因为现在丢包了 + if s.fr.active { + // We were attempting fast recovery but were not successful. + // Leave the state. We don't need to update ssthresh because it + // has already been updated when entered fast-recovery. + // 我们试图快速恢复,但没有成功,退出这个状态。 + // 我们不需要更新ssthresh,因为它在进入快速恢复时已经更新。 + s.leaveFastRecovery() + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. + // We store the highest sequence number transmitted in cases where + // we were not in fast recovery. + s.fr.last = s.sndNxt - 1 + + // tcp的拥塞控制:处理丢包的情况 + s.cc.HandleRTOExpired() + + // Mark the next segment to be sent as the first unacknowledged one and + // start sending again. Set the number of outstanding packets to 0 so + // that we'll be able to retransmit. + // + // We'll keep on transmitting (or retransmitting) as we get acks for + // the data we transmit. + // tcp可靠性:将下一个段标记为第一个未确认的段,然后再次开始发送。将未完成的数据包数设置为0,以便我们能够重新传输。 + // 当我们收到我们传输的数据时,我们将继续传输(或重新传输)。 + s.outstanding = 0 + s.writeNext = s.writeList.Front() + // 重新发送数据包 + s.sendData() + + return true +} + +// sendData sends new data segments. It is called when data becomes available or +// when the send window opens up. +// 发送数据段,最终掉用 sendSegment 来发送 +func (s *sender) sendData() { + limit := s.maxPayloadSize + + // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. + // "A TCP SHOULD set cwnd to no more than RW before beginning + // transmission if the TCP has not sent data in the interval exceeding + // the retrasmission timeout." + // 根据RFC 5681,第10页将拥塞窗口减少到min(IW,cwnd)。 + // 如果TCP在超过重新传输超时的时间间隔内没有发送数据,TCP应该在开始传输之前将cwnd设置为不超过RW。 + if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { + if s.sndCwnd > InitialCwnd { + s.sndCwnd = InitialCwnd + } + } + + // TODO: We currently don't merge multiple send buffers + // into one segment if they happen to fit. We should do that + // eventually. + var seg *segment + end := s.sndUna.Add(s.sndWnd) + var dataSent bool + // 遍历发送链表,发送数据 + // tcp拥塞控制:s.outstanding < s.sndCwnd 判断正在发送的数据量不能超过拥塞窗口。 + for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + // We abuse the flags field to determine if we have already + // assigned a sequence number to this segment. + // 如果seg的flags是0,将flags改为psh|ack + if seg.flags == 0 { + seg.sequenceNumber = s.sndNxt + seg.flags = flagAck | flagPsh + } + + var segEnd seqnum.Value + if seg.data.Size() == 0 { // 数据段没有负载,表示要结束连接 + if s.writeList.Back() != seg { + panic("FIN segments must be the final segment in the write list.") + } + // 发送 fin 报文 + seg.flags = flagAck | flagFin + // fin 报文需要确认,且消耗一个字节序列号 + segEnd = seg.sequenceNumber.Add(1) + } else { + // We're sending a non-FIN segment. + if seg.flags&flagFin != 0 { + panic("Netstack queues FIN segments without data.") + } + + if !seg.sequenceNumber.LessThan(end) { + break + } + + // tcp流量控制:计算最多一次发送多大数据, + available := int(seg.sequenceNumber.Size(end)) + if available > limit { + available = limit + } + + // 如果seg的payload字节数大于available + // 将seg进行分段,并且插入到该seg的后面 + if seg.data.Size() > available { + // Split this segment up. + nSeg := seg.clone() + nSeg.data.TrimFront(available) + nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) + s.writeList.InsertAfter(seg, nSeg) + seg.data.CapLength(available) + } + + s.outstanding++ + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + } + + if !dataSent { + dataSent = true + // We are sending data, so we should stop the keepalive timer to + // ensure that no keepalives are sent while there is pending data. + s.ep.disableKeepaliveTimer() + } + s.sendSegment(seg.data, seg.flags, seg.sequenceNumber) + + // Update sndNxt if we actually sent new data (as opposed to + // retransmitting some previously sent data). + // 发送一个数据段后,更新sndNxt + if s.sndNxt.LessThan(segEnd) { + s.sndNxt = segEnd + } + } + + // Remember the next segment we'll write. + s.writeNext = seg + + // Enable the timer if we have pending data and it's not enabled yet. + // tcp的可靠性:如果重发定时器没有启动 且 snduna 不等于 sndNxt,启动定时器 + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + // 启动定时器,并且设定定时器的间隔为s.rto + s.resendTimer.enable(s.rto) + } + // If we have no more pending data, start the keepalive timer. + if s.sndUna == s.sndNxt { + s.ep.resetKeepaliveTimer(false) + } +} + +// tcp拥塞控制:进入快速恢复状态和相应的处理 +func (s *sender) enterFastRecovery() { + s.fr.active = true + // Save state to reflect we're now in fast recovery. + // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. + // We inflat the cwnd by 3 to account for the 3 packets which triggered + // the 3 duplicate ACKs and are now not in flight. + s.sndCwnd = s.sndSsthresh + 3 + s.fr.first = s.sndUna + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = s.sndCwnd + s.outstanding +} + +// tcp拥塞控制:退出快速恢复状态和相应的处理 +func (s *sender) leaveFastRecovery() { + s.fr.active = false + s.fr.first = 0 + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = 0 + s.dupAckCount = 0 + + // Deflate cwnd. It had been artificially inflated when new dups arrived. + s.sndCwnd = s.sndSsthresh + s.cc.PostRecovery() +} + +// checkDuplicateAck is called when an ack is received. It manages the state +// related to duplicate acks and determines if a retransmit is needed according +// to the rules in RFC 6582 (NewReno). +// tcp拥塞控制:收到确认时调用 checkDuplicateAck。它管理与重复确认相关的状态, +// 并根据RFC 6582(NewReno)中的规则确定是否需要重新传输 +func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { + ack := seg.ackNumber + // 已经启动了快速恢复 + if s.fr.active { + // We are in fast recovery mode. Ignore the ack if it's out of + // range. + if !ack.InRange(s.sndUna, s.sndNxt+1) { + return false + } + + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + // 如果它确认此快速恢复涵盖的所有数据,退出快速恢复。 + if s.fr.last.LessThan(ack) { + s.leaveFastRecovery() + return false + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + // 如果该tcp段有负载或者正在更新窗口,那么不计算这个ack + if seg.logicalLen() != 0 || s.sndWnd != seg.window { + return false + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if ack == s.fr.first { + // We received a dup, inflate the congestion window by 1 + // packet if we're not at the max yet. + if s.sndCwnd < s.fr.maxCwnd { + s.sndCwnd++ + } + return false + } + + // A partial ack was received. Retransmit this packet and + // remember it so that we don't retransmit it again. We don't + // inflate the window because we're putting the same packet back + // onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + s.fr.first = ack + s.dupAckCount = 0 + return true + } + + // We're not in fast recovery yet. A segment is considered a duplicate + // only if it doesn't carry any data and doesn't update the send window, + // because if it does, it wasn't sent in response to an out-of-order + // segment. + // 我们还没有进入快速恢复状态,只有当段不携带任何数据并且不更新发送窗口时,才认为该段是重复的。 + if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt { + s.dupAckCount = 0 + return false + } + + // 到这表示收到一个重复的ack + s.dupAckCount++ + + // 收到三次的重复ack才会进入快速恢复。 + if s.dupAckCount < nDupAckThreshold { + return false + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 + // + // We only do the check here, the incrementing of last to the highest + // sequence number transmitted till now is done when enterFastRecovery + // is invoked. + // 我们只在这里进行检查,当调用enterFastRecovery时,最后一次到最高序列号的递增完成。 + if !s.fr.last.LessThan(seg.ackNumber) { + s.dupAckCount = 0 + return false + } + + // 调用拥塞控制的 HandleNDupAcks 处理三次重复ack + s.cc.HandleNDupAcks() + // 进入快速恢复状态 + s.enterFastRecovery() + s.dupAckCount = 0 + return true +} + +// handleRcvdSegment is called when a segment is received; it is responsible for +// updating the send-related state. +//收到段时调用 handleRcvdSegment; 它负责更新与发送相关的状态。 +func (s *sender) handleRcvdSegment(seg *segment) { + // Check if we can extract an RTT measurement from this ack. + // 如果rtt测量seq小于ack num,更新rto + if !s.ep.sendTSOk && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + s.updateRTO(time.Now().Sub(s.rttMeasureTime)) + s.rttMeasureSeqNum = s.sndNxt + } + + // Update Timestamp if required. See RFC7323, section-4.3. + s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + + // Count the duplicates and do the fast retransmit if needed. + // tcp的拥塞控制:检查是否有重复的ack,是否进入快速重传和快速恢复状态 + rtx := s.checkDuplicateAck(seg) + + // Stash away the current window size. + // 存放当前窗口大小。 + s.sndWnd = seg.window + + // Ignore ack if it doesn't acknowledge any new data. + // 获取确认号 + ack := seg.ackNumber + // 如果ack在最小未确认的seq和下一seg的seq之间 + if (ack - 1).InRange(s.sndUna, s.sndNxt) { + s.dupAckCount = 0 + // When an ack is received we must reset the timer. We stop it + // here and it will be restarted later if needed. + s.resendTimer.disable() + + // See : https://tools.ietf.org/html/rfc1323#section-3.3. + // Specifically we should only update the RTO using TSEcr if the + // following condition holds: + // + // A TSecr value received in a segment is used to update the + // averaged RTT measurement only if the segment acknowledges + // some new data, i.e., only if it advances the left edge of + // the send window. + if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + // TSVal/Ecr values sent by Netstack are at a millisecond + // granularity. + elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + s.updateRTO(elapsed) + } + + // 获取这次确认的字节数,即 ack - snaUna + acked := s.sndUna.Size(ack) + // 更新下一个未确认的序列号 + s.sndUna = ack + + ackLeft := acked + originalOutstanding := s.outstanding + // 从发送链表中删除已经确认的数据,发送窗口的滑动。 + for ackLeft > 0 { + // We use logicalLen here because we can have FIN + // segments (which are always at the end of list) that + // have no data, but do consume a sequence number. + seg := s.writeList.Front() + datalen := seg.logicalLen() + + if datalen > ackLeft { + seg.data.TrimFront(int(ackLeft)) + break + } + + if s.writeNext == seg { + s.writeNext = seg.Next() + } + // 从发送链表中删除已确认的tcp段。 + s.writeList.Remove(seg) + // 因为有一个tcp段确认了,所以 outstanding 减1 + s.outstanding-- + seg.decRef() + ackLeft -= datalen + } + + // Update the send buffer usage and notify potential waiters. + // 当收到ack确认时,需要更新发送缓冲占用 + s.ep.updateSndBufferUsage(int(acked)) + + // If we are not in fast recovery then update the congestion + // window based on the number of acknowledged packets. + // tcp拥塞控制:如果没有进入快速恢复状态,那么根据确认的数据包的数量更新拥塞窗口。 + if !s.fr.active { + // 调用相应拥塞控制算法的 Update + s.cc.Update(originalOutstanding - s.outstanding) + } + + // It is possible for s.outstanding to drop below zero if we get + // a retransmit timeout, reset outstanding to zero but later + // get an ack that cover previously sent data. + // 如果发生超时重传时,s.outstanding可能会降到零以下, + // 重置为零但后来得到一个覆盖先前发送数据的确认。 + if s.outstanding < 0 { + s.outstanding = 0 + } + } + + // Now that we've popped all acknowledged data from the retransmit + // queue, retransmit if needed. + // tcp拥塞控制:如果需要快速重传,则重传数据,但是只是重传下一个数据 + if rtx { + // tcp拥塞控制:快速重传 + s.resendSegment() + } + + // Send more data now that some of the pending data has been ack'd, or + // that the window opened up, or the congestion window was inflated due + // to a duplicate ack during fast recovery. This will also re-enable + // the retransmit timer if needed. + // 现在某些待处理数据已被确认,或者窗口打开,或者由于快速恢复期间出现重复的ack而导致拥塞窗口膨胀, + // 因此发送更多数据。如果需要,这也将重新启用重传计时器。 + s.sendData() +} + +// sendSegment sends a new segment containing the given payload, flags and +// sequence number. +// 根据给定的参数,负载数据、flags标记和序列号来发送数据 +func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { + s.lastSendTime = time.Now() + if seq == s.rttMeasureSeqNum { + s.rttMeasureTime = s.lastSendTime + } + + rcvNxt, rcvWnd := s.ep.rcv.getSendParams() + + // Remember the max sent ack. + s.maxSentAck = rcvNxt + + return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) +} diff --git a/netstack/tcpip/transport/tcp/tcp_sack_test.go b/netstack/tcpip/transport/tcp/tcp_sack_test.go new file mode 100644 index 0000000..2940867 --- /dev/null +++ b/netstack/tcpip/transport/tcp/tcp_sack_test.go @@ -0,0 +1,350 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "fmt" + "reflect" + "testing" + + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/transport/tcp" + "tcpip/netstack/tcpip/transport/tcp/testing/context" +) + +// createConnectWithSACKPermittedOption creates and connects c.ep with the +// SACKPermitted option enabled if the stack in the context has the SACK support +// enabled. +func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { + return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) +} + +func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { + t.Helper() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) + } +} + +// TestSackPermittedConnect establishes a connection with the SACK option +// enabled. +func TestSackPermittedConnect(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + setStackSACKPermitted(t, c, sackEnabled) + rep := createConnectedWithSACKPermittedOption(c) + data := []byte{1, 2, 3} + + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum + rep.VerifyACKNoSACK() + + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // Restore the saved sequence number so that the + // VerifyXXX calls use the right sequence number for + // checking ACK numbers. + rep.NextSeqNum = savedSeqNum + if sackEnabled { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() + } + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) + } +} + +// TestSackDisabledConnect establishes a connection with the SACK option +// disabled and verifies that no SACKs are sent for out of order segments. +func TestSackDisabledConnect(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) + + data := []byte{1, 2, 3} + + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum + rep.VerifyACKNoSACK() + + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) + + // The ACK should contain the older sequence number and + // no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) + } +} + +// TestSackPermittedAccept accepts and establishes a connection with the +// SACKPermitted option enabled if the connection request specifies the +// SACKPermitted option. In case of SYN cookies SACK should be disabled as we +// don't encode the SACK information in the cookie. +func TestSackPermittedAccept(t *testing.T) { + type testCase struct { + cookieEnabled bool + sackPermitted bool + wndScale int + wndSize uint16 + } + + testCases := []testCase{ + // When cookie is used window scaling is disabled. + {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + savedSynCountThreshold := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + }() + for _, tc := range testCases { + t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { + if tc.cookieEnabled { + tcp.SynRcvdCountThreshold = 0 + } else { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + } + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + + savedSeqNum := rep.NextSeqNum + + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // The ACK should contain the older + // sequence number. + rep.NextSeqNum = savedSeqNum + if sackEnabled && tc.sackPermitted { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() + } + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. + rep.VerifyACKNoSACK() + }) + } + }) + } +} + +// TestSackDisabledAccept accepts and establishes a connection with +// the SACKPermitted option disabled and verifies that no SACKs are +// sent for out of order packets. +func TestSackDisabledAccept(t *testing.T) { + type testCase struct { + cookieEnabled bool + wndScale int + wndSize uint16 + } + + testCases := []testCase{ + // When cookie is used window scaling is disabled. + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + savedSynCountThreshold := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + }() + for _, tc := range testCases { + t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { + if tc.cookieEnabled { + tcp.SynRcvdCountThreshold = 0 + } else { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + } + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + savedSeqNum := rep.NextSeqNum + + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) + + // The ACK should contain the older + // sequence number and no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. + rep.VerifyACKNoSACK() + }) + } + }) + } +} + +func TestUpdateSACKBlocks(t *testing.T) { + testCases := []struct { + segStart seqnum.Value + segEnd seqnum.Value + rcvNxt seqnum.Value + sackBlocks []header.SACKBlock + updated []header.SACKBlock + }{ + // Trivial cases where current SACK block list is empty and we + // have an out of order delivery. + {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, + {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, + {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, + + // Cases where current SACK block list is not empty and we have + // an out of order delivery. Tests that the updated SACK block + // list has the first block as the one that contains the new + // SACK block representing the segment that was just delivered. + {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, + {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, + {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, + + // Ensure that we only retain header.MaxSACKBlocks and drop the + // oldest one if adding a new block exceeds + // header.MaxSACKBlocks. + {24, 30, 9, + []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, + []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, + + // Cases where segment extends an existing SACK block. + {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, + {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, + {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, + {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, + {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, + {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, + {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, + {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, + + // Cases where segment contains rcvNxt. + {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, + } + + for _, tc := range testCases { + var sack tcp.SACKInfo + copy(sack.Blocks[:], tc.sackBlocks) + sack.NumBlocks = len(tc.sackBlocks) + tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) + if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { + t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) + } + + } +} + +func TestTrimSackBlockList(t *testing.T) { + testCases := []struct { + rcvNxt seqnum.Value + sackBlocks []header.SACKBlock + trimmed []header.SACKBlock + }{ + // Simple cases where we trim whole entries. + {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, + {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, + {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, + {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, + // Cases where we need to update a block. + {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, + {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, + {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, + {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, + } + for _, tc := range testCases { + var sack tcp.SACKInfo + copy(sack.Blocks[:], tc.sackBlocks) + sack.NumBlocks = len(tc.sackBlocks) + tcp.TrimSACKBlockList(&sack, tc.rcvNxt) + if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { + t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) + } + } +} diff --git a/netstack/tcpip/transport/tcp/tcp_segment_list.go b/netstack/tcpip/transport/tcp/tcp_segment_list.go new file mode 100644 index 0000000..48b4640 --- /dev/null +++ b/netstack/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,173 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *segmentList) Back() *segment { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *segmentList) PushFront(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(l.head) + segmentElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *segmentList) PushBack(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(nil) + segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *segmentList) InsertAfter(b, e *segment) { + a := segmentElementMapper{}.linkerFor(b).Next() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *segmentList) InsertBefore(a, e *segment) { + b := segmentElementMapper{}.linkerFor(a).Prev() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *segmentList) Remove(e *segment) { + prev := segmentElementMapper{}.linkerFor(e).Prev() + next := segmentElementMapper{}.linkerFor(e).Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/netstack/tcpip/transport/tcp/tcp_test.go b/netstack/tcpip/transport/tcp/tcp_test.go new file mode 100644 index 0000000..428e7dc --- /dev/null +++ b/netstack/tcpip/transport/tcp/tcp_test.go @@ -0,0 +1,3565 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "bytes" + "fmt" + "math" + "testing" + "time" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/checker" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/link/loopback" + "tcpip/netstack/tcpip/link/sniffer" + "tcpip/netstack/tcpip/network/ipv4" + "tcpip/netstack/tcpip/network/ipv6" + "tcpip/netstack/tcpip/ports" + "tcpip/netstack/tcpip/seqnum" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/tcpip/transport/tcp" + "tcpip/netstack/tcpip/transport/tcp/testing/context" + "tcpip/netstack/waiter" +) + +const ( + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65535 + + // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an + // IPv4 endpoint when the MTU is set to defaultMTU in the test. + defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize +) + +func TestGiveUpConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Register for notification, then start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventOut) + defer wq.EventUnregister(&waitEntry) + + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + } + + // Close the connection, wait for completion. + ep.Close() + + // Wait for ep to become writable. + <-notifyCh + if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { + t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) + } +} + +func TestConnectIncrementActiveConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ActiveConnectionOpenings.Value() + 1 + + c.CreateConnected(789, 30000, nil) + if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %v, want = %v", got, want) + } +} + +func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.FailedConnectionAttempts.Value() + + c.CreateConnected(789, 30000, nil) + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionOpenings.Value() = %v, want = %v", got, want) + } +} + +func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + c.EP = ep + want := stats.TCP.FailedConnectionAttempts.Value() + 1 + + if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { + t.Errorf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute) + } + + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } +} + +func TestPassiveConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.PassiveConnectionOpenings.Value() + 1 + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if err := ep.Listen(1); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %v, want = %v", got, want) + } +} + +func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + c.EP = ep + want := stats.TCP.FailedConnectionAttempts.Value() + 1 + + if err := ep.Listen(1); err != tcpip.ErrInvalidEndpointState { + t.Errorf("got ep.Listen(1) = %v, want = %v", err, tcpip.ErrInvalidEndpointState) + } + + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %v, want = %v", got, want) + } +} + +func TestTCPSegmentsSentIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + // SYN and ACK + want := stats.TCP.SegmentsSent.Value() + 2 + c.CreateConnected(789, 30000, nil) + + if got := stats.TCP.SegmentsSent.Value(); got != want { + t.Errorf("got stats.TCP.SegmentsSent.Value() = %v, want = %v", got, want) + } +} + +func TestTCPResetsSentIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + want := stats.TCP.SegmentsSent.Value() + 1 + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + // If the AckNum is not the increment of the last sequence number, a RST + // segment is sent back in response. + AckNum: c.IRS + 2, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + c.GetPacket() + if got := stats.TCP.ResetsSent.Value(); got != want { + t.Errorf("got stats.TCP.ResetsSent.Value() = %v, want = %v", got, want) + } +} + +func TestTCPResetsReceivedIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + ackNum := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(ackNum, rcvWnd, nil) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: c.IRS.Add(2), + AckNum: ackNum.Add(2), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) + + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %v, want = %v", got, want) + } +} + +func TestActiveHandshake(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) +} + +func TestNonBlockingClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + ep := c.EP + c.EP = nil + + // Close the endpoint and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %v", diff) + } +} + +func TestConnectResetAfterClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + ep := c.EP + c.EP = nil + + // Close the endpoint, make sure we get a FIN segment, then acknowledge + // to complete closure of sender, but don't send our own FIN. + ep.Close() + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN, and send a RST. + time.Sleep(3 * time.Second) + for { + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin { + // This is a retransmit of the FIN, ignore it. + continue + } + + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + ), + ) + break + } +} + +func TestSimpleReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Receive data. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } + + // Check that ACK is received. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestOutOfOrderReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Send second half of data first, with seqnum 3 ahead of expected. + data := []byte{1, 2, 3, 4, 5, 6} + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that we get an ACK specifying which seqnum is expected. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Wait 200ms and check that no data has been received. + time.Sleep(200 * time.Millisecond) + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Send the first 3 bytes now. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive data. + read := make([]byte, 0, 6) + for len(read) < len(data) { + v, _, err := c.EP.Read(nil) + if err != nil { + if err == tcpip.ErrWouldBlock { + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + continue + } + t.Fatalf("Read failed: %v", err) + } + + read = append(read, v...) + } + + // Check that we received the data in proper order. + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) + } + + // Check that the whole data is acknowledged. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestOutOfOrderFlood(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create a new connection with initial window size of 10. + opt := tcpip.ReceiveBufferSizeOption(10) + c.CreateConnected(789, 30000, &opt) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Send 100 packets before the actual one that is expected. + data := []byte{1, 2, 3, 4, 5, 6} + for i := 0; i < 100; i++ { + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 796, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Send packet with seqnum 793. It must be discarded because the + // out-of-order buffer was filled by the previous packets. + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Now send the expected packet, seqnum 790. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that only packet 790 is acknowledged. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestRstOnCloseWithUnreadData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that ACK is received, this happens regardless of the read. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Now that we know we have unread data, let's just close the connection + // and verify that netstack sends an RST rather than a FIN. + c.EP.Close() + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // We shouldn't consume a sequence number on RST. + checker.SeqNum(uint32(c.IRS)+1), + )) + + // This final should be ignored because an ACK on a reset doesn't + // mean anything. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + len(data)), + AckNum: c.IRS.Add(seqnum.Size(2)), + RcvWnd: 30000, + }) +} + +func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that ACK is received, this happens regardless of the read. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Cause a FIN to be generated. + c.EP.Shutdown(tcpip.ShutdownWrite) + + // Make sure we get the FIN but DON't ACK IT. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + checker.SeqNum(uint32(c.IRS)+1), + )) + + // Cause a RST to be generated by closing the read end now since we have + // unread data. + c.EP.Shutdown(tcpip.ShutdownRead) + + // Make sure we get the RST + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // We shouldn't consume a sequence number on RST. + checker.SeqNum(uint32(c.IRS)+1), + )) + + // The ACK to the FIN should now be rejected since the connection has been + // closed by a RST. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + len(data)), + AckNum: c.IRS.Add(seqnum.Size(2)), + RcvWnd: 30000, + }) +} + +func TestFullWindowReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.ReceiveBufferSizeOption(10) + c.CreateConnected(789, 30000, &opt) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + _, _, err := c.EP.Read(nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("Read failed: %v", err) + } + + // Fill up the window. + data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that data is acknowledged, and window goes to zero. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(0), + ), + ) + + // Receive data and check it. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } + + // Check that we get an ACK for the newly non-zero window. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(10), + ), + ) +} + +func TestNoWindowShrinking(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Start off with a window size of 10, then shrink it to 5. + opt := tcpip.ReceiveBufferSizeOption(10) + c.CreateConnected(789, 30000, &opt) + + opt = 5 + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt failed: %v", err) + } + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrWouldBlock) + } + + // Send 3 bytes, check that the peer acknowledges them. + data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that data is acknowledged, and that window doesn't go to zero + // just yet because it was previously set to 10. It must go to 7 now. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(7), + ), + ) + + // Send 7 more bytes, check that the window fills up. + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(0), + ), + ) + + // Receive data and check it. + read := make([]byte, 0, 10) + for len(read) < len(data) { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + read = append(read, v...) + } + + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) + } + + // Check that we get an ACK for the newly non-zero window, which is the + // new size. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(5), + ), + ) +} + +func TestSimpleSend(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) +} + +func TestZeroWindowSend(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 0, nil) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Since the window is currently zero, check that no packet is received. + c.CheckNoPacket("Packet received when window is zero") + + // Open up the window. Data should be received now. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) +} + +func TestScaledWindowConnect(t *testing.T) { + // This test ensures that window scaling is used when the peer + // does advertise it and connection is established with Connect(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size greater than the maximum non-scaled window. + opt := tcpip.ReceiveBufferSizeOption(65535 * 3) + c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Check that data is received, and that advertised window is 0xbfff, + // that is, that it is scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xbfff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestNonScaledWindowConnect(t *testing.T) { + // This test ensures that window scaling is not used when the peer + // doesn't advertise it and connection is established with Connect(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size greater than the maximum non-scaled window. + opt := tcpip.ReceiveBufferSizeOption(65535 * 3) + c.CreateConnected(789, 30000, &opt) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestScaledWindowAccept(t *testing.T) { + // This test ensures that window scaling is used when the peer + // does advertise it and connection is established with Accept(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + // Set the window size greater than the maximum non-scaled window. + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Check that data is received, and that advertised window is 0xbfff, + // that is, that it is scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xbfff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestNonScaledWindowAccept(t *testing.T) { + // This test ensures that window scaling is not used when the peer + // doesn't advertise it and connection is established with Accept(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + // Set the window size greater than the maximum non-scaled window. + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnect(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestZeroScaledWindowReceive(t *testing.T) { + // This test ensures that the endpoint sends a non-zero window size + // advertisement when the scaled window transitions from 0 to non-zero, + // but the actual window (not scaled) hasn't gotten to zero. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size such that a window scale of 4 will be used. + const wnd = 65535 * 10 + const ws = uint32(4) + opt := tcpip.ReceiveBufferSizeOption(wnd) + c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) + + // Write chunks of 50000 bytes. + remain := wnd + sent := 0 + data := make([]byte, 50000) + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(remain>>ws)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Make the window non-zero, but the scaled window zero. + if remain >= 16 { + data = data[:remain-15] + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(0), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Read some data. An ack should be sent in response to that. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(len(v)>>ws)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { + payloadMultiplier := 10 + dataLen := payloadMultiplier * maxPayload + data := make([]byte, dataLen) + for i := range data { + data[i] = byte(i) + } + + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Check that data is received in chunks. + bytesReceived := 0 + numPackets := 0 + for bytesReceived != dataLen { + b := c.GetPacket() + numPackets++ + tcp := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcp.Payload()) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[bytesReceived : bytesReceived+payloadLen] + if p := tcp.Payload(); !bytes.Equal(pdata, p) { + t.Fatalf("got data = %v, want = %v", p, pdata) + } + bytesReceived += payloadLen + var options []byte + if c.TimeStampEnabled { + // If timestamp option is enabled, echo back the timestamp and increment + // the TSEcr value included in the packet and send that back as the TSVal. + parsedOpts := tcp.ParsedOptions() + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) + options = tsOpt[:] + } + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), + RcvWnd: 30000, + TCPOpts: options, + }) + } + if numPackets == 1 { + t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") + } +} + +func TestSendGreaterThanMTU(t *testing.T) { + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestActiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + c := context.New(t, 65535) + defer c.Cleanup() + + c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestPassiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + const wndScale = 2 + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, wndScale, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 536 + const mtu = 2000 + c := context.New(t, mtu) + defer c.Cleanup() + + // Set the SynRcvd threshold to zero to force a syn cookie based accept + // to happen. + saved := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = saved + }() + tcp.SynRcvdCountThreshold = 0 + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestForwarderSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + s := c.Stack() + ch := make(chan *tcpip.Error, 1) + f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { + var err *tcpip.Error + c.EP, err = r.CreateEndpoint(&c.WQ) + ch <- err + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, 1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Wait for connection to be available. + select { + case err := <-ch: + if err != nil { + t.Fatalf("Error creating endpoint: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSynOptionsOnActiveConnect(t *testing.T) { + const mtu = 1400 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + const wndScale = 2 + if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + // Start connection attempt. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got c.EP.Connect(...) = %v, want = %v", err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}), + ), + ) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + // Wait for retransmit. + time.Sleep(1 * time.Second) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.SrcPort(tcp.SourcePort()), + checker.SeqNum(tcp.SequenceNumber()), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}), + ), + ) + + // Send SYN-ACK. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-ch: + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestCloseListener(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Close the listener and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %v", diff) + } +} + +func TestReceiveOnResetConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Send RST segment. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + RcvWnd: 30000, + }) + + // Try to read. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + +loop: + for { + switch _, _, err := c.EP.Read(nil); err { + case tcpip.ErrWouldBlock: + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for reset to arrive") + } + case tcpip.ErrConnectionReset: + break loop + default: + t.Fatalf("got c.EP.Read(nil) = %v, want = %v", err, tcpip.ErrConnectionReset) + } + } +} + +func TestSendOnResetConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Send RST segment. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + RcvWnd: 30000, + }) + + // Wait for the RST to be received. + time.Sleep(1 * time.Second) + + // Try to write. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Write(...) = %v, want = %v", err, tcpip.ErrConnectionReset) + } +} + +func TestFinImmediately(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinRetransmit(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Don't acknowledge yet. We should get a retransmit of the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithNoPendingData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write something out, and have it acknowledged. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Shutdown, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func DisabledTestFinWithPendingDataCwndFull(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write enough segments to fill the congestion window before ACK'ing + // any of them. + view := buffer.NewView(10) + for i := tcp.InitialCwnd; i > 0; i-- { + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + } + + next := uint32(c.IRS) + 1 + for i := tcp.InitialCwnd; i > 0; i-- { + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + } + + // Shutdown the connection, check that the FIN segment isn't sent + // because the congestion window doesn't allow it. Wait until a + // retransmit is received. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Send the ACK that will allow the FIN to be sent as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPendingData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write something out, and acknowledge it to get cwnd to 2. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Write new data, but don't acknowledge it. + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPartialAck(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write something out, and acknowledge it to get cwnd to 2. Also send + // FIN from the test side. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Check that we get an ACK for the fin. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Write new data, but don't acknowledge it. + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send an ACK for the data, but not for the FIN yet. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: seqnum.Value(next - 1), + RcvWnd: 30000, + }) + + // Check that we don't get a retransmit of the FIN. + c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) + + // Ack the FIN. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 791, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) +} + +func DisabledTestExponentialIncreaseDuringSlowStart(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // Double the number of expected packets for the next iteration. + expected *= 2 + } +} + +func DisabledTestCongestionAvoidance(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd/2. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected/2. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 (which "consumes" expected/2-1 of the + // acknowledgements), then the congestion avoidance part will consume + // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack + // remains in the "ack count" (which will cause cwnd to be incremented + // once it reaches cwnd acks). + // + // So we're straight into congestion avoidance with cwnd set to + // expected/2 + 1. + // + // Check that packets trains of cwnd packets are sent, and that cwnd is + // incremented by 1 after we acknowledge each packet. + expected = expected/2 + 1 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + expected++ + } +} + +// cubicCwnd returns an estimate of a cubic window given the +// originalCwnd, wMax, last congestion event time and sRTT. +func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { + cwnd := float64(origCwnd) + // We wait 50ms between each iteration so sRTT as computed by cubic + // should be close to 50ms. + elapsed := (time.Since(congEventTime) + sRTT).Seconds() + k := math.Cbrt(float64(wMax) * 0.3 / 0.7) + wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) + cwnd += (wtRTT - cwnd) / cwnd + return int(cwnd) +} + +func TestCubicCongestionAvoidance(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + enableCUBIC(t, c) + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd * 0.7. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all pending data. + c.SendAck(790, bytesRead) + + // Store away the time we sent the ACK and assuming a 200ms RTO + // we estimate that the sender will have an RTO 200ms from now + // and go back into slow start. + packetDropTime := time.Now().Add(200 * time.Millisecond) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 essentially putting the connection + // straight into congestion avoidance. + wMax := expected + // Lower expected as per cubic spec after a congestion event. + expected = int(float64(expected) * 0.7) + cwnd := expected + for i := 0; i < iterations; i++ { + // Cubic grows window independent of ACKs. Cubic Window growth + // is a function of time elapsed since last congestion event. + // As a result the congestion window does not grow + // deterministically in response to ACKs. + // + // We need to roughly estimate what the cwnd of the sender is + // based on when we sent the dupacks. + cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) + + packetsExpected := cwnd + for j := 0; j < packetsExpected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + t.Logf("expected packets received, next trying to receive any extra packets that may come") + + // If our estimate was correct there should be no more pending packets. + // We attempt to read a packet a few times with a short sleep in between + // to ensure that we don't see the sender send any unexpected packets. + unexpectedPackets := 0 + for { + gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) + if !gotPacket { + break + } + bytesRead += maxPayload + unexpectedPackets++ + time.Sleep(1 * time.Millisecond) + } + if unexpectedPackets != 0 { + t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + } +} + +func DisabledTestFastRecovery(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + for i := 0; i < 3; i++ { + c.SendAck(790, rtxOffset) + } + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Now send 7 mode duplicate acks. Each of these should cause a window + // inflation by 1 and cause the sender to send an extra packet. + for i := 0; i < 7; i++ { + c.SendAck(790, rtxOffset) + } + + recover := bytesRead + + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the retransmit due to partial ack. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Receive the 10 extra packets that should have been released due to + // the congestion window inflation in recovery. + for i := 0; i < 10; i++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // A partial ACK during recovery should reduce congestion window by the + // number acked. Since we had "expected" packets outstanding before sending + // partial ack and we acked expected/2 , the cwnd and outstanding should + // be expected/2 + 7. Which means the sender should not send any more packets + // till we ack this one. + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", + 50*time.Millisecond) + + // Acknowledge all pending data to recover point. + c.SendAck(790, recover) + + // At this point, the cwnd should reset to expected/2 and there are 10 + // packets outstanding. + // + // NOTE: Technically netstack is incorrect in that we adjust the cwnd on + // the same segment that takes us out of recovery. But because of that + // the actual cwnd at exit of recovery will be expected/2 + 1 as we + // acked a cwnd worth of packets which will increase the cwnd further by + // 1 in congestion avoidance. + // + // Now in the first iteration since there are 10 packets outstanding. + // We would expect to get expected/2 +1 - 10 packets. But subsequent + // iterations will send us expected/2 + 1 + 1 (per iteration). + expected = expected/2 + 1 - 10 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 10 + } + expected++ + } +} + +func DisabledTestRetransmit(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in two shots. Packets will only be written at the + // MTU size though. + half := data[:len(data)/2] + if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + half = data[len(data)/2:] + if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Wait for a timeout and retransmit. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the remaining data, making sure that acknowledged data is not + // retransmitted. + for offset := rtxOffset; offset < len(data); offset += maxPayload { + c.ReceiveAndCheckPacket(data, offset, maxPayload) + c.SendAck(790, offset+maxPayload) + } + + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) +} + +func TestUpdateListenBacklog(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Update the backlog with another Listen() on the same endpoint. + if err := ep.Listen(20); err != nil { + t.Fatalf("Listen failed to update backlog: %v", err) + } + + ep.Close() +} + +func scaledSendWindow(t *testing.T, scale uint8) { + // This test ensures that the endpoint is using the right scaling by + // sending a buffer that is larger than the window size, and ensuring + // that the endpoint doesn't send more than allowed. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize + c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + header.TCPOptionWS, 3, scale, header.TCPOptionNOP, + }) + + // Open up the window with a scaled value. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 1, + }) + + // Send some data. Check that it's capped by the window size. + view := buffer.NewView(65535) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Check that only data that fits in the scaled window is sent. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen((1< maxTotalSize { + buf = buf[:maxTotalSize] + } + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: TestAddr, + DstAddr: StackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) + icmp.SetType(typ) + icmp.SetCode(code) + + copy(icmp[header.ICMPv4MinimumSize:], p1) + copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) + + // Inject packet. + c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) +} + +// BuildSegment builds a TCP segment based on the given Headers and payload. +func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(tcp.ProtocolNumber), + SrcAddr: TestAddr, + DstAddr: StackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv4MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.Checksum([]byte(TestAddr), 0) + xsum = header.Checksum([]byte(StackAddr), xsum) + xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum) + + // Calculate the TCP checksum and set it. + length := uint16(header.TCPMinimumSize + len(h.TCPOpts) + len(payload)) + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum, length)) + + // Inject packet. + return buf.ToVectorisedView() +} + +// SendSegment sends a TCP segment that has already been built and written to a +// buffer.VectorisedView. +func (c *Context) SendSegment(s buffer.VectorisedView) { + c.linkEP.Inject(ipv4.ProtocolNumber, s) +} + +// SendPacket builds and sends a TCP segment(with the provided payload & TCP +// headers) in an IPv4 packet via the link layer endpoint. +func (c *Context) SendPacket(payload []byte, h *Headers) { + c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h)) +} + +// SendAck sends an ACK packet. +func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(testInitialSequenceNumber).Add(1), + AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), + RcvWnd: 30000, + }) +} + +// ReceiveAndCheckPacket reads a packet from the link layer endpoint and +// verifies that the packet packet payload of packet matches the slice +// of data indicated by offset & size. +func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(TestPort), + checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[offset:][:size] + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { + c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } +} + +// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint +// and verifies that the packet packet payload of packet matches the slice of +// data indicated by offset & size. It returns true if a packet was received and +// processed. +func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { + b := c.GetPacketNonBlocking() + if b == nil { + return false + } + checker.IPv4(c.t, b, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(TestPort), + checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[offset:][:size] + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { + c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } + return true +} + +// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only +// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 +// only endpoint instead of a default dual stack socket. +func (c *Context) CreateV6Endpoint(v6only bool) { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.EP.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } +} + +// GetV6Packet reads a single packet from the link layer endpoint of the context +// and asserts that it is an IPv6 Packet with the expected src/dest addresses. +func (c *Context) GetV6Packet() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of +// the context. +func (c *Context) SendV6Packet(payload []byte, h *Headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + NextHeader: uint8(tcp.ProtocolNumber), + HopLimit: 65, + SrcAddr: TestV6Addr, + DstAddr: StackV6Addr, + }) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv6MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: header.TCPMinimumSize, + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.Checksum([]byte(TestV6Addr), 0) + xsum = header.Checksum([]byte(StackV6Addr), xsum) + xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum) + + // Calculate the TCP checksum and set it. + length := uint16(header.TCPMinimumSize + len(payload)) + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum, length)) + + // Inject packet. + c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) +} + +// CreateConnected creates a connected TCP endpoint. +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { + c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if epRcvBuf != nil { + if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) + if err != tcpip.ErrConnectStarted { + c.t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + c.SendPacket(nil, &Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Receive ACK packet. + checker.IPv4(c.t, c.GetPacket(), + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-notifyCh: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + + c.Port = tcp.SourcePort() +} + +// RawEndpoint is just a small wrapper around a TCP endpoint's state to make +// sending data and ACK packets easy while being able to manipulate the sequence +// numbers and timestamp values as needed. +type RawEndpoint struct { + C *Context + SrcPort uint16 + DstPort uint16 + Flags int + NextSeqNum seqnum.Value + AckNum seqnum.Value + WndSize seqnum.Size + RecentTS uint32 // Stores the latest timestamp to echo back. + TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. + + // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. + SACKPermitted bool +} + +// SendPacketWithTS embeds the provided tsVal in the Timestamp option +// for the packet to be sent out. +func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { + r.TSVal = tsVal + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) + r.SendPacket(payload, tsOpt[:]) +} + +// SendPacket is a small wrapper function to build and send packets. +func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { + packetHeaders := &Headers{ + SrcPort: r.SrcPort, + DstPort: r.DstPort, + Flags: r.Flags, + SeqNum: r.NextSeqNum, + AckNum: r.AckNum, + RcvWnd: r.WndSize, + TCPOpts: opts, + } + r.C.SendPacket(payload, packetHeaders) + r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPTimestampChecker(true, 0, tsVal), + ), + ) + // Store the parsed TSVal from the ack as recentTS. + tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) + opts := tcpSeg.ParsedOptions() + r.RecentTS = opts.TSVal +} + +// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. +func (r *RawEndpoint) VerifyACKNoSACK() { + r.VerifyACKHasSACK(nil) +} + +// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. +func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { + // Read ACK and verify that the TCP options in the segment do + // not contain a SACK block. + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSACKBlockChecker(sackBlocks), + ), + ) +} + +// CreateConnectedWithOptions creates and connects c.ep with the specified TCP +// options enabled and returns a RawEndpoint which represents the other end of +// the connection. +// +// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK +// does not carry an option that was not requested. +func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) + } + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} + err = c.EP.Connect(testFullAddr) + if err != tcpip.ErrConnectStarted { + c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) + } + // Receive SYN packet. + b := c.GetPacket() + // Validate that the syn has the timestamp option and a valid + // TS value. + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{ + MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize), + TS: true, + WS: defaultWindowScale, + SACKPermitted: c.SACKEnabled(), + }), + ), + ) + tcpSeg := header.TCP(header.IPv4(b).Payload()) + synOptions := header.ParseSynOptions(tcpSeg.Options(), false) + + // Build options w/ tsVal to be sent in the SYN-ACK. + synAckOptions := make([]byte, 40) + offset := 0 + if wantOptions.TS { + offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) + } + if wantOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) + } + + offset += header.AddTCPOptionPadding(synAckOptions, offset) + + // Build SYN-ACK. + c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + TCPOpts: synAckOptions[:offset], + }) + + // Read ACK. + ackPacket := c.GetPacket() + + // Verify TCP header fields. + tcpCheckers := []checker.TransportChecker{ + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS) + 1), + checker.AckNum(uint32(iss) + 1), + } + + // Verify that tsEcr of ACK packet is wantOptions.TSVal if the + // timestamp option was enabled, if not then we verify that + // there is no timestamp in the ACK packet. + if wantOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) + + ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) + ackOptions := ackSeg.ParsedOptions() + + // Wait for connection to be established. + select { + case <-notifyCh: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + + // Store the source port in use by the endpoint. + c.Port = tcpSeg.SourcePort() + + // Mark in context that timestamp option is enabled for this endpoint. + c.TimeStampEnabled = true + + return &RawEndpoint{ + C: c, + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagAck | header.TCPFlagPsh, + NextSeqNum: iss + 1, + AckNum: c.IRS.Add(1), + WndSize: 30000, + RecentTS: ackOptions.TSVal, + TSVal: wantOptions.TSVal, + SACKPermitted: wantOptions.SACKPermitted, + } +} + +// AcceptWithOptions initializes a listening endpoint and connects to it with the +// provided options enabled. It also verifies that the SYN-ACK has the expected +// values for the provided options. +// +// The function returns a RawEndpoint representing the other end of the accepted +// endpoint. +func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: StackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + c.t.Fatalf("Listen failed: %v", err) + } + + rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + c.t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for accept") + } + } + return rep +} + +// PassiveConnect just disables WindowScaling and delegates the call to +// PassiveConnectWithOptions. +func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { + synOptions.WS = -1 + c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) +} + +// PassiveConnectWithOptions initiates a new connection (with the specified TCP +// options enabled) to the port on which the Context.ep is listening for new +// connections. It also validates that the SYN-ACK has the expected values for +// the enabled options. +// +// NOTE: MSS is not a negotiated option and it can be asymmetric +// in each direction. This function uses the maxPayload to set the MSS to be +// sent to the peer on a connect and validates that the MSS in the SYN-ACK +// response is equal to the MTU - (tcphdr len + iphdr len). +// +// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the +// value of the window scaling option to be sent in the SYN. If synOptions.WS > +// 0 then we send the WindowScale option. +func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + opts := make([]byte, 40) + offset := 0 + offset += header.EncodeMSSOption(uint32(maxPayload), opts) + + if synOptions.WS >= 0 { + offset += header.EncodeWSOption(3, opts[offset:]) + } + if synOptions.TS { + offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) + } + + if synOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(opts[offset:]) + } + + paddingToAdd := 4 - offset%4 + // Now add any padding bytes that might be required to quad align the + // options. + for i := offset; i < offset+paddingToAdd; i++ { + opts[i] = header.TCPOptionNOP + } + offset += paddingToAdd + + // Send a SYN request. + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + TCPOpts: opts[:offset], + }) + + // Receive the SYN-ACK reply. Make sure MSS and other expected options + // are present. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(StackPort), + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(iss) + 1), + checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), + } + + // If TS option was enabled in the original SYN then add a checker to + // validate the Timestamp option in the SYN-ACK. + if synOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) + rcvWnd := seqnum.Size(30000) + ackHeaders := &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: rcvWnd, + } + + // If WS was expected to be in effect then scale the advertised window + // correspondingly. + if synOptions.WS > 0 { + ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) + } + + parsedOpts := tcp.ParsedOptions() + if synOptions.TS { + // Echo the tsVal back to the peer in the tsEcr field of the + // timestamp option. + // Increment TSVal by 1 from the value sent in the SYN and echo + // the TSVal in the SYN-ACK in the TSEcr field. + opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) + ackHeaders.TCPOpts = opts[:] + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + c.Port = StackPort + + return &RawEndpoint{ + C: c, + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagPsh | header.TCPFlagAck, + NextSeqNum: iss + 1, + AckNum: c.IRS + 1, + WndSize: rcvWnd, + SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), + RecentTS: parsedOpts.TSVal, + TSVal: synOptions.TSVal + 1, + } +} + +// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true +// for the Stack in the context. +func (c *Context) SACKEnabled() bool { + var v tcp.SACKEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return false + } + return bool(v) +} diff --git a/netstack/tcpip/transport/tcp/timer.go b/netstack/tcpip/transport/tcp/timer.go new file mode 100644 index 0000000..866ddd9 --- /dev/null +++ b/netstack/tcpip/transport/tcp/timer.go @@ -0,0 +1,142 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "tcpip/netstack/sleep" +) + +type timerState int + +const ( + timerStateDisabled timerState = iota + timerStateEnabled + timerStateOrphaned +) + +// timer is a timer implementation that reduces the interactions with the +// runtime timer infrastructure by letting timers run (and potentially +// eventually expire) even if they are stopped. It makes it cheaper to +// disable/reenable timers at the expense of spurious wakes. This is useful for +// cases when the same timer is disabled/reenabled repeatedly with relatively +// long timeouts farther into the future. +// +// TCP retransmit timers benefit from this because they the timeouts are long +// (currently at least 200ms), and get disabled when acks are received, and +// reenabled when new pending segments are sent. +// +// It is advantageous to avoid interacting with the runtime because it acquires +// a global mutex and performs O(log n) operations, where n is the global number +// of timers, whenever a timer is enabled or disabled, and may make a syscall. +// +// This struct is thread-compatible. +// 定时器的实现 +type timer struct { + // state is the current state of the timer, it can be one of the + // following values: + // disabled - the timer is disabled. + // orphaned - the timer is disabled, but the runtime timer is + // enabled, which means that it will evetually cause a + // spurious wake (unless it gets enabled again before + // then). + // enabled - the timer is enabled, but the runtime timer may be set + // to an earlier expiration time due to a previous + // orphaned state. + state timerState + + // target is the expiration time of the current timer. It is only + // meaningful in the enabled state. + target time.Time + + // runtimeTarget is the expiration time of the runtime timer. It is + // meaningful in the enabled and orphaned states. + runtimeTarget time.Time + + // timer is the runtime timer used to wait on. + timer *time.Timer +} + +// init initializes the timer. Once it expires, it the given waker will be +// asserted. +func (t *timer) init(w *sleep.Waker) { + t.state = timerStateDisabled + + // Initialize a runtime timer that will assert the waker, then + // immediately stop it. + t.timer = time.AfterFunc(time.Hour, func() { + w.Assert() + }) + t.timer.Stop() +} + +// cleanup frees all resources associated with the timer. +func (t *timer) cleanup() { + t.timer.Stop() +} + +// checkExpiration checks if the given timer has actually expired, it should be +// called whenever a sleeper wakes up due to the waker being asserted, and is +// used to check if it's a supurious wake (due to a previously orphaned timer) +// or a legitimate one. +func (t *timer) checkExpiration() bool { + // Transition to fully disabled state if we're just consuming an + // orphaned timer. + if t.state == timerStateOrphaned { + t.state = timerStateDisabled + return false + } + + // The timer is enabled, but it may have expired early. Check if that's + // the case, and if so, reset the runtime timer to the correct time. + now := time.Now() + if now.Before(t.target) { + t.runtimeTarget = t.target + t.timer.Reset(t.target.Sub(now)) + return false + } + + // The timer has actually expired, disable it for now and inform the + // caller. + t.state = timerStateDisabled + return true +} + +// disable disables the timer, leaving it in an orphaned state if it wasn't +// already disabled. +func (t *timer) disable() { + if t.state != timerStateDisabled { + t.state = timerStateOrphaned + } +} + +// enabled returns true if the timer is currently enabled, false otherwise. +func (t *timer) enabled() bool { + return t.state == timerStateEnabled +} + +// enable enables the timer, programming the runtime timer if necessary. +func (t *timer) enable(d time.Duration) { + t.target = time.Now().Add(d) + + // Check if we need to set the runtime timer. + if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { + t.runtimeTarget = t.target + t.timer.Reset(d) + } + + t.state = timerStateEnabled +} diff --git a/netstack/tcpip/transport/tcpconntrack/tcp_conntrack.go b/netstack/tcpip/transport/tcpconntrack/tcp_conntrack.go new file mode 100644 index 0000000..fb32548 --- /dev/null +++ b/netstack/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -0,0 +1,348 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcpconntrack implements a TCP connection tracking object. It allows +// users with access to a segment stream to figure out when a connection is +// established, reset, and closed (and in the last case, who closed first). +package tcpconntrack + +import ( + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/seqnum" +) + +// Result is returned when the state of a TCB is updated in response to an +// inbound or outbound segment. +type Result int + +const ( + // ResultDrop indicates that the segment should be dropped. + ResultDrop Result = iota + + // ResultConnecting indicates that the connection remains in a + // connecting state. + ResultConnecting + + // ResultAlive indicates that the connection remains alive (connected). + ResultAlive + + // ResultReset indicates that the connection was reset. + ResultReset + + // ResultClosedByPeer indicates that the connection was gracefully + // closed, and the inbound stream was closed first. + ResultClosedByPeer + + // ResultClosedBySelf indicates that the connection was gracefully + // closed, and the outbound stream was closed first. + ResultClosedBySelf +) + +// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP +// connection and inform the caller when the connection has been closed. +type TCB struct { + inbound stream + outbound stream + + // State handlers. + handlerInbound func(*TCB, header.TCP) Result + handlerOutbound func(*TCB, header.TCP) Result + + // firstFin holds a pointer to the first stream to send a FIN. + firstFin *stream + + // state is the current state of the stream. + state Result +} + +// Init initializes the state of the TCB according to the initial SYN. +func (t *TCB) Init(initialSyn header.TCP) { + t.handlerInbound = synSentStateInbound + t.handlerOutbound = synSentStateOutbound + + iss := seqnum.Value(initialSyn.SequenceNumber()) + t.outbound.una = iss + t.outbound.nxt = iss.Add(logicalLen(initialSyn)) + t.outbound.end = t.outbound.nxt + + // Even though "end" is a sequence number, we don't know the initial + // receive sequence number yet, so we store the window size until we get + // a SYN from the peer. + t.inbound.una = 0 + t.inbound.nxt = 0 + t.inbound.end = seqnum.Value(initialSyn.WindowSize()) + t.state = ResultConnecting +} + +// UpdateStateInbound updates the state of the TCB based on the supplied inbound +// segment. +func (t *TCB) UpdateStateInbound(tcp header.TCP) Result { + st := t.handlerInbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// UpdateStateOutbound updates the state of the TCB based on the supplied +// outbound segment. +func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { + st := t.handlerOutbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// IsAlive returns true as long as the connection is established(Alive) +// or connecting state. +func (t *TCB) IsAlive() bool { + return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed()) +} + +// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream. +func (t *TCB) OutboundSendSequenceNumber() seqnum.Value { + return t.outbound.nxt +} + +// InboundSendSequenceNumber returns the snd.NXT for the inbound stream. +func (t *TCB) InboundSendSequenceNumber() seqnum.Value { + return t.inbound.nxt +} + +// adapResult modifies the supplied "Result" according to the state of the TCB; +// if r is anything other than "Alive", or if one of the streams isn't closed +// yet, it is returned unmodified. Otherwise it's converted to either +// ClosedBySelf or ClosedByPeer depending on which stream was closed first. +func (t *TCB) adaptResult(r Result) Result { + // Check the unmodified case. + if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() { + return r + } + + // Find out which was closed first. + if t.firstFin == &t.outbound { + return ResultClosedBySelf + } + + return ResultClosedByPeer +} + +// synSentStateInbound is the state handler for inbound segments when the +// connection is in SYN-SENT state. +func synSentStateInbound(t *TCB, tcp header.TCP) Result { + flags := tcp.Flags() + ackPresent := flags&header.TCPFlagAck != 0 + ack := seqnum.Value(tcp.AckNumber()) + + // Ignore segment if ack is present but not acceptable. + if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) { + return ResultConnecting + } + + // If reset is specified, we will let the packet through no matter what + // but we will also destroy the connection if the ACK is present (and + // implicitly acceptable). + if flags&header.TCPFlagRst != 0 { + if ackPresent { + t.inbound.rstSeen = true + return ResultReset + } + return ResultConnecting + } + + // Ignore segment if SYN is not set. + if flags&header.TCPFlagSyn == 0 { + return ResultConnecting + } + + // Update state informed by this SYN. + irs := seqnum.Value(tcp.SequenceNumber()) + t.inbound.una = irs + t.inbound.nxt = irs.Add(logicalLen(tcp)) + t.inbound.end += irs + + t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize())) + + // If the ACK was set (it is acceptable), update our unacknowledgement + // tracking. + if ackPresent { + // Advance the "una" and "end" indices of the outbound stream. + if t.outbound.una.LessThan(ack) { + t.outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) { + t.outbound.end = end + } + } + + // Update handlers so that new calls will be handled by new state. + t.handlerInbound = allOtherInbound + t.handlerOutbound = allOtherOutbound + + return ResultAlive +} + +// synSentStateOutbound is the state handler for outbound segments when the +// connection is in SYN-SENT state. +func synSentStateOutbound(t *TCB, tcp header.TCP) Result { + // Drop outbound segments that aren't retransmits of the original one. + if tcp.Flags() != header.TCPFlagSyn || + tcp.SequenceNumber() != uint32(t.outbound.una) { + return ResultDrop + } + + // Update the receive window. We only remember the largest value seen. + if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end { + t.inbound.end = wnd + } + + return ResultConnecting +} + +// update updates the state of inbound and outbound streams, given the supplied +// inbound segment. For outbound segments, this same function can be called with +// swapped inbound/outbound streams. +func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result { + // Ignore segments out of the window. + s := seqnum.Value(tcp.SequenceNumber()) + if !inbound.acceptable(s, dataLen(tcp)) { + return ResultAlive + } + + flags := tcp.Flags() + if flags&header.TCPFlagRst != 0 { + inbound.rstSeen = true + return ResultReset + } + + // Ignore segments that don't have the ACK flag, and those with the SYN + // flag. + if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 { + return ResultAlive + } + + // Ignore segments that acknowledge not yet sent data. + ack := seqnum.Value(tcp.AckNumber()) + if outbound.nxt.LessThan(ack) { + return ResultAlive + } + + // Advance the "una" and "end" indices of the outbound stream. + if outbound.una.LessThan(ack) { + outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) { + outbound.end = end + } + + // Advance the "nxt" index of the inbound stream. + end := s.Add(logicalLen(tcp)) + if inbound.nxt.LessThan(end) { + inbound.nxt = end + } + + // Note the index of the FIN segment. And stash away a pointer to the + // first stream to see a FIN. + if flags&header.TCPFlagFin != 0 && !inbound.finSeen { + inbound.finSeen = true + inbound.fin = end - 1 + + if *firstFin == nil { + *firstFin = inbound + } + } + + return ResultAlive +} + +// allOtherInbound is the state handler for inbound segments in all states +// except SYN-SENT. +func allOtherInbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin)) +} + +// allOtherOutbound is the state handler for outbound segments in all states +// except SYN-SENT. +func allOtherOutbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin)) +} + +// streams holds the state of a TCP unidirectional stream. +type stream struct { + // The interval [una, end) is the allowed interval as defined by the + // receiver, i.e., anything less than una has already been acknowledged + // and anything greater than or equal to end is beyond the receiver + // window. The interval [una, nxt) is the acknowledgable range, whose + // right edge indicates the sequence number of the next byte to be sent + // by the sender, i.e., anything greater than or equal to nxt hasn't + // been sent yet. + una seqnum.Value + nxt seqnum.Value + end seqnum.Value + + // finSeen indicates if a FIN has already been sent on this stream. + finSeen bool + + // fin is the sequence number of the FIN. It is only valid after finSeen + // is set to true. + fin seqnum.Value + + // rstSeen indicates if a RST has already been sent on this stream. + rstSeen bool +} + +// acceptable determines if the segment with the given sequence number and data +// length is acceptable, i.e., if it's within the [una, end) window or, in case +// the window is zero, if it's a packet with no payload and sequence number +// equal to una. +func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + wnd := s.una.Size(s.end) + if wnd == 0 { + return segLen == 0 && segSeq == s.una + } + + // Make sure [segSeq, seqSeq+segLen) is non-empty. + if segLen == 0 { + segLen = 1 + } + + return seqnum.Overlap(s.una, wnd, segSeq, segLen) +} + +// closed determines if the stream has already been closed. This happens when +// a FIN has been set by the sender and acknowledged by the receiver. +func (s *stream) closed() bool { + return s.finSeen && s.fin.LessThan(s.una) +} + +// dataLen returns the length of the TCP segment payload. +func dataLen(tcp header.TCP) seqnum.Size { + return seqnum.Size(len(tcp) - int(tcp.DataOffset())) +} + +// logicalLen calculates the logical length of the TCP segment. +func logicalLen(tcp header.TCP) seqnum.Size { + l := dataLen(tcp) + flags := tcp.Flags() + if flags&header.TCPFlagSyn != 0 { + l++ + } + if flags&header.TCPFlagFin != 0 { + l++ + } + return l +} diff --git a/netstack/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/netstack/tcpip/transport/tcpconntrack/tcp_conntrack_test.go new file mode 100644 index 0000000..3e90edd --- /dev/null +++ b/netstack/tcpip/transport/tcpconntrack/tcp_conntrack_test.go @@ -0,0 +1,511 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpconntrack_test + +import ( + "testing" + + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/transport/tcpconntrack" +) + +// connected creates a connection tracker TCB and sets it to a connected state +// by performing a 3-way handshake. +func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: iss, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: irw, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: irs, + AckNum: iss + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: isw, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: iss + 1, + AckNum: irs + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: irw, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + return &tcb +} + +func TestConnectionRefused(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive RST. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionRefusedInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionResetInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestRetransmitOnSynSent(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Retransmit the same SYN. + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) + } +} + +func TestRetransmitOnSynRcvd(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. This will cause the state to go to SYN-RCVD. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Retransmit the original SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Transmit a SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} + +func TestClosedBySelf(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Send FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1236, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestClosedByPeer(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Receive FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 791, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) + } +} + +func TestSendAndReceiveDataClosedBySelf(t *testing.T) { + sseq := uint32(1234) + rseq := uint32(789) + tcb := connected(t, sseq, rseq, 30000, 50000) + sseq++ + rseq++ + + // Send some data. + tcp := make(header.TCP, header.TCPMinimumSize+1024) + + for i := uint32(0); i < 10; i++ { + // Send some data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + sseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + for i := uint32(0); i < 10; i++ { + // Receive some data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + rseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + // Send FIN. + tcp = tcp[:header.TCPMinimumSize] + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + sseq++ + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + rseq++ + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestIgnoreBadResetOnSynSent(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive a RST with a bad ACK, it should not cause the connection to + // be reset. + acks := []uint32{1234, 1236, 1000, 5000} + flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} + for _, a := range acks { + for _, f := range flags { + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: a, + DataOffset: header.TCPMinimumSize, + Flags: f, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + } + + // Complete the handshake. + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} diff --git a/netstack/tcpip/transport/udp/endpoint.go b/netstack/tcpip/transport/udp/endpoint.go new file mode 100644 index 0000000..434e409 --- /dev/null +++ b/netstack/tcpip/transport/udp/endpoint.go @@ -0,0 +1,980 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "log" + "math" + "sync" + + "tcpip/netstack/sleep" + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +// +stateify savable +// udp报文结构,当接收到udp报文时,会用这个结构来保存udp报文数据, +// 并插入到接收链表中。 +type udpPacket struct { + udpPacketEntry + senderAddress tcpip.FullAddress + data buffer.VectorisedView + timestamp int64 + hasTimestamp bool + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View +} + +type endpointState int + +// 表示UDP端的状态参数 +const ( + stateInitial endpointState = iota + stateBound + stateConnected + stateClosed +) + +// endpoint represents a UDP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// +stateify savable +// 表示UDP协议端的结构 +type endpoint struct { + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue, and are + // protected by rcvMu. + rcvMu sync.Mutex + rcvReady bool + rcvList udpPacketList + rcvBufSizeMax int + rcvBufSize int + rcvClosed bool + rcvTimestamp bool + + // The following fields are protected by the mu mutex. + mu sync.RWMutex + sndBufSize int + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + regNICID tcpip.NICID + route stack.Route + dstPort uint16 + v6only bool + multicastTTL uint8 + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // multicastMemberships that need to be remvoed when the endpoint is + // closed. Protected by the mu mutex. + multicastMemberships []multicastMembership + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber +} + +// 多播的成员关系,包括多播地址和网卡ID +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + +// newEndpoint 新建一个UDP端 +// 默认多播的TTL为1,且接收和发送的缓存为32k。 +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + return &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + // RFC 1075 section 5.4 recommends a TTL of 1 for membership + // requests. + // + // RFC 5135 4.2.1 appears to assume that IGMP messages have a + // TTL of 1. + // + // RFC 5135 Appendix A defines TTL=1: A multicast source that + // wants its traffic to not traverse a router (e.g., leave a + // home network) may find it useful to send traffic with IP + // TTL=1. + // + // Linux defaults to TTL=1. + multicastTTL: 1, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } +} + +// NewConnectedEndpoint creates a new endpoint in the connected state using the +// provided route. +func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := newEndpoint(stack, r.NetProto, waiterQueue) + + // Register new endpoint so that packets are routed to it. + if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil { + ep.Close() + return nil, err + } + + ep.id = id + ep.route = r.Clone() + ep.dstPort = id.RemotePort + ep.regNICID = r.NICID() + + ep.state = stateConnected + + return ep, nil +} + +// Close puts the endpoint in a closed state and frees all resources +// associated with it. +// UDP端的关闭,释放相应的资源 +func (e *endpoint) Close() { + e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite + + switch e.state { + case stateBound, stateConnected: + // 释放在协议栈中注册的UDP端 + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + // 释放端口占用 + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + } + + for _, mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + // 清空接收链表 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = stateClosed + + e.mu.Unlock() + + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// Read reads data from the endpoint. This method does not block if +// there is no data pending. +// 从UDP端读取消息,即使没有消息,也不会阻塞。 +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.rcvMu.Lock() + + // 如果接收链表为空,即没有任何数据 + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + // 从接收链表中取出最前面的数据报,接着从链表中删除该数据报 + // 然后减少接收缓存的大小 + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + ts := e.rcvTimestamp + + e.rcvMu.Unlock() + + if addr != nil { + // 赋值发送地址 + *addr = p.senderAddress + } + + if ts && !p.hasTimestamp { + // Linux uses the current time. + p.timestamp = e.stack.NowNanoseconds() + } + + // 返回数据报的内容 + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil +} + +// prepareForWrite prepares the endpoint for sending data. In particular, it +// binds it if it's still in the initial state. To do so, it must first +// reacquire the mutex in exclusive mode. +// +// Returns true for retry if preparation should be retried. +// 写数据之前的准备,如果还是初始状态需要先进性绑定操作。 +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case stateInitial: + case stateConnected: + return false, nil + + case stateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != stateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil { + return false, err + } + + return true, nil +} + +// Write writes data to the endpoint's peer. This method does not block +// if the data cannot be written. +// 用户层最终调用该函数,发送数据包给对端,即使数据写失败,也不会阻塞。 +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + // 如果报文长度超过65535,将会超过UDP最大的长度表示,这是不允许的。 + if p.Size() > math.MaxUint16 { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } + + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + // 如果设置了关闭写数据,那返回错误 + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, nil, tcpip.ErrClosedForSend + } + + // Prepare for write. + // 准备发送数据 + for { + retry, err := e.prepareForWrite(to) + if err != nil { + return 0, nil, err + } + + if !retry { + break + } + } + + var route *stack.Route + var dstPort uint16 + if to == nil { + // 如果没有指定发送的地址,用UDP端 Connect 得到的路由和目的端口 + route = &e.route + dstPort = e.dstPort + + if route.IsResolutionRequired() { + // Promote lock to exclusive if using a shared route, given that it may need to + // change in Route.Resolve() call below. + // 如果使用共享路由,则将锁定提升为独占路由,因为它可能需要在下面的Route.Resolve()调用中进行更改。 + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // Recheck state after lock was re-acquired. + // 锁定后重新检查状态。 + if e.state != stateConnected { + return 0, nil, tcpip.ErrInvalidEndpointState + } + } + } else { // 如果指定了发送的地址 + nicid := to.NIC + // 如果绑定了网卡,使用该网卡 + if e.bindNICID != 0 { + // 如果绑定的网卡和参数指定的网卡不同,那么返回错误 + if nicid != 0 && nicid != e.bindNICID { + return 0, nil, tcpip.ErrNoRoute + } + + nicid = e.bindNICID + } + + // 得到目的IP+端口 + toCopy := *to + to = &toCopy + netProto, err := e.checkV4Mapped(to, false) + if err != nil { + return 0, nil, err + } + + log.Printf("netProto: 0x%x", netProto) + // Find the enpoint. + // 根据目的地址和协议找到相关路由信息 + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, to.Addr, netProto) + if err != nil { + return 0, nil, err + } + defer r.Release() + + route = &r + dstPort = to.Port + } + + // 如果路由没有下一跳的链路MAC地址,那么触发相应的机制,来填充该路由信息。 + // 比如:IPV4协议,如果没有目的IP对应的MAC信息,从从ARP缓存中查找信息,找到了直接返回, + // 若没找到,那么发送ARP请求,得到对应的MAC地址。 + if route.IsResolutionRequired() { + waker := &sleep.Waker{} + if ch, err := route.Resolve(waker); err != nil { + if err == tcpip.ErrWouldBlock { + // Link address needs to be resolved. Resolution was triggered the background. + // Better luck next time. + route.RemoveWaker(waker) + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + // 得到要发送的数据内容 + v, err := p.Get(p.Size()) + if err != nil { + return 0, nil, err + } + + ttl := route.DefaultTTL() + // 如果是多播地址,设置ttl + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { + ttl = e.multicastTTL + } + + // 增加UDP头部信息,并发送出去 + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { + return 0, nil, err + } + + return uintptr(len(v)), nil, nil +} + +// Peek only returns data from a single datagram, so do nothing here. +func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// SetSockOpt sets a socket option. +// 给 UDP 套接字设置选项 +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v != 0 + + case tcpip.TimestampOption: + e.rcvMu.Lock() + e.rcvTimestamp = v != 0 + e.rcvMu.Unlock() + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() + + case tcpip.AddMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + + e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) + + case tcpip.RemoveMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + for i, mem := range e.multicastMemberships { + if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr { + // Only remove the first match, so that each added membership above is + // paired with exactly 1 removal. + e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + break + } + } + } + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +// 从 UDP 端获取选项参数 +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + e.mu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + e.rcvMu.Unlock() + return nil + + case *tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.ReceiveQueueSizeOption: + e.rcvMu.Lock() + if e.rcvList.Empty() { + *o = 0 + } else { + p := e.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + e.rcvMu.Unlock() + return nil + + case *tcpip.TimestampOption: + e.rcvMu.Lock() + *o = 0 + if e.rcvTimestamp { + *o = 1 + } + e.rcvMu.Unlock() + + case *tcpip.MulticastTTLOption: + e.mu.Lock() + *o = tcpip.MulticastTTLOption(e.multicastTTL) + e.mu.Unlock() + return nil + } + + return tcpip.ErrUnknownProtocolOption +} + +// sendUDP sends a UDP segment via the provided network endpoint and under the +// provided identity. +// 增加UDP头部信息,并发送给给网络层 +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { + // Allocate a buffer for the UDP header. + hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + + // Initialize the header. + udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + + // 得到报文的长度 + length := uint16(hdr.UsedLength() + data.Size()) + // UDP首部的编码 + udp.Encode(&header.UDPFields{ + SrcPort: localPort, + DstPort: remotePort, + Length: length, + }) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { + // 检验和的计算 + xsum := r.PseudoHeaderChecksum(ProtocolNumber) + for _, v := range data.Views() { + xsum = header.Checksum(v, xsum) + } + udp.SetChecksum(^udp.CalculateChecksum(xsum, length)) + } + + // Track count of packets sent. + r.Stats().UDP.PacketsSent.Increment() + + // 将准备好的UDP首部和数据写给网络层 + log.Printf("send udp %d bytes", hdr.UsedLength()+data.Size()) + return r.WritePacket(hdr, data, ProtocolNumber, ttl) +} + +// IPV6于IPV4地址的映射 +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + + // Fail if we are bound to an IPv6 address. + if !allowMismatch && len(e.id.LocalAddress) == 16 { + return 0, tcpip.ErrNetworkUnreachable + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. Specifying a NIC is optional. +// 连接UDP对端,当发送数据的时候用这个地址来填充目标地址 +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // 目标端口为0是错误的 + if addr.Port == 0 { + // We don't support connecting to port zero. + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + nicid := addr.NIC + var localPort uint16 + // 判断UDP端的状态 + switch e.state { + case stateInitial: + // 如果是初始状态,直接下一步 + case stateBound, stateConnected: + // 如果已经绑定或者已连接状态 + localPort = e.id.LocalPort + if e.bindNICID == 0 { + break + } + + if nicid != 0 && nicid != e.bindNICID { + return tcpip.ErrInvalidEndpointState + } + + nicid = e.bindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + // 检查地址的映射,得到相应的协议 + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + + // Find a route to the desired destination. + // 在全局协议栈中查找路由 + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) + if err != nil { + return err + } + defer r.Release() + + // 新建一个传输端的标识,包括源IP、源端口、目的IP、目的端口 + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + LocalPort: localPort, + RemotePort: addr.Port, + RemoteAddress: r.RemoteAddress, + } + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + // 设置网络层协议,IPV4或IPV6,或两者都有 + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } + } + + // 将该UDP端注册到协议栈中 + id, err = e.registerWithStack(nicid, netProtos, id) + if err != nil { + return err + } + + // Remove the old registration. + // 如果源端口不为0,则尝试在传输层端中删除老的UDP端 + if e.id.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + } + + // 赋值UDP端的属性 + e.id = id + e.route = r.Clone() + e.dstPort = addr.Port + e.regNICID = nicid + e.effectiveNetProtos = netProtos + + // 更改该UDP端的状态为已连接 + e.state = stateConnected + + // 标志该UDP端可以接收数据了 + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection +// to its peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // A socket in the bound state can still receive multicast messages, + // so we need to notify waiters on shutdown. + if e.state != stateBound && e.state != stateConnected { + return tcpip.ErrNotConnected + } + + e.shutdownFlags |= flags + + if flags&tcpip.ShutdownRead != 0 { + e.rcvMu.Lock() + wasClosed := e.rcvClosed + e.rcvClosed = true + e.rcvMu.Unlock() + + if !wasClosed { + e.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen is not supported by UDP, it just fails. +// UDP是没有Listen的。 +func (*endpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept is not supported by UDP, it just fails. +// 没有Listen,自然没有Accept。 +func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// 在协议栈中注册该UDP端,并且分配源端口 +func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, + id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if e.id.LocalPort == 0 { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + if err != nil { + return id, err + } + id.LocalPort = port + } + + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + } + return id, err +} + +// 根据addr参数绑定 +func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + // Don't allow binding once endpoint is not in the initial state + // anymore. + // 不是初始状态的UDP端不让绑定 + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, true) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + if len(addr.Addr) != 0 { + // A local address was specified, verify that it's valid. + if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + } + + // 绑定的时候传输端ID是源IP+源端口 + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + // 在协议中注册该UDP端 + id, err = e.registerWithStack(addr.NIC, netProtos, id) + if err != nil { + return err + } + // 如果指定了 commit 函数,执行并处理错误 + if commit != nil { + if err := commit(); err != nil { + // Unregister, the commit failed. + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + return err + } + } + + e.id = id + e.regNICID = addr.NIC + e.effectiveNetProtos = netProtos + + // Mark endpoint as bound. + // 标记状态为已绑定 + e.state = stateBound + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// Bind binds the endpoint to a specific local address and port. +// Specifying a NIC is optional. +// Bind 将该UDP端绑定本地的一个IP+端口 +// 例如:绑定本地0.0.0.0的9000端口,那么其他机器给这台机器9000端口发消息,该UDP端就能收到消息了。 +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // 执行绑定IP+端口操作 + err := e.bindLocked(addr, commit) + if err != nil { + return err + } + + // 绑定的网卡ID + e.bindNICID = addr.NIC + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + }, nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +// 从网络层接收到UDP数据报时的处理函数 +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + // 错误报文 + e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + return + } + + // 去除UDP首部 + vv.TrimFront(header.UDPMinimumSize) + + e.rcvMu.Lock() + e.stack.Stats().UDP.PacketsReceived.Increment() + + // Drop the packet if our buffer is currently full. + // 如果UDP的接收缓存已经满了,那么丢弃报文。 + if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.rcvMu.Unlock() + return + } + + // 接收缓存是否为空 + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + // 新建一个UDP 数据报结构,插入到接收链表里 + pkt := &udpPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + Port: hdr.SourcePort(), + }, + } + // 复制UDP数据的用户数据 + pkt.data = vv.Clone(pkt.views[:]) + // 插入到接收链表里,并增加已使用的缓存 + e.rcvList.PushBack(pkt) + e.rcvBufSize += vv.Size() + + if e.rcvTimestamp { + pkt.timestamp = e.stack.NowNanoseconds() + pkt.hasTimestamp = true + } + + e.rcvMu.Unlock() + + // Notify any waiters that there's data to be read now. + // 通知程序现在可以读取到数据了 + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } + log.Printf("recv udp %d bytes", hdr.Length()) +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +} diff --git a/netstack/tcpip/transport/udp/protocol.go b/netstack/tcpip/transport/udp/protocol.go new file mode 100644 index 0000000..8e832e7 --- /dev/null +++ b/netstack/tcpip/transport/udp/protocol.go @@ -0,0 +1,85 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package udp contains the implementation of the UDP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing udp.ProtocolName (or "udp") as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing udp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package udp + +import ( + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/waiter" +) + +const ( + // ProtocolName is the string representation of the udp protocol name. + ProtocolName = "udp" + + // ProtocolNumber is the udp protocol number. + ProtocolNumber = header.UDPProtocolNumber +) + +// tcpip.Endpoint 接口的UDP协议实现 +type protocol struct{} + +// Number returns the udp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new udp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, + waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// MinimumPacketSize returns the minimum valid udp packet size. +func (*protocol) MinimumPacketSize() int { + return header.UDPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given udp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.UDP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { + return true +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { + return &protocol{} + }) +} diff --git a/netstack/tcpip/transport/udp/udp_packet_list.go b/netstack/tcpip/transport/udp/udp_packet_list.go new file mode 100644 index 0000000..6eeafff --- /dev/null +++ b/netstack/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,174 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +// udp数据报的双向链表结构 +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *udpPacketList) PushFront(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + udpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *udpPacketList) PushBack(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(nil) + udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + a := udpPacketElementMapper{}.linkerFor(b).Next() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + b := udpPacketElementMapper{}.linkerFor(a).Prev() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *udpPacketList) Remove(e *udpPacket) { + prev := udpPacketElementMapper{}.linkerFor(e).Prev() + next := udpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/netstack/tcpip/transport/udp/udp_test.go b/netstack/tcpip/transport/udp/udp_test.go new file mode 100644 index 0000000..8729fa6 --- /dev/null +++ b/netstack/tcpip/transport/udp/udp_test.go @@ -0,0 +1,928 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_test + +import ( + "bytes" + "math/rand" + "testing" + "time" + + "tcpip/netstack/tcpip" + "tcpip/netstack/tcpip/buffer" + "tcpip/netstack/tcpip/checker" + "tcpip/netstack/tcpip/header" + "tcpip/netstack/tcpip/link/channel" + "tcpip/netstack/tcpip/link/sniffer" + "tcpip/netstack/tcpip/network/ipv4" + "tcpip/netstack/tcpip/network/ipv6" + "tcpip/netstack/tcpip/stack" + "tcpip/netstack/tcpip/transport/udp" + "tcpip/netstack/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr + testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr + multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testAddr = "\x0a\x00\x00\x02" + testPort = 4096 + multicastAddr = "\xe8\x2b\xd3\xea" + multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + multicastPort = 1234 + + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65536 +) + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newDualTestContext(t *testing.T, mtu uint32) *testContext { + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{}) + + id, linkEP := channel.New(256, mtu, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + { + Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEP: linkEP, + } +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } +} + +func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != protocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker) + var srcAddr, dstAddr tcpip.Address + switch protocolNumber { + case ipv4.ProtocolNumber: + checkerFn = checker.IPv4 + srcAddr, dstAddr = stackAddr, testAddr + if multicast { + dstAddr = multicastAddr + } + case ipv6.ProtocolNumber: + checkerFn = checker.IPv6 + srcAddr, dstAddr = stackV6Addr, testV6Addr + if multicast { + dstAddr = multicastV6Addr + } + default: + c.t.Fatalf("unknown protocol %d", protocolNumber) + } + checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.Checksum([]byte(testV6Addr), 0) + xsum = header.Checksum([]byte(stackV6Addr), xsum) + xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum) + + // Calculate the UDP checksum and set it. + length := uint16(header.UDPMinimumSize + len(payload)) + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum, length)) + + // Inject packet. + c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) +} + +func (c *testContext) sendPacket(payload []byte, h *headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: testAddr, + DstAddr: stackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.Checksum([]byte(testAddr), 0) + xsum = header.Checksum([]byte(stackAddr), xsum) + xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum) + + // Calculate the UDP checksum and set it. + length := uint16(header.UDPMinimumSize + len(payload)) + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum, length)) + + // Inject packet. + c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func testV4Read(c *testContext) { + // Send a packet. + payload := newPayload() + c.sendPacket(payload, &headers{ + srcPort: testPort, + dstPort: stackPort, + }) + + // Try to receive the data. + we, ch := waiter.NewChannelEntry(nil) + c.wq.EventRegister(&we, waiter.EventIn) + defer c.wq.EventUnregister(&we) + + var addr tcpip.FullAddress + v, _, err := c.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + // Wait for data to become available. + select { + case <-ch: + v, _, err = c.ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for data") + } + } + + // Check the peer address. + if addr.Addr != testAddr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + } + + // Check the payload. + if !bytes.Equal(payload, v) { + c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + } +} + +func TestBindEphemeralPort(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + if err := c.ep.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } +} + +func TestBindReservedPort(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + addr, err := c.ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress failed: %v", err) + } + + // We can't bind the address reserved by the connected endpoint above. + { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + if got, want := ep.Bind(addr, nil), tcpip.ErrPortInUse; got != want { + t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + } + } + + func() { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + // We can't bind ipv4-any on the port reserved by the connected endpoint + // above, since the endpoint is dual-stack. + if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil), tcpip.ErrPortInUse; got != want { + t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + } + // We can bind an ipv4 address on this port, though. + if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + }() + + // Once the connected endpoint releases its port reservation, we are able to + // bind ipv4-any once again. + c.ep.Close() + func() { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil); err != nil { + t.Fatalf("ep.Bind(...) failed: %v", err) + } + }() +} + +func TestV4ReadOnV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func TestV4ReadOnBoundToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to local address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func TestV6ReadOnV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Send a packet. + payload := newPayload() + c.sendV6Packet(payload, &headers{ + srcPort: testPort, + dstPort: stackPort, + }) + + // Try to receive the data. + we, ch := waiter.NewChannelEntry(nil) + c.wq.EventRegister(&we, waiter.EventIn) + defer c.wq.EventUnregister(&we) + + var addr tcpip.FullAddress + v, _, err := c.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + // Wait for data to become available. + select { + case <-ch: + v, _, err = c.ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for data") + } + } + + // Check the peer address. + if addr.Addr != testV6Addr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + } + + // Check the payload. + if !bytes.Equal(payload, v) { + c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + } +} + +func TestV4ReadOnV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + // Create v4 UDP endpoint. + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func testV4Write(c *testContext) uint16 { + // Write to V4 mapped address. + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, + }) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet. + b := c.getPacket(ipv4.ProtocolNumber, false) + udp := header.UDP(header.IPv4(b).Payload()) + checker.IPv4(c.t, b, + checker.UDP( + checker.DstPort(testPort), + ), + ) + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } + + return udp.SourcePort() +} + +func testV6Write(c *testContext) uint16 { + // Write to v6 address. + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + }) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet. + b := c.getPacket(ipv6.ProtocolNumber, false) + udp := header.UDP(header.IPv6(b).Payload()) + checker.IPv6(c.t, b, + checker.UDP( + checker.DstPort(testPort), + ), + ) + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } + + return udp.SourcePort() +} + +func testDualWrite(c *testContext) uint16 { + v4Port := testV4Write(c) + v6Port := testV6Write(c) + if v4Port != v6Port { + c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) + } + + return v4Port +} + +func TestDualWriteUnbound(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + testDualWrite(c) +} + +func TestDualWriteBoundToWildcard(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + p := testDualWrite(c) + if p != stackPort { + c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) + } +} + +func TestDualWriteConnectedToV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v6 address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + testV6Write(c) + + // Write to V4 mapped address. + payload := buffer.View(newPayload()) + _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, + }) + if err != tcpip.ErrNetworkUnreachable { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) + } +} + +func TestDualWriteConnectedToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v4 mapped address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + testV4Write(c) + + // Write to v6 address. + payload := buffer.View(newPayload()) + _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + }) + if err != tcpip.ErrInvalidEndpointState { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) + } +} + +func TestV4WriteOnV6Only(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(true) + + // Write to V4 mapped address. + payload := buffer.View(newPayload()) + _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, + }) + if err != tcpip.ErrNoRoute { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) + } +} + +func TestV6WriteOnBoundToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to v4 mapped address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Write to v6 address. + payload := buffer.View(newPayload()) + _, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + }) + if err != tcpip.ErrInvalidEndpointState { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) + } +} + +func TestV6WriteOnConnected(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v6 address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + // Write without destination. + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet. + b := c.getPacket(ipv6.ProtocolNumber, false) + udp := header.UDP(header.IPv6(b).Payload()) + checker.IPv6(c.t, b, + checker.UDP( + checker.DstPort(testPort), + ), + ) + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } +} + +func TestV4WriteOnConnected(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v4 mapped address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + // Write without destination. + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet. + b := c.getPacket(ipv4.ProtocolNumber, false) + udp := header.UDP(header.IPv4(b).Payload()) + checker.IPv4(c.t, b, + checker.UDP( + checker.DstPort(testPort), + ), + ) + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } +} + +func TestReadIncrementsPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + // Create IPv4 UDP endpoint + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + testV4Read(c) + + var want uint64 = 1 + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { + c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) + } +} + +func TestWriteIncrementsPacketsSent(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + testDualWrite(c) + + var want uint64 = 2 + if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { + c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) + } +} + +func TestTTL(t *testing.T) { + payload := tcpip.SlicePayload(buffer.View(newPayload())) + + for _, name := range []string{"v4", "v6", "dual"} { + t.Run(name, func(t *testing.T) { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch name { + case "v4": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + var variants []string + switch name { + case "v4": + variants = []string{"v4"} + case "v6": + variants = []string{"v6"} + case "dual": + variants = []string{"v6", "mapped"} + } + + for _, variant := range variants { + t.Run(variant, func(t *testing.T) { + for _, typ := range []string{"unicast", "multicast"} { + t.Run(typ, func(t *testing.T) { + var addr tcpip.Address + var port uint16 + switch typ { + case "unicast": + port = testPort + switch variant { + case "v4": + addr = testAddr + case "mapped": + addr = testV4MappedAddr + case "v6": + addr = testV6Addr + default: + t.Fatal("unknown test variant") + } + case "multicast": + port = multicastPort + switch variant { + case "v4": + addr = multicastAddr + case "mapped": + addr = multicastV4MappedAddr + case "v6": + addr = multicastV6Addr + default: + t.Fatal("unknown test variant") + } + default: + t.Fatal("unknown test variant") + } + + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + switch name { + case "v4": + case "v6": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + case "dual": + if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + default: + t.Fatal("unknown test variant") + } + + const multicastTTL = 42 + if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } + + n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload)) + } + + checkerFn := checker.IPv4 + switch variant { + case "v4", "mapped": + case "v6": + checkerFn = checker.IPv6 + default: + t.Fatal("unknown test variant") + } + var wantTTL uint8 + var multicast bool + switch typ { + case "unicast": + multicast = false + switch variant { + case "v4", "mapped": + ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + case "v6": + ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + default: + t.Fatal("unknown test variant") + } + case "multicast": + wantTTL = multicastTTL + multicast = true + default: + t.Fatal("unknown test variant") + } + + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch variant { + case "v4", "mapped": + networkProtocolNumber = ipv4.ProtocolNumber + case "v6": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatal("unknown test variant") + } + + b := c.getPacket(networkProtocolNumber, multicast) + checkerFn(c.t, b, + checker.TTL(wantTTL), + checker.UDP( + checker.DstPort(port), + ), + ) + }) + } + }) + } + }) + } +} diff --git a/netstack/tmutex/tmutex.go b/netstack/tmutex/tmutex.go new file mode 100644 index 0000000..76abdde --- /dev/null +++ b/netstack/tmutex/tmutex.go @@ -0,0 +1,81 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tmutex provides the implementation of a mutex that implements an +// efficient TryLock function in addition to Lock and Unlock. +package tmutex + +import ( + "sync/atomic" +) + +// Mutex is a mutual exclusion primitive that implements TryLock in addition +// to Lock and Unlock. +type Mutex struct { + v int32 + ch chan struct{} +} + +// Init initializes the mutex. +func (m *Mutex) Init() { + m.v = 1 + m.ch = make(chan struct{}, 1) +} + +// Lock acquires the mutex. If it is currently held by another goroutine, Lock +// will wait until it has a chance to acquire it. +func (m *Mutex) Lock() { + // Uncontended case. + if atomic.AddInt32(&m.v, -1) == 0 { + return + } + + for { + // Try to acquire the mutex again, at the same time making sure + // that m.v is negative, which indicates to the owner of the + // lock that it is contended, which will force it to try to wake + // someone up when it releases the mutex. + if v := atomic.LoadInt32(&m.v); v >= 0 && atomic.SwapInt32(&m.v, -1) == 1 { + return + } + + // Wait for the mutex to be released before trying again. + <-m.ch + } +} + +// TryLock attempts to acquire the mutex without blocking. If the mutex is +// currently held by another goroutine, it fails to acquire it and returns +// false. +func (m *Mutex) TryLock() bool { + v := atomic.LoadInt32(&m.v) + if v <= 0 { + return false + } + return atomic.CompareAndSwapInt32(&m.v, 1, 0) +} + +// Unlock releases the mutex. +func (m *Mutex) Unlock() { + if atomic.SwapInt32(&m.v, 1) == 0 { + // There were no pending waiters. + return + } + + // Wake some waiter up. + select { + case m.ch <- struct{}{}: + default: + } +} diff --git a/netstack/tmutex/tmutex_test.go b/netstack/tmutex/tmutex_test.go new file mode 100644 index 0000000..a5814fc --- /dev/null +++ b/netstack/tmutex/tmutex_test.go @@ -0,0 +1,257 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmutex + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestBasicLock(t *testing.T) { + var m Mutex + m.Init() + + m.Lock() + + // Try blocking lock the mutex from a different goroutine. This must + // not block because the mutex is held. + ch := make(chan struct{}, 1) + go func() { + m.Lock() + ch <- struct{}{} + m.Unlock() + ch <- struct{}{} + }() + + select { + case <-ch: + t.Fatalf("Lock succeeded on locked mutex") + case <-time.After(100 * time.Millisecond): + } + + // Unlock the mutex and make sure that the goroutine waiting on Lock() + // unblocks and succeeds. + m.Unlock() + + select { + case <-ch: + case <-time.After(100 * time.Millisecond): + t.Fatalf("Lock failed to acquire unlocked mutex") + } + + // Make sure we can lock and unlock again. + m.Lock() + m.Unlock() +} + +func TestTryLock(t *testing.T) { + var m Mutex + m.Init() + + // Try to lock. It should succeed. + if !m.TryLock() { + t.Fatalf("TryLock failed on unlocked mutex") + } + + // Try to lock again, it should now fail. + if m.TryLock() { + t.Fatalf("TryLock succeeded on locked mutex") + } + + // Try blocking lock the mutex from a different goroutine. This must + // not block because the mutex is held. + ch := make(chan struct{}, 1) + go func() { + m.Lock() + ch <- struct{}{} + m.Unlock() + }() + + select { + case <-ch: + t.Fatalf("Lock succeeded on locked mutex") + case <-time.After(100 * time.Millisecond): + } + + // Unlock the mutex and make sure that the goroutine waiting on Lock() + // unblocks and succeeds. + m.Unlock() + + select { + case <-ch: + case <-time.After(100 * time.Millisecond): + t.Fatalf("Lock failed to acquire unlocked mutex") + } +} + +func TestMutualExclusion(t *testing.T) { + var m Mutex + m.Init() + + // Test mutual exclusion by running "gr" goroutines concurrently, and + // have each one increment a counter "iters" times within the critical + // section established by the mutex. + // + // If at the end the counter is not gr * iters, then we know that + // goroutines ran concurrently within the critical section. + // + // If one of the goroutines doesn't complete, it's likely a bug that + // causes to it to wait forever. + const gr = 1000 + const iters = 100000 + v := 0 + var wg sync.WaitGroup + for i := 0; i < gr; i++ { + wg.Add(1) + go func() { + for j := 0; j < iters; j++ { + m.Lock() + v++ + m.Unlock() + } + wg.Done() + }() + } + + wg.Wait() + + if v != gr*iters { + t.Fatalf("Bad count: got %v, want %v", v, gr*iters) + } +} + +func TestMutualExclusionWithTryLock(t *testing.T) { + var m Mutex + m.Init() + + // Similar to the previous, with the addition of some goroutines that + // only increment the count if TryLock succeeds. + const gr = 1000 + const iters = 100000 + total := int64(gr * iters) + var tryTotal int64 + v := int64(0) + var wg sync.WaitGroup + for i := 0; i < gr; i++ { + wg.Add(2) + go func() { + for j := 0; j < iters; j++ { + m.Lock() + v++ + m.Unlock() + } + wg.Done() + }() + go func() { + local := int64(0) + for j := 0; j < iters; j++ { + if m.TryLock() { + v++ + m.Unlock() + local++ + } + } + atomic.AddInt64(&tryTotal, local) + wg.Done() + }() + } + + wg.Wait() + + t.Logf("tryTotal = %d", tryTotal) + total += tryTotal + + if v != total { + t.Fatalf("Bad count: got %v, want %v", v, total) + } +} + +// BenchmarkTmutex is equivalent to TestMutualExclusion, with the following +// differences: +// +// - The number of goroutines is variable, with the maximum value depending on +// GOMAXPROCS. +// +// - The number of iterations per benchmark is controlled by the benchmarking +// framework. +// +// - Care is taken to ensure that all goroutines participating in the benchmark +// have been created before the benchmark begins. +func BenchmarkTmutex(b *testing.B) { + for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var m Mutex + m.Init() + + var ready sync.WaitGroup + begin := make(chan struct{}) + var end sync.WaitGroup + for i := 0; i < n; i++ { + ready.Add(1) + end.Add(1) + go func() { + ready.Done() + <-begin + for j := 0; j < b.N; j++ { + m.Lock() + m.Unlock() + } + end.Done() + }() + } + + ready.Wait() + b.ResetTimer() + close(begin) + end.Wait() + }) + } +} + +// BenchmarkSyncMutex is equivalent to BenchmarkTmutex, but uses sync.Mutex as +// a comparison point. +func BenchmarkSyncMutex(b *testing.B) { + for n, max := 1, 4*runtime.GOMAXPROCS(0); n > 0 && n <= max; n *= 2 { + b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { + var m sync.Mutex + + var ready sync.WaitGroup + begin := make(chan struct{}) + var end sync.WaitGroup + for i := 0; i < n; i++ { + ready.Add(1) + end.Add(1) + go func() { + ready.Done() + <-begin + for j := 0; j < b.N; j++ { + m.Lock() + m.Unlock() + } + end.Done() + }() + } + + ready.Wait() + b.ResetTimer() + close(begin) + end.Wait() + }) + } +} diff --git a/netstack/waiter/waiter.go b/netstack/waiter/waiter.go new file mode 100644 index 0000000..104bee0 --- /dev/null +++ b/netstack/waiter/waiter.go @@ -0,0 +1,240 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package waiter provides the implementation of a wait queue, where waiters can +// be enqueued to be notified when an event of interest happens. +// +// Becoming readable and/or writable are examples of events. Waiters are +// expected to use a pattern similar to this to make a blocking function out of +// a non-blocking one: +// +// func (o *object) blockingRead(...) error { +// err := o.nonBlockingRead(...) +// if err != ErrAgain { +// // Completed with no need to wait! +// return err +// } +// +// e := createOrGetWaiterEntry(...) +// o.EventRegister(&e, waiter.EventIn) +// defer o.EventUnregister(&e) +// +// // We need to try to read again after registration because the +// // object may have become readable between the last attempt to +// // read and read registration. +// err = o.nonBlockingRead(...) +// for err == ErrAgain { +// wait() +// err = o.nonBlockingRead(...) +// } +// +// return err +// } +// +// Another goroutine needs to notify waiters when events happen. For example: +// +// func (o *object) Write(...) ... { +// // Do write work. +// [...] +// +// if oldDataAvailableSize == 0 && dataAvailableSize > 0 { +// // If no data was available and now some data is +// // available, the object became readable, so notify +// // potential waiters about this. +// o.Notify(waiter.EventIn) +// } +// } +package waiter + +import ( + "sync" + + "tcpip/netstack/ilist" +) + +// EventMask represents io events as used in the poll() syscall. +type EventMask uint16 + +// Events that waiters can wait on. The meaning is the same as those in the +// poll() syscall. +const ( + EventIn EventMask = 0x01 // syscall.EPOLLIN + EventPri EventMask = 0x02 // syscall.EPOLLPRI + EventOut EventMask = 0x04 // syscall.EPOLLOUT + EventErr EventMask = 0x08 // syscall.EPOLLERR + EventHUp EventMask = 0x10 // syscall.EPOLLHUP + EventNVal EventMask = 0x20 // Not defined in syscall. +) + +// Waitable contains the methods that need to be implemented by waitable +// objects. +type Waitable interface { + // Readiness returns what the object is currently ready for. If it's + // not ready for a desired purpose, the caller may use EventRegister and + // EventUnregister to get notifications once the object becomes ready. + // + // Implementations should allow for events like EventHUp and EventErr + // to be returned regardless of whether they are in the input EventMask. + Readiness(mask EventMask) EventMask + + // EventRegister registers the given waiter entry to receive + // notifications when an event occurs that makes the object ready for + // at least one of the events in mask. + EventRegister(e *Entry, mask EventMask) + + // EventUnregister unregisters a waiter entry previously registered with + // EventRegister(). + EventUnregister(e *Entry) +} + +// EntryCallback provides a notify callback. +type EntryCallback interface { + // Callback is the function to be called when the waiter entry is + // notified. It is responsible for doing whatever is needed to wake up + // the waiter. + // + // The callback is supposed to perform minimal work, and cannot call + // any method on the queue itself because it will be locked while the + // callback is running. + Callback(e *Entry) +} + +// Entry represents a waiter that can be add to the a wait queue. It can +// only be in one queue at a time, and is added "intrusively" to the queue with +// no extra memory allocations. +// +// +stateify savable +type Entry struct { + // Context stores any state the waiter may wish to store in the entry + // itself, which may be used at wake up time. + // + // Note that use of this field is optional and state may alternatively be + // stored in the callback itself. + Context interface{} + + Callback EntryCallback + + // The following fields are protected by the queue lock. + mask EventMask + ilist.Entry +} + +type channelCallback struct{} + +// Callback implements EntryCallback.Callback. +func (*channelCallback) Callback(e *Entry) { + ch := e.Context.(chan struct{}) + select { + case ch <- struct{}{}: + default: + } +} + +// NewChannelEntry initializes a new Entry that does a non-blocking write to a +// struct{} channel when the callback is called. It returns the new Entry +// instance and the channel being used. +// +// If a channel isn't specified (i.e., if "c" is nil), then NewChannelEntry +// allocates a new channel. +func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { + if c == nil { + c = make(chan struct{}, 1) + } + + return Entry{Context: c, Callback: &channelCallback{}}, c +} + +// Queue represents the wait queue where waiters can be added and +// notifiers can notify them when events happen. +// +// The zero value for waiter.Queue is an empty queue ready for use. +// +// +stateify savable +type Queue struct { + list ilist.List + mu sync.RWMutex +} + +// EventRegister adds a waiter to the wait queue; the waiter will be notified +// when at least one of the events specified in mask happens. +func (q *Queue) EventRegister(e *Entry, mask EventMask) { + q.mu.Lock() + e.mask = mask + q.list.PushBack(e) + q.mu.Unlock() +} + +// EventUnregister removes the given waiter entry from the wait queue. +func (q *Queue) EventUnregister(e *Entry) { + q.mu.Lock() + q.list.Remove(e) + q.mu.Unlock() +} + +// Notify notifies all waiters in the queue whose masks have at least one bit +// in common with the notification mask. +func (q *Queue) Notify(mask EventMask) { + q.mu.RLock() + for it := q.list.Front(); it != nil; it = it.Next() { + e := it.(*Entry) + if mask&e.mask != 0 { + e.Callback.Callback(e) + } + } + q.mu.RUnlock() +} + +// Events returns the set of events being waited on. It is the union of the +// masks of all registered entries. +func (q *Queue) Events() EventMask { + ret := EventMask(0) + + q.mu.RLock() + for it := q.list.Front(); it != nil; it = it.Next() { + e := it.(*Entry) + ret |= e.mask + } + q.mu.RUnlock() + + return ret +} + +// IsEmpty returns if the wait queue is empty or not. +func (q *Queue) IsEmpty() bool { + q.mu.Lock() + defer q.mu.Unlock() + + return q.list.Front() == nil +} + +// AlwaysReady implements the Waitable interface but is always ready. Embedding +// this struct into another struct makes it implement the boilerplate empty +// functions automatically. +type AlwaysReady struct { +} + +// Readiness always returns the input mask because this object is always ready. +func (*AlwaysReady) Readiness(mask EventMask) EventMask { + return mask +} + +// EventRegister doesn't do anything because this object doesn't need to issue +// notifications because its readiness never changes. +func (*AlwaysReady) EventRegister(*Entry, EventMask) { +} + +// EventUnregister doesn't do anything because this object doesn't need to issue +// notifications because its readiness never changes. +func (*AlwaysReady) EventUnregister(e *Entry) { +} diff --git a/netstack/waiter/waiter_test.go b/netstack/waiter/waiter_test.go new file mode 100644 index 0000000..d0590e9 --- /dev/null +++ b/netstack/waiter/waiter_test.go @@ -0,0 +1,192 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package waiter + +import ( + "sync/atomic" + "testing" +) + +type callbackStub struct { + f func(e *Entry) +} + +// Callback implements EntryCallback.Callback. +func (c *callbackStub) Callback(e *Entry) { + c.f(e) +} + +func TestEmptyQueue(t *testing.T) { + var q Queue + + // Notify the zero-value of a queue. + q.Notify(EventIn) + + // Register then unregister a waiter, then notify the queue. + cnt := 0 + e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + q.EventRegister(&e, EventIn) + q.EventUnregister(&e) + q.Notify(EventIn) + if cnt != 0 { + t.Errorf("Callback was called when it shouldn't have been") + } +} + +func TestMask(t *testing.T) { + // Register a waiter. + var q Queue + var cnt int + e := Entry{Callback: &callbackStub{func(*Entry) { cnt++ }}} + q.EventRegister(&e, EventIn|EventErr) + + // Notify with an overlapping mask. + cnt = 0 + q.Notify(EventIn | EventOut) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a subset mask. + cnt = 0 + q.Notify(EventIn) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a superset mask. + cnt = 0 + q.Notify(EventIn | EventErr | EventOut) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with the exact same mask. + cnt = 0 + q.Notify(EventIn | EventErr) + if cnt != 1 { + t.Errorf("Callback wasn't called when it should have been") + } + + // Notify with a disjoint mask. + cnt = 0 + q.Notify(EventOut | EventHUp) + if cnt != 0 { + t.Errorf("Callback was called when it shouldn't have been") + } +} + +func TestConcurrentRegistration(t *testing.T) { + var q Queue + var cnt int + const concurrency = 1000 + + ch1 := make(chan struct{}) + ch2 := make(chan struct{}) + ch3 := make(chan struct{}) + + // Create goroutines that will all register/unregister concurrently. + for i := 0; i < concurrency; i++ { + go func() { + var e Entry + e.Callback = &callbackStub{func(entry *Entry) { + cnt++ + if entry != &e { + t.Errorf("entry = %p, want %p", entry, &e) + } + }} + + // Wait for notification, then register. + <-ch1 + q.EventRegister(&e, EventIn|EventErr) + + // Tell main goroutine that we're done registering. + ch2 <- struct{}{} + + // Wait for notification, then unregister. + <-ch3 + q.EventUnregister(&e) + + // Tell main goroutine that we're done unregistering. + ch2 <- struct{}{} + }() + } + + // Let the goroutines register. + close(ch1) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Issue a notification. + q.Notify(EventIn) + if cnt != concurrency { + t.Errorf("cnt = %d, want %d", cnt, concurrency) + } + + // Let the goroutine unregister. + close(ch3) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Issue a notification. + q.Notify(EventIn) + if cnt != concurrency { + t.Errorf("cnt = %d, want %d", cnt, concurrency) + } +} + +func TestConcurrentNotification(t *testing.T) { + var q Queue + var cnt int32 + const concurrency = 1000 + const waiterCount = 1000 + + // Register waiters. + for i := 0; i < waiterCount; i++ { + var e Entry + e.Callback = &callbackStub{func(entry *Entry) { + atomic.AddInt32(&cnt, 1) + if entry != &e { + t.Errorf("entry = %p, want %p", entry, &e) + } + }} + + q.EventRegister(&e, EventIn|EventErr) + } + + // Launch notifiers. + ch1 := make(chan struct{}) + ch2 := make(chan struct{}) + for i := 0; i < concurrency; i++ { + go func() { + <-ch1 + q.Notify(EventIn) + ch2 <- struct{}{} + }() + } + + // Let notifiers go. + close(ch1) + for i := 0; i < concurrency; i++ { + <-ch2 + } + + // Check the count. + if cnt != concurrency*waiterCount { + t.Errorf("cnt = %d, want %d", cnt, concurrency*waiterCount) + } +} diff --git a/sleep/sleep_unsafe.go b/sleep/sleep_unsafe.go deleted file mode 100644 index 21dcfb0..0000000 --- a/sleep/sleep_unsafe.go +++ /dev/null @@ -1,4 +0,0 @@ -package sleep - -type Waker struct { -} diff --git a/tcpip/buffer/prependable.go b/src_code/tcpip_bak/buffer/prependable.go similarity index 100% rename from tcpip/buffer/prependable.go rename to src_code/tcpip_bak/buffer/prependable.go diff --git a/tcpip/buffer/view.go b/src_code/tcpip_bak/buffer/view.go similarity index 100% rename from tcpip/buffer/view.go rename to src_code/tcpip_bak/buffer/view.go diff --git a/src_code/tcpip_bak/buffer/view_test.go b/src_code/tcpip_bak/buffer/view_test.go new file mode 100644 index 0000000..7bae37a --- /dev/null +++ b/src_code/tcpip_bak/buffer/view_test.go @@ -0,0 +1,235 @@ +// Copyright 2018 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package buffer_test contains tests for the VectorisedView type. +package buffer + +import ( + "reflect" + "testing" +) + +// copy returns a deep-copy of the vectorised view. +func (vv VectorisedView) copy() VectorisedView { + uu := VectorisedView{ + views: make([]View, 0, len(vv.views)), + size: vv.size, + } + for _, v := range vv.views { + uu.views = append(uu.views, append(View(nil), v...)) + } + return uu +} + +// vv is an helper to build VectorisedView from different strings. +func vv(size int, pieces ...string) VectorisedView { + views := make([]View, len(pieces)) + for i, p := range pieces { + views[i] = []byte(p) + } + + return NewVectorisedView(size, views) +} + +var capLengthTestCases = []struct { + comment string + in VectorisedView + length int + want VectorisedView +}{ + { + comment: "Simple case", + in: vv(2, "12"), + length: 1, + want: vv(1, "1"), + }, + { + comment: "Case spanning across two Views", + in: vv(4, "123", "4"), + length: 2, + want: vv(2, "12"), + }, + { + comment: "Corner case with negative length", + in: vv(1, "1"), + length: -1, + want: vv(0), + }, + { + comment: "Corner case with length = 0", + in: vv(3, "12", "3"), + length: 0, + want: vv(0), + }, + { + comment: "Corner case with length = size", + in: vv(1, "1"), + length: 1, + want: vv(1, "1"), + }, + { + comment: "Corner case with length > size", + in: vv(1, "1"), + length: 2, + want: vv(1, "1"), + }, +} + +func TestCapLength(t *testing.T) { + for _, c := range capLengthTestCases { + orig := c.in.copy() + c.in.CapLength(c.length) + if !reflect.DeepEqual(c.in, c.want) { + t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v", + c.comment, c.length, orig, c.in, c.want) + } + } +} + +var trimFrontTestCases = []struct { + comment string + in VectorisedView + count int + want VectorisedView +}{ + { + comment: "Simple case", + in: vv(2, "12"), + count: 1, + want: vv(1, "2"), + }, + { + comment: "Case where we trim an entire View", + in: vv(2, "1", "2"), + count: 1, + want: vv(1, "2"), + }, + { + comment: "Case spanning across two Views", + in: vv(3, "1", "23"), + count: 2, + want: vv(1, "3"), + }, + { + comment: "Corner case with negative count", + in: vv(1, "1"), + count: -1, + want: vv(1, "1"), + }, + { + comment: " Corner case with count = 0", + in: vv(1, "1"), + count: 0, + want: vv(1, "1"), + }, + { + comment: "Corner case with count = size", + in: vv(1, "1"), + count: 1, + want: vv(0), + }, + { + comment: "Corner case with count > size", + in: vv(1, "1"), + count: 2, + want: vv(0), + }, +} + +func TestTrimFront(t *testing.T) { + for _, c := range trimFrontTestCases { + orig := c.in.copy() + c.in.TrimFront(c.count) + if !reflect.DeepEqual(c.in, c.want) { + t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v", + c.comment, c.count, orig, c.in, c.want) + } + } +} + +var toViewCases = []struct { + comment string + in VectorisedView + want View +}{ + { + comment: "Simple case", + in: vv(2, "12"), + want: []byte("12"), + }, + { + comment: "Case with multiple views", + in: vv(2, "1", "2"), + want: []byte("12"), + }, + { + comment: "Empty case", + in: vv(0), + want: []byte(""), + }, +} + +func TestToView(t *testing.T) { + for _, c := range toViewCases { + got := c.in.ToView() + if !reflect.DeepEqual(got, c.want) { + t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v", + c.comment, c.in, got, c.want) + } + } +} + +var toCloneCases = []struct { + comment string + inView VectorisedView + inBuffer []View +}{ + { + comment: "Simple case", + inView: vv(1, "1"), + inBuffer: make([]View, 1), + }, + { + comment: "Case with multiple views", + inView: vv(2, "1", "2"), + inBuffer: make([]View, 2), + }, + { + comment: "Case with buffer too small", + inView: vv(2, "1", "2"), + inBuffer: make([]View, 1), + }, + { + comment: "Case with buffer larger than needed", + inView: vv(1, "1"), + inBuffer: make([]View, 2), + }, + { + comment: "Case with nil buffer", + inView: vv(1, "1"), + inBuffer: nil, + }, +} + +func TestToClone(t *testing.T) { + for _, c := range toCloneCases { + t.Run(c.comment, func(t *testing.T) { + got := c.inView.Clone(c.inBuffer) + if !reflect.DeepEqual(got, c.inView) { + t.Fatalf("got (%+v).Clone(%+v) = %+v, want = %+v", + c.inView, c.inBuffer, got, c.inView) + } + }) + } +} diff --git a/tcpip/header/eth.go b/src_code/tcpip_bak/header/eth.go similarity index 100% rename from tcpip/header/eth.go rename to src_code/tcpip_bak/header/eth.go diff --git a/tcpip/ilist/list.go b/src_code/tcpip_bak/ilist/list.go similarity index 100% rename from tcpip/ilist/list.go rename to src_code/tcpip_bak/ilist/list.go diff --git a/tcpip/link/fdbased/endpoint.go b/src_code/tcpip_bak/link/fdbased/endpoint.go similarity index 100% rename from tcpip/link/fdbased/endpoint.go rename to src_code/tcpip_bak/link/fdbased/endpoint.go diff --git a/tcpip/link/rawfile/blockingpoll_unsafe.go b/src_code/tcpip_bak/link/rawfile/blockingpoll_unsafe.go similarity index 100% rename from tcpip/link/rawfile/blockingpoll_unsafe.go rename to src_code/tcpip_bak/link/rawfile/blockingpoll_unsafe.go diff --git a/tcpip/link/rawfile/errors.go b/src_code/tcpip_bak/link/rawfile/errors.go similarity index 100% rename from tcpip/link/rawfile/errors.go rename to src_code/tcpip_bak/link/rawfile/errors.go diff --git a/tcpip/link/rawfile/rawfile_unsafe.go b/src_code/tcpip_bak/link/rawfile/rawfile_unsafe.go similarity index 100% rename from tcpip/link/rawfile/rawfile_unsafe.go rename to src_code/tcpip_bak/link/rawfile/rawfile_unsafe.go diff --git a/tcpip/link/tap1/main.go b/src_code/tcpip_bak/link/tap1/main.go similarity index 100% rename from tcpip/link/tap1/main.go rename to src_code/tcpip_bak/link/tap1/main.go diff --git a/tcpip/link/tap2/main.go b/src_code/tcpip_bak/link/tap2/main.go similarity index 100% rename from tcpip/link/tap2/main.go rename to src_code/tcpip_bak/link/tap2/main.go diff --git a/tcpip/link/tuntap/tuntap.go b/src_code/tcpip_bak/link/tuntap/tuntap.go similarity index 100% rename from tcpip/link/tuntap/tuntap.go rename to src_code/tcpip_bak/link/tuntap/tuntap.go diff --git a/tcpip/ports/port.go b/src_code/tcpip_bak/ports/port.go similarity index 100% rename from tcpip/ports/port.go rename to src_code/tcpip_bak/ports/port.go diff --git a/tcpip/seqnum/seqnum.go b/src_code/tcpip_bak/seqnum/seqnum.go similarity index 100% rename from tcpip/seqnum/seqnum.go rename to src_code/tcpip_bak/seqnum/seqnum.go diff --git a/tcpip/stack/linkaddresscache.go b/src_code/tcpip_bak/stack/linkaddresscache.go similarity index 100% rename from tcpip/stack/linkaddresscache.go rename to src_code/tcpip_bak/stack/linkaddresscache.go diff --git a/tcpip/stack/nic.go b/src_code/tcpip_bak/stack/nic.go similarity index 100% rename from tcpip/stack/nic.go rename to src_code/tcpip_bak/stack/nic.go diff --git a/tcpip/stack/registration.go b/src_code/tcpip_bak/stack/registration.go similarity index 100% rename from tcpip/stack/registration.go rename to src_code/tcpip_bak/stack/registration.go diff --git a/tcpip/stack/route.go b/src_code/tcpip_bak/stack/route.go similarity index 100% rename from tcpip/stack/route.go rename to src_code/tcpip_bak/stack/route.go diff --git a/tcpip/stack/stack.go b/src_code/tcpip_bak/stack/stack.go similarity index 100% rename from tcpip/stack/stack.go rename to src_code/tcpip_bak/stack/stack.go diff --git a/tcpip/stack/transport_demuxer.go b/src_code/tcpip_bak/stack/transport_demuxer.go similarity index 100% rename from tcpip/stack/transport_demuxer.go rename to src_code/tcpip_bak/stack/transport_demuxer.go diff --git a/tcpip/tcpip.go b/src_code/tcpip_bak/tcpip.go similarity index 100% rename from tcpip/tcpip.go rename to src_code/tcpip_bak/tcpip.go