这个仓库代码量太大了 以读为主 不写了

This commit is contained in:
impact-eintr
2021-10-15 19:17:11 +08:00
parent 14862c4bb0
commit 03c0e5427f
126 changed files with 31369 additions and 7 deletions

View File

@@ -283,7 +283,6 @@ tap0 Link encap:Ethernet HWaddr 22:e2:f2:93:ff:bf
``` ```
删除网卡可以使用如下命令: 删除网卡可以使用如下命令:

4
go.mod
View File

@@ -1,3 +1,3 @@
module github.com/impact-eintr/netstack module tcpip
go 1.16 go 1.14

207
netstack/ilist/list.go Normal file
View File

@@ -0,0 +1,207 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ilist provides the implementation of intrusive linked lists.
package ilist
// Linker is the interface that objects must implement if they want to be added
// to and/or removed from List objects.
//
// N.B. When substituted in a template instantiation, Linker doesn't need to
// be an interface, and in most cases won't be.
type Linker interface {
Next() Element
Prev() Element
SetNext(Element)
SetPrev(Element)
}
// Element the item that is used at the API level.
//
// N.B. Like Linker, this is unlikely to be an interface in most cases.
type Element interface {
Linker
}
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type ElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (ElementMapper) linkerFor(elem Element) Linker { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type List struct {
head Element
tail Element
}
// Reset resets list l to the empty state.
func (l *List) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
func (l *List) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
func (l *List) Front() Element {
return l.head
}
// Back returns the last element of list l or nil.
func (l *List) Back() Element {
return l.tail
}
// PushFront inserts the element e at the front of list l.
func (l *List) PushFront(e Element) {
ElementMapper{}.linkerFor(e).SetNext(l.head)
ElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
ElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushBack inserts the element e at the back of list l.
func (l *List) PushBack(e Element) {
ElementMapper{}.linkerFor(e).SetNext(nil)
ElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
ElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
func (l *List) PushBackList(m *List) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
ElementMapper{}.linkerFor(l.tail).SetNext(m.head)
ElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *List) InsertAfter(b, e Element) {
a := ElementMapper{}.linkerFor(b).Next()
ElementMapper{}.linkerFor(e).SetNext(a)
ElementMapper{}.linkerFor(e).SetPrev(b)
ElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
ElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
func (l *List) InsertBefore(a, e Element) {
b := ElementMapper{}.linkerFor(a).Prev()
ElementMapper{}.linkerFor(e).SetNext(a)
ElementMapper{}.linkerFor(e).SetPrev(b)
ElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
ElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
func (l *List) Remove(e Element) {
prev := ElementMapper{}.linkerFor(e).Prev()
next := ElementMapper{}.linkerFor(e).Next()
if prev != nil {
ElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
ElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type Entry struct {
next Element
prev Element
}
// Next returns the entry that follows e in the list.
func (e *Entry) Next() Element {
return e.next
}
// Prev returns the entry that precedes e in the list.
func (e *Entry) Prev() Element {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
func (e *Entry) SetNext(elem Element) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *Entry) SetPrev(elem Element) {
e.prev = elem
}

27
netstack/rand/rand.go Normal file
View File

@@ -0,0 +1,27 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package rand implements a cryptographically secure pseudorandom number
// generator.
package rand
import "crypto/rand"
// Reader is the default reader.
var Reader = rand.Reader
// Read implements io.Reader.Read.
func Read(b []byte) (int, error) {
return rand.Read(b)
}

View File

@@ -0,0 +1,35 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "textflag.h"
#define preparingG 1
// See commit_noasm.go for a description of commitSleep.
//
// func commitSleep(g uintptr, waitingG *uintptr) bool
TEXT ·commitSleep(SB),NOSPLIT,$0-24
MOVQ waitingG+8(FP), CX
MOVQ g+0(FP), DX
// Store the G in waitingG if it's still preparingG. If it's anything
// else it means a waker has aborted the sleep.
MOVQ $preparingG, AX
LOCK
CMPXCHGQ DX, 0(CX)
SETEQ AX
MOVB AX, ret+16(FP)
RET

View File

@@ -0,0 +1,20 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build amd64
package sleep
// See commit_noasm.go for a description of commitSleep.
func commitSleep(g uintptr, waitingG *uintptr) bool

View File

@@ -0,0 +1,42 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !race
// +build !amd64
package sleep
import "sync/atomic"
// commitSleep signals to wakers that the given g is now sleeping. Wakers can
// then fetch it and wake it.
//
// The commit may fail if wakers have been asserted after our last check, in
// which case they will have set s.waitingG to zero.
//
// It is written in assembly because it is called from g0, so it doesn't have
// a race context.
func commitSleep(g uintptr, waitingG *uintptr) bool {
for {
// Check if the wait was aborted.
if atomic.LoadUintptr(waitingG) == 0 {
return false
}
// Try to store the G so that wakers know who to wake.
if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) {
return true
}
}
}

15
netstack/sleep/empty.s Normal file
View File

@@ -0,0 +1,15 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Empty assembly file so empty func definitions work.

View File

@@ -0,0 +1,542 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sleep
import (
"math/rand"
"runtime"
"testing"
"time"
)
// ZeroWakerNotAsserted tests that a zero-value waker is in non-asserted state.
func ZeroWakerNotAsserted(t *testing.T) {
var w Waker
if w.IsAsserted() {
t.Fatalf("Zero waker is asserted")
}
if w.Clear() {
t.Fatalf("Zero waker is asserted")
}
}
// AssertedWakerAfterAssert tests that a waker properly reports its state as
// asserted once its Assert() method is called.
func AssertedWakerAfterAssert(t *testing.T) {
var w Waker
w.Assert()
if !w.IsAsserted() {
t.Fatalf("Asserted waker is not reported as such")
}
if !w.Clear() {
t.Fatalf("Asserted waker is not reported as such")
}
}
// AssertedWakerAfterTwoAsserts tests that a waker properly reports its state as
// asserted once its Assert() method is called twice.
func AssertedWakerAfterTwoAsserts(t *testing.T) {
var w Waker
w.Assert()
w.Assert()
if !w.IsAsserted() {
t.Fatalf("Asserted waker is not reported as such")
}
if !w.Clear() {
t.Fatalf("Asserted waker is not reported as such")
}
}
// NotAssertedWakerWithSleeper tests that a waker properly reports its state as
// not asserted after a sleeper is associated with it.
func NotAssertedWakerWithSleeper(t *testing.T) {
var w Waker
var s Sleeper
s.AddWaker(&w, 0)
if w.IsAsserted() {
t.Fatalf("Non-asserted waker is reported as asserted")
}
if w.Clear() {
t.Fatalf("Non-asserted waker is reported as asserted")
}
}
// NotAssertedWakerAfterWake tests that a waker properly reports its state as
// not asserted after a previous assert is consumed by a sleeper. That is, tests
// the "edge-triggered" behavior.
func NotAssertedWakerAfterWake(t *testing.T) {
var w Waker
var s Sleeper
s.AddWaker(&w, 0)
w.Assert()
s.Fetch(true)
if w.IsAsserted() {
t.Fatalf("Consumed waker is reported as asserted")
}
if w.Clear() {
t.Fatalf("Consumed waker is reported as asserted")
}
}
// AssertedWakerBeforeAdd tests that a waker causes a sleeper to not sleep if
// it's already asserted before being added.
func AssertedWakerBeforeAdd(t *testing.T) {
var w Waker
var s Sleeper
w.Assert()
s.AddWaker(&w, 0)
if _, ok := s.Fetch(false); !ok {
t.Fatalf("Fetch failed even though asserted waker was added")
}
}
// ClearedWaker tests that a waker properly reports its state as not asserted
// after it is cleared.
func ClearedWaker(t *testing.T) {
var w Waker
w.Assert()
w.Clear()
if w.IsAsserted() {
t.Fatalf("Cleared waker is reported as asserted")
}
if w.Clear() {
t.Fatalf("Cleared waker is reported as asserted")
}
}
// ClearedWakerWithSleeper tests that a waker properly reports its state as
// not asserted when it is cleared while it has a sleeper associated with it.
func ClearedWakerWithSleeper(t *testing.T) {
var w Waker
var s Sleeper
s.AddWaker(&w, 0)
w.Clear()
if w.IsAsserted() {
t.Fatalf("Cleared waker is reported as asserted")
}
if w.Clear() {
t.Fatalf("Cleared waker is reported as asserted")
}
}
// ClearedWakerAssertedWithSleeper tests that a waker properly reports its state
// as not asserted when it is cleared while it has a sleeper associated with it
// and has been asserted.
func ClearedWakerAssertedWithSleeper(t *testing.T) {
var w Waker
var s Sleeper
s.AddWaker(&w, 0)
w.Assert()
w.Clear()
if w.IsAsserted() {
t.Fatalf("Cleared waker is reported as asserted")
}
if w.Clear() {
t.Fatalf("Cleared waker is reported as asserted")
}
}
// TestBlock tests that a sleeper actually blocks waiting for the waker to
// assert its state.
func TestBlock(t *testing.T) {
var w Waker
var s Sleeper
s.AddWaker(&w, 0)
// Assert waker after one second.
before := time.Now()
go func() {
time.Sleep(1 * time.Second)
w.Assert()
}()
// Fetch the result and make sure it took at least 500ms.
if _, ok := s.Fetch(true); !ok {
t.Fatalf("Fetch failed unexpectedly")
}
if d := time.Now().Sub(before); d < 500*time.Millisecond {
t.Fatalf("Duration was too short: %v", d)
}
// Check that already-asserted waker completes inline.
w.Assert()
if _, ok := s.Fetch(true); !ok {
t.Fatalf("Fetch failed unexpectedly")
}
// Check that fetch sleeps if waker had been asserted but was reset
// before Fetch is called.
w.Assert()
w.Clear()
before = time.Now()
go func() {
time.Sleep(1 * time.Second)
w.Assert()
}()
if _, ok := s.Fetch(true); !ok {
t.Fatalf("Fetch failed unexpectedly")
}
if d := time.Now().Sub(before); d < 500*time.Millisecond {
t.Fatalf("Duration was too short: %v", d)
}
}
// TestNonBlock checks that a sleeper won't block if waker isn't asserted.
func TestNonBlock(t *testing.T) {
var w Waker
var s Sleeper
// Don't block when there's no waker.
if _, ok := s.Fetch(false); ok {
t.Fatalf("Fetch succeeded when there is no waker")
}
// Don't block when waker isn't asserted.
s.AddWaker(&w, 0)
if _, ok := s.Fetch(false); ok {
t.Fatalf("Fetch succeeded when waker was not asserted")
}
// Don't block when waker was asserted, but isn't anymore.
w.Assert()
w.Clear()
if _, ok := s.Fetch(false); ok {
t.Fatalf("Fetch succeeded when waker was not asserted anymore")
}
// Don't block when waker was consumed by previous Fetch().
w.Assert()
if _, ok := s.Fetch(false); !ok {
t.Fatalf("Fetch failed even though waker was asserted")
}
if _, ok := s.Fetch(false); ok {
t.Fatalf("Fetch succeeded when waker had been consumed")
}
}
// TestMultiple checks that a sleeper can wait for and receives notifications
// from multiple wakers.
func TestMultiple(t *testing.T) {
s := Sleeper{}
w1 := Waker{}
w2 := Waker{}
s.AddWaker(&w1, 0)
s.AddWaker(&w2, 1)
w1.Assert()
w2.Assert()
v, ok := s.Fetch(false)
if !ok {
t.Fatalf("Fetch failed when there are asserted wakers")
}
if v != 0 && v != 1 {
t.Fatalf("Unexpected waker id: %v", v)
}
want := 1 - v
v, ok = s.Fetch(false)
if !ok {
t.Fatalf("Fetch failed when there is an asserted waker")
}
if v != want {
t.Fatalf("Unexpected waker id, got %v, want %v", v, want)
}
}
// TestDoneFunction tests if calling Done() on a sleeper works properly.
func TestDoneFunction(t *testing.T) {
// Trivial case of no waker.
s := Sleeper{}
s.Done()
// Cases when the sleeper has n wakers, but none are asserted.
for n := 1; n < 20; n++ {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
s.AddWaker(&w[j], j)
}
s.Done()
}
// Cases when the sleeper has n wakers, and only the i-th one is
// asserted.
for n := 1; n < 20; n++ {
for i := 0; i < n; i++ {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
s.AddWaker(&w[j], j)
}
w[i].Assert()
s.Done()
}
}
// Cases when the sleeper has n wakers, and the i-th one is asserted
// and cleared.
for n := 1; n < 20; n++ {
for i := 0; i < n; i++ {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
s.AddWaker(&w[j], j)
}
w[i].Assert()
w[i].Clear()
s.Done()
}
}
// Cases when the sleeper has n wakers, with a random number of them
// asserted.
for n := 1; n < 20; n++ {
for iters := 0; iters < 1000; iters++ {
s := Sleeper{}
w := make([]Waker, n)
for j := 0; j < n; j++ {
s.AddWaker(&w[j], j)
}
// Pick the number of asserted elements, then assert
// random wakers.
asserted := rand.Int() % (n + 1)
for j := 0; j < asserted; j++ {
w[rand.Int()%n].Assert()
}
s.Done()
}
}
}
// TestRace tests that multiple wakers can continuously send wake requests to
// the sleeper.
func TestRace(t *testing.T) {
const wakers = 100
const wakeRequests = 10000
counts := make([]int, wakers)
w := make([]Waker, wakers)
s := Sleeper{}
// Associate each waker and start goroutines that will assert them.
for i := range w {
s.AddWaker(&w[i], i)
go func(w *Waker) {
n := 0
for n < wakeRequests {
if !w.IsAsserted() {
w.Assert()
n++
} else {
runtime.Gosched()
}
}
}(&w[i])
}
// Wait for all wake up notifications from all wakers.
for i := 0; i < wakers*wakeRequests; i++ {
v, _ := s.Fetch(true)
counts[v]++
}
// Check that we got the right number for each.
for i, v := range counts {
if v != wakeRequests {
t.Errorf("Waker %v only got %v wakes", i, v)
}
}
}
// BenchmarkSleeperMultiSelect measures how long it takes to fetch a wake up
// from 4 wakers when at least one is already asserted.
func BenchmarkSleeperMultiSelect(b *testing.B) {
const count = 4
s := Sleeper{}
w := make([]Waker, count)
for i := range w {
s.AddWaker(&w[i], i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
w[count-1].Assert()
s.Fetch(true)
}
}
// BenchmarkGoMultiSelect measures how long it takes to fetch a zero-length
// struct from one of 4 channels when at least one is ready.
func BenchmarkGoMultiSelect(b *testing.B) {
const count = 4
ch := make([]chan struct{}, count)
for i := range ch {
ch[i] = make(chan struct{}, 1)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
ch[count-1] <- struct{}{}
select {
case <-ch[0]:
case <-ch[1]:
case <-ch[2]:
case <-ch[3]:
}
}
}
// BenchmarkSleeperSingleSelect measures how long it takes to fetch a wake up
// from one waker that is already asserted.
func BenchmarkSleeperSingleSelect(b *testing.B) {
s := Sleeper{}
w := Waker{}
s.AddWaker(&w, 0)
b.ResetTimer()
for i := 0; i < b.N; i++ {
w.Assert()
s.Fetch(true)
}
}
// BenchmarkGoSingleSelect measures how long it takes to fetch a zero-length
// struct from a channel that already has it buffered.
func BenchmarkGoSingleSelect(b *testing.B) {
ch := make(chan struct{}, 1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
ch <- struct{}{}
<-ch
}
}
// BenchmarkSleeperAssertNonWaiting measures how long it takes to assert a
// channel that is already asserted.
func BenchmarkSleeperAssertNonWaiting(b *testing.B) {
w := Waker{}
w.Assert()
for i := 0; i < b.N; i++ {
w.Assert()
}
}
// BenchmarkGoAssertNonWaiting measures how long it takes to write to a channel
// that has already something written to it.
func BenchmarkGoAssertNonWaiting(b *testing.B) {
ch := make(chan struct{}, 1)
ch <- struct{}{}
for i := 0; i < b.N; i++ {
select {
case ch <- struct{}{}:
default:
}
}
}
// BenchmarkSleeperWaitOnSingleSelect measures how long it takes to wait on one
// waker channel while another goroutine wakes up the sleeper. This assumes that
// a new goroutine doesn't run immediately (i.e., the creator of a new goroutine
// is allowed to go to sleep before the new goroutine has a chance to run).
func BenchmarkSleeperWaitOnSingleSelect(b *testing.B) {
s := Sleeper{}
w := Waker{}
s.AddWaker(&w, 0)
for i := 0; i < b.N; i++ {
go func() {
w.Assert()
}()
s.Fetch(true)
}
}
// BenchmarkGoWaitOnSingleSelect measures how long it takes to wait on one
// channel while another goroutine wakes up the sleeper. This assumes that a new
// goroutine doesn't run immediately (i.e., the creator of a new goroutine is
// allowed to go to sleep before the new goroutine has a chance to run).
func BenchmarkGoWaitOnSingleSelect(b *testing.B) {
ch := make(chan struct{}, 1)
for i := 0; i < b.N; i++ {
go func() {
ch <- struct{}{}
}()
<-ch
}
}
// BenchmarkSleeperWaitOnMultiSelect measures how long it takes to wait on 4
// wakers while another goroutine wakes up the sleeper. This assumes that a new
// goroutine doesn't run immediately (i.e., the creator of a new goroutine is
// allowed to go to sleep before the new goroutine has a chance to run).
func BenchmarkSleeperWaitOnMultiSelect(b *testing.B) {
const count = 4
s := Sleeper{}
w := make([]Waker, count)
for i := range w {
s.AddWaker(&w[i], i)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
go func() {
w[count-1].Assert()
}()
s.Fetch(true)
}
}
// BenchmarkGoWaitOnMultiSelect measures how long it takes to wait on 4 channels
// while another goroutine wakes up the sleeper. This assumes that a new
// goroutine doesn't run immediately (i.e., the creator of a new goroutine is
// allowed to go to sleep before the new goroutine has a chance to run).
func BenchmarkGoWaitOnMultiSelect(b *testing.B) {
const count = 4
ch := make([]chan struct{}, count)
for i := range ch {
ch[i] = make(chan struct{}, 1)
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
go func() {
ch[count-1] <- struct{}{}
}()
select {
case <-ch[0]:
case <-ch[1]:
case <-ch[2]:
case <-ch[3]:
}
}
}

View File

@@ -0,0 +1,395 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package sleep allows goroutines to efficiently sleep on multiple sources of
// notifications (wakers). It offers O(1) complexity, which is different from
// multi-channel selects which have O(n) complexity (where n is the number of
// channels) and a considerable constant factor.
//
// It is similar to edge-triggered epoll waits, where the user registers each
// object of interest once, and then can repeatedly wait on all of them.
//
// A Waker object is used to wake a sleeping goroutine (G) up, or prevent it
// from going to sleep next. A Sleeper object is used to receive notifications
// from wakers, and if no notifications are available, to optionally sleep until
// one becomes available.
//
// A Waker can be associated with at most one Sleeper, but a Sleeper can be
// associated with multiple Wakers. A Sleeper has a list of asserted (ready)
// wakers; when Fetch() is called repeatedly, elements from this list are
// returned until the list becomes empty in which case the goroutine goes to
// sleep. When Assert() is called on a Waker, it adds itself to the Sleeper's
// asserted list and wakes the G up from its sleep if needed.
//
// Sleeper objects are expected to be used as follows, with just one goroutine
// executing this code:
//
// // One time set-up.
// s := sleep.Sleeper{}
// s.AddWaker(&w1, constant1)
// s.AddWaker(&w2, constant2)
//
// // Called repeatedly.
// for {
// switch id, _ := s.Fetch(true); id {
// case constant1:
// // Do work triggered by w1 being asserted.
// case constant2:
// // Do work triggered by w2 being asserted.
// }
// }
//
// And Waker objects are expected to call w.Assert() when they want the sleeper
// to wake up and perform work.
//
// The notifications are edge-triggered, which means that if a Waker calls
// Assert() several times before the sleeper has the chance to wake up, it will
// only be notified once and should perform all pending work (alternatively, it
// can also call Assert() on the waker, to ensure that it will wake up again).
//
// The "unsafeness" here is in the casts to/from unsafe.Pointer, which is safe
// when only one type is used for each unsafe.Pointer (which is the case here),
// we should just make sure that this remains the case in the future. The usage
// of unsafe package could be confined to sharedWaker and sharedSleeper types
// that would hold pointers in atomic.Pointers, but the go compiler currently
// can't optimize these as well (it won't inline their method calls), which
// reduces performance.
package sleep
import (
"sync/atomic"
"unsafe"
)
const (
// preparingG is stored in sleepers to indicate that they're preparing
// to sleep.
preparingG = 1
)
var (
// assertedSleeper is a sentinel sleeper. A pointer to it is stored in
// wakers that are asserted.
assertedSleeper Sleeper
)
//go:linkname gopark runtime.gopark
func gopark(unlockf func(uintptr, *uintptr) bool, wg *uintptr, reason string, traceEv byte, traceskip int)
//go:linkname goready runtime.goready
func goready(g uintptr, traceskip int)
// Sleeper allows a goroutine to sleep and receive wake up notifications from
// Wakers in an efficient way.
//
// This is similar to edge-triggered epoll in that wakers are added to the
// sleeper once and the sleeper can then repeatedly sleep in O(1) time while
// waiting on all wakers.
//
// None of the methods in a Sleeper can be called concurrently. Wakers that have
// been added to a sleeper A can only be added to another sleeper after A.Done()
// returns. These restrictions allow this to be implemented lock-free.
//
// This struct is thread-compatible.
type Sleeper struct {
// sharedList is a "stack" of asserted wakers. They atomically add
// themselves to the front of this list as they become asserted.
sharedList unsafe.Pointer
// localList is a list of asserted wakers that is only accessible to the
// waiter, and thus doesn't have to be accessed atomically. When
// fetching more wakers, the waiter will first go through this list, and
// only when it's empty will it atomically fetch wakers from
// sharedList.
localList *Waker
// allWakers is a list with all wakers that have been added to this
// sleeper. It is used during cleanup to remove associations.
allWakers *Waker
// waitingG holds the G that is sleeping, if any. It is used by wakers
// to determine which G, if any, they should wake.
waitingG uintptr
}
// AddWaker associates the given waker to the sleeper. id is the value to be
// returned when the sleeper is woken by the given waker.
func (s *Sleeper) AddWaker(w *Waker, id int) {
// Add the waker to the list of all wakers.
w.allWakersNext = s.allWakers
s.allWakers = w
w.id = id
// Try to associate the waker with the sleeper. If it's already
// asserted, we simply enqueue it in the "ready" list.
for {
p := (*Sleeper)(atomic.LoadPointer(&w.s))
if p == &assertedSleeper {
s.enqueueAssertedWaker(w)
return
}
if atomic.CompareAndSwapPointer(&w.s, usleeper(p), usleeper(s)) {
return
}
}
}
// nextWaker returns the next waker in the notification list, blocking if
// needed.
func (s *Sleeper) nextWaker(block bool) *Waker {
// Attempt to replenish the local list if it's currently empty.
if s.localList == nil {
for atomic.LoadPointer(&s.sharedList) == nil {
// Fail request if caller requested that we
// don't block.
if !block {
return nil
}
// Indicate to wakers that we're about to sleep,
// this allows them to abort the wait by setting
// waitingG back to zero (which we'll notice
// before committing the sleep).
atomic.StoreUintptr(&s.waitingG, preparingG)
// Check if something was queued while we were
// preparing to sleep. We need this interleaving
// to avoid missing wake ups.
if atomic.LoadPointer(&s.sharedList) != nil {
atomic.StoreUintptr(&s.waitingG, 0)
break
}
// Try to commit the sleep and report it to the
// tracer as a select.
//
// gopark puts the caller to sleep and calls
// commitSleep to decide whether to immediately
// wake the caller up or to leave it sleeping.
const traceEvGoBlockSelect = 24
gopark(commitSleep, &s.waitingG, "sleeper", traceEvGoBlockSelect, 0)
}
// Pull the shared list out and reverse it in the local
// list. Given that wakers push themselves in reverse
// order, we fix things here.
v := (*Waker)(atomic.SwapPointer(&s.sharedList, nil))
for v != nil {
cur := v
v = v.next
cur.next = s.localList
s.localList = cur
}
}
// Remove the waker in the front of the list.
w := s.localList
s.localList = w.next
return w
}
// Fetch fetches the next wake-up notification. If a notification is immediately
// available, it is returned right away. Otherwise, the behavior depends on the
// value of 'block': if true, the current goroutine blocks until a notification
// arrives, then returns it; if false, returns 'ok' as false.
//
// When 'ok' is true, the value of 'id' corresponds to the id associated with
// the waker; when 'ok' is false, 'id' is undefined.
//
// N.B. This method is *not* thread-safe. Only one goroutine at a time is
// allowed to call this method.
func (s *Sleeper) Fetch(block bool) (id int, ok bool) {
for {
w := s.nextWaker(block)
if w == nil {
return -1, false
}
// Reassociate the waker with the sleeper. If the waker was
// still asserted we can return it, otherwise try the next one.
old := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(s)))
if old == &assertedSleeper {
return w.id, true
}
}
}
// Done is used to indicate that the caller won't use this Sleeper anymore. It
// removes the association with all wakers so that they can be safely reused
// by another sleeper after Done() returns.
func (s *Sleeper) Done() {
// Remove all associations that we can, and build a list of the ones
// we could not. An association can be removed right away from waker w
// if w.s has a pointer to the sleeper, that is, the waker is not
// asserted yet. By atomically switching w.s to nil, we guarantee that
// subsequent calls to Assert() on the waker will not result in it being
// queued to this sleeper.
var pending *Waker
w := s.allWakers
for w != nil {
next := w.allWakersNext
for {
t := atomic.LoadPointer(&w.s)
if t != usleeper(s) {
w.allWakersNext = pending
pending = w
break
}
if atomic.CompareAndSwapPointer(&w.s, t, nil) {
break
}
}
w = next
}
// The associations that we could not remove are either asserted, or in
// the process of being asserted, or have been asserted and cleared
// before being pulled from the sleeper lists. We must wait for them all
// to make it to the sleeper lists, so that we know that the wakers
// won't do any more work towards waking this sleeper up.
for pending != nil {
pulled := s.nextWaker(true)
// Remove the waker we just pulled from the list of associated
// wakers.
prev := &pending
for w := *prev; w != nil; w = *prev {
if pulled == w {
*prev = w.allWakersNext
break
}
prev = &w.allWakersNext
}
}
s.allWakers = nil
}
// enqueueAssertedWaker enqueues an asserted waker to the "ready" circular list
// of wakers that want to notify the sleeper.
func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
// Add the new waker to the front of the list.
for {
v := (*Waker)(atomic.LoadPointer(&s.sharedList))
w.next = v
if atomic.CompareAndSwapPointer(&s.sharedList, uwaker(v), uwaker(w)) {
break
}
}
for {
// Nothing to do if there isn't a G waiting.
g := atomic.LoadUintptr(&s.waitingG)
if g == 0 {
return
}
// Signal to the sleeper that a waker has been asserted.
if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) {
if g != preparingG {
// We managed to get a G. Wake it up.
goready(g, 0)
}
}
}
}
// Waker represents a source of wake-up notifications to be sent to sleepers. A
// waker can be associated with at most one sleeper at a time, and at any given
// time is either in asserted or non-asserted state.
//
// Once asserted, the waker remains so until it is manually cleared or a sleeper
// consumes its assertion (i.e., a sleeper wakes up or is prevented from going
// to sleep due to the waker).
//
// This struct is thread-safe, that is, its methods can be called concurrently
// by multiple goroutines.
type Waker struct {
// s is the sleeper that this waker can wake up. Only one sleeper at a
// time is allowed. This field can have three classes of values:
// nil -- the waker is not asserted: it either is not associated with
// a sleeper, or is queued to a sleeper due to being previously
// asserted. This is the zero value.
// &assertedSleeper -- the waker is asserted.
// otherwise -- the waker is not asserted, and is associated with the
// given sleeper. Once it transitions to asserted state, the
// associated sleeper will be woken.
s unsafe.Pointer
// next is used to form a linked list of asserted wakers in a sleeper.
next *Waker
// allWakersNext is used to form a linked list of all wakers associated
// to a given sleeper.
allWakersNext *Waker
// id is the value to be returned to sleepers when they wake up due to
// this waker being asserted.
id int
}
// Assert moves the waker to an asserted state, if it isn't asserted yet. When
// asserted, the waker will cause its matching sleeper to wake up.
func (w *Waker) Assert() {
// Nothing to do if the waker is already asserted. This check allows us
// to complete this case (already asserted) without any interlocked
// operations on x86.
if atomic.LoadPointer(&w.s) == usleeper(&assertedSleeper) {
return
}
// Mark the waker as asserted, and wake up a sleeper if there is one.
switch s := (*Sleeper)(atomic.SwapPointer(&w.s, usleeper(&assertedSleeper))); s {
case nil:
case &assertedSleeper:
default:
s.enqueueAssertedWaker(w)
}
}
// Clear moves the waker to then non-asserted state and returns whether it was
// asserted before being cleared.
//
// N.B. The waker isn't removed from the "ready" list of a sleeper (if it
// happens to be in one), but the sleeper will notice that it is not asserted
// anymore and won't return it to the caller.
func (w *Waker) Clear() bool {
// Nothing to do if the waker is not asserted. This check allows us to
// complete this case (already not asserted) without any interlocked
// operations on x86.
if atomic.LoadPointer(&w.s) != usleeper(&assertedSleeper) {
return false
}
// Try to store nil in the sleeper, which indicates that the waker is
// not asserted.
return atomic.CompareAndSwapPointer(&w.s, usleeper(&assertedSleeper), nil)
}
// IsAsserted returns whether the waker is currently asserted (i.e., if it's
// currently in a state that would cause its matching sleeper to wake up).
func (w *Waker) IsAsserted() bool {
return (*Sleeper)(atomic.LoadPointer(&w.s)) == &assertedSleeper
}
func usleeper(s *Sleeper) unsafe.Pointer {
return unsafe.Pointer(s)
}
func uwaker(w *Waker) unsafe.Pointer {
return unsafe.Pointer(w)
}

View File

@@ -0,0 +1,64 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package buffer
// Prependable is a buffer that grows backwards, that is, more data can be
// prepended to it. It is useful when building networking packets, where each
// protocol adds its own headers to the front of the higher-level protocol
// header and payload; for example, TCP would prepend its header to the payload,
// then IP would prepend its own, then ethernet.
type Prependable struct {
// Buf is the buffer backing the prependable buffer.
buf View
// usedIdx is the index where the used part of the buffer begins.
usedIdx int
}
// NewPrependable allocates a new prependable buffer with the given size.
func NewPrependable(size int) Prependable {
return Prependable{buf: NewView(size), usedIdx: size}
}
// NewPrependableFromView creates an entirely-used Prependable from a View.
//
// NewPrependableFromView takes ownership of v. Note that since the entire
// prependable is used, further attempts to call Prepend will note that size >
// p.usedIdx and return nil.
func NewPrependableFromView(v View) Prependable {
return Prependable{buf: v, usedIdx: 0}
}
// View returns a View of the backing buffer that contains all prepended
// data so far.
func (p Prependable) View() View {
return p.buf[p.usedIdx:]
}
// UsedLength returns the number of bytes used so far.
func (p Prependable) UsedLength() int {
return len(p.buf) - p.usedIdx
}
// Prepend reserves the requested space in front of the buffer, returning a
// slice that represents the reserved space.
func (p *Prependable) Prepend(size int) []byte {
if size > p.usedIdx {
return nil
}
p.usedIdx -= size
return p.View()[:size:size]
}

View File

@@ -0,0 +1,146 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package buffer provides the implementation of a buffer view.
package buffer
// View is a slice of a buffer, with convenience methods.
type View []byte
// NewView allocates a new buffer and returns an initialized view that covers
// the whole buffer.
func NewView(size int) View {
return make(View, size)
}
// NewViewFromBytes allocates a new buffer and copies in the given bytes.
func NewViewFromBytes(b []byte) View {
return append(View(nil), b...)
}
// TrimFront removes the first "count" bytes from the visible section of the
// buffer.
func (v *View) TrimFront(count int) {
*v = (*v)[count:]
}
// CapLength irreversibly reduces the length of the visible section of the
// buffer to the value specified.
func (v *View) CapLength(length int) {
// We also set the slice cap because if we don't, one would be able to
// expand the view back to include the region just excluded. We want to
// prevent that to avoid potential data leak if we have uninitialized
// data in excluded region.
*v = (*v)[:length:length]
}
// ToVectorisedView returns a VectorisedView containing the receiver.
func (v View) ToVectorisedView() VectorisedView {
return NewVectorisedView(len(v), []View{v})
}
// VectorisedView is a vectorised version of View using non contigous memory.
// It supports all the convenience methods supported by View.
//
// +stateify savable
type VectorisedView struct {
views []View
size int
}
// NewVectorisedView creates a new vectorised view from an already-allocated slice
// of View and sets its size.
func NewVectorisedView(size int, views []View) VectorisedView {
return VectorisedView{views: views, size: size}
}
// TrimFront removes the first "count" bytes of the vectorised view.
func (vv *VectorisedView) TrimFront(count int) {
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
vv.size -= count
vv.views[0].TrimFront(count)
return
}
count -= len(vv.views[0])
vv.RemoveFirst()
}
}
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
length = 0
}
if vv.size < length {
return
}
vv.size = length
for i := range vv.views {
v := &vv.views[i]
if len(*v) >= length {
if length == 0 {
vv.views = vv.views[:i]
} else {
v.CapLength(length)
vv.views = vv.views[:i+1]
}
return
}
length -= len(*v)
}
}
// Clone returns a clone of this VectorisedView.
// If the buffer argument is large enough to contain all the Views of this VectorisedView,
// the method will avoid allocations and use the buffer to store the Views of the clone.
func (vv VectorisedView) Clone(buffer []View) VectorisedView {
return VectorisedView{views: append(buffer[:0], vv.views...), size: vv.size}
}
// First returns the first view of the vectorised view.
func (vv VectorisedView) First() View {
if len(vv.views) == 0 {
return nil
}
return vv.views[0]
}
// RemoveFirst removes the first view of the vectorised view.
func (vv *VectorisedView) RemoveFirst() {
if len(vv.views) == 0 {
return
}
vv.size -= len(vv.views[0])
vv.views = vv.views[1:]
}
// Size returns the size in bytes of the entire content stored in the vectorised view.
func (vv VectorisedView) Size() int {
return vv.size
}
// ToView returns a single view containing the content of the vectorised view.
func (vv VectorisedView) ToView() View {
u := make([]byte, 0, vv.size)
for _, v := range vv.views {
u = append(u, v...)
}
return u
}
// Views returns the slice containing the all views.
func (vv VectorisedView) Views() []View {
return vv.views
}

View File

@@ -0,0 +1,588 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package checker provides helper functions to check networking packets for
// validity.
package checker
import (
"encoding/binary"
"reflect"
"testing"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
)
// NetworkChecker is a function to check a property of a network packet.
type NetworkChecker func(*testing.T, []header.Network)
// TransportChecker is a function to check a property of a transport packet.
type TransportChecker func(*testing.T, header.Transport)
// IPv4 checks the validity and properties of the given IPv4 packet. It is
// expected to be used in conjunction with other network checkers for specific
// properties. For example, to check the source and destination address, one
// would call:
//
// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
t.Helper()
ipv4 := header.IPv4(b)
if !ipv4.IsValid(len(b)) {
t.Error("Not a valid IPv4 packet")
}
xsum := ipv4.CalculateChecksum()
if xsum != 0 && xsum != 0xffff {
t.Errorf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
}
for _, f := range checkers {
f(t, []header.Network{ipv4})
}
if t.Failed() {
t.FailNow()
}
}
// IPv6 checks the validity and properties of the given IPv6 packet. The usage
// is similar to IPv4.
func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
t.Helper()
ipv6 := header.IPv6(b)
if !ipv6.IsValid(len(b)) {
t.Error("Not a valid IPv6 packet")
}
for _, f := range checkers {
f(t, []header.Network{ipv6})
}
if t.Failed() {
t.FailNow()
}
}
// SrcAddr creates a checker that checks the source address.
func SrcAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if a := h[0].SourceAddress(); a != addr {
t.Errorf("Bad source address, got %v, want %v", a, addr)
}
}
}
// DstAddr creates a checker that checks the destination address.
func DstAddr(addr tcpip.Address) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if a := h[0].DestinationAddress(); a != addr {
t.Errorf("Bad destination address, got %v, want %v", a, addr)
}
}
}
// TTL creates a checker that checks the TTL (ipv4) or HopLimit (ipv6).
func TTL(ttl uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
var v uint8
switch ip := h[0].(type) {
case header.IPv4:
v = ip.TTL()
case header.IPv6:
v = ip.HopLimit()
}
if v != ttl {
t.Fatalf("Bad TTL, got %v, want %v", v, ttl)
}
}
}
// PayloadLen creates a checker that checks the payload length.
func PayloadLen(plen int) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if l := len(h[0].Payload()); l != plen {
t.Errorf("Bad payload length, got %v, want %v", l, plen)
}
}
}
// FragmentOffset creates a checker that checks the FragmentOffset field.
func FragmentOffset(offset uint16) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
// We only do this of IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.FragmentOffset(); v != offset {
t.Errorf("Bad fragment offset, got %v, want %v", v, offset)
}
}
}
}
// FragmentFlags creates a checker that checks the fragment flags field.
func FragmentFlags(flags uint8) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
// We only do this of IPv4 for now.
switch ip := h[0].(type) {
case header.IPv4:
if v := ip.Flags(); v != flags {
t.Errorf("Bad fragment offset, got %v, want %v", v, flags)
}
}
}
}
// TOS creates a checker that checks the TOS field.
func TOS(tos uint8, label uint32) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if v, l := h[0].TOS(); v != tos || l != label {
t.Errorf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
}
}
}
// Raw creates a checker that checks the bytes of payload.
// The checker always checks the payload of the last network header.
// For instance, in case of IPv6 fragments, the payload that will be checked
// is the one containing the actual data that the packet is carrying, without
// the bytes added by the IPv6 fragmentation.
func Raw(want []byte) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}
// IPv6Fragment creates a checker that validates an IPv6 fragment.
func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
}
ipv6Frag := header.IPv6Fragment(h[0].Payload())
if !ipv6Frag.IsValid() {
t.Error("Not a valid IPv6 fragment")
}
for _, f := range checkers {
f(t, []header.Network{h[0], ipv6Frag})
}
if t.Failed() {
t.FailNow()
}
}
}
// TCP creates a checker that checks that the transport protocol is TCP and
// potentially additional transport header fields.
func TCP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
first := h[0]
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
t.Errorf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
}
// Verify the checksum.
tcp := header.TCP(last.Payload())
l := uint16(len(tcp))
xsum := header.Checksum([]byte(first.SourceAddress()), 0)
xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
xsum = header.Checksum(tcp, xsum)
if xsum != 0 && xsum != 0xffff {
t.Errorf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
}
// Run the transport checkers.
for _, f := range checkers {
f(t, tcp)
}
if t.Failed() {
t.FailNow()
}
}
}
// UDP creates a checker that checks that the transport protocol is UDP and
// potentially additional transport header fields.
func UDP(checkers ...TransportChecker) NetworkChecker {
return func(t *testing.T, h []header.Network) {
t.Helper()
last := h[len(h)-1]
if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
t.Errorf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
}
udp := header.UDP(last.Payload())
for _, f := range checkers {
f(t, udp)
}
if t.Failed() {
t.FailNow()
}
}
}
// SrcPort creates a checker that checks the source port.
func SrcPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
if p := h.SourcePort(); p != port {
t.Errorf("Bad source port, got %v, want %v", p, port)
}
}
}
// DstPort creates a checker that checks the destination port.
func DstPort(port uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
if p := h.DestinationPort(); p != port {
t.Errorf("Bad destination port, got %v, want %v", p, port)
}
}
}
// SeqNum creates a checker that checks the sequence number.
func SeqNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
if s := tcp.SequenceNumber(); s != seq {
t.Errorf("Bad sequence number, got %v, want %v", s, seq)
}
}
}
// AckNum creates a checker that checks the ack number.
func AckNum(seq uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
if s := tcp.AckNumber(); s != seq {
t.Errorf("Bad ack number, got %v, want %v", s, seq)
}
}
}
// Window creates a checker that checks the tcp window.
func Window(window uint16) TransportChecker {
return func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
if w := tcp.WindowSize(); w != window {
t.Errorf("Bad window, got 0x%x, want 0x%x", w, window)
}
}
}
// TCPFlags creates a checker that checks the tcp flags.
func TCPFlags(flags uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
if f := tcp.Flags(); f != flags {
t.Errorf("Bad flags, got 0x%x, want 0x%x", f, flags)
}
}
}
// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
// given mask, match the supplied flags.
func TCPFlagsMatch(flags, mask uint8) TransportChecker {
return func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
if f := tcp.Flags(); (f & mask) != (flags & mask) {
t.Errorf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
}
}
}
// TCPSynOptions creates a checker that checks the presence of TCP options in
// SYN segments.
//
// If wndscale is negative, the window scale option must not be present.
func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
return func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
opts := tcp.Options()
limit := len(opts)
foundMSS := false
foundWS := false
foundTS := false
foundSACKPermitted := false
tsVal := uint32(0)
tsEcr := uint32(0)
for i := 0; i < limit; {
switch opts[i] {
case header.TCPOptionEOL:
i = limit
case header.TCPOptionNOP:
i++
case header.TCPOptionMSS:
v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if wantOpts.MSS != v {
t.Errorf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
}
foundMSS = true
i += 4
case header.TCPOptionWS:
if wantOpts.WS < 0 {
t.Error("WS present when it shouldn't be")
}
v := int(opts[i+2])
if v != wantOpts.WS {
t.Errorf("Bad WS: got %v, want %v", v, wantOpts.WS)
}
foundWS = true
i += 3
case header.TCPOptionTS:
if i+9 >= limit {
t.Errorf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
}
if opts[i+1] != 10 {
t.Errorf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = uint32(0)
if tcp.Flags()&header.TCPFlagAck != 0 {
// If the syn is an SYN-ACK then read
// the tsEcr value as well.
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
}
foundTS = true
i += 10
case header.TCPOptionSACKPermitted:
if i+1 >= limit {
t.Errorf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
}
if opts[i+1] != 2 {
t.Errorf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
}
foundSACKPermitted = true
i += 2
default:
i += int(opts[i+1])
}
}
if !foundMSS {
t.Errorf("MSS option not found. Options: %x", opts)
}
if !foundWS && wantOpts.WS >= 0 {
t.Errorf("WS option not found. Options: %x", opts)
}
if wantOpts.TS && !foundTS {
t.Errorf("TS option not found. Options: %x", opts)
}
if foundTS && tsVal == 0 {
t.Error("TS option specified but the timestamp value is zero")
}
if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
t.Errorf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
}
if wantOpts.SACKPermitted && !foundSACKPermitted {
t.Errorf("SACKPermitted option not found. Options: %x", opts)
}
}
}
// TCPTimestampChecker creates a checker that validates that a TCP segment has a
// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
// wantTSEcr values with those in the TCP segment (if present).
//
// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
// skipped.
func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
return func(t *testing.T, h header.Transport) {
tcp, ok := h.(header.TCP)
if !ok {
return
}
opts := []byte(tcp.Options())
limit := len(opts)
foundTS := false
tsVal := uint32(0)
tsEcr := uint32(0)
for i := 0; i < limit; {
switch opts[i] {
case header.TCPOptionEOL:
i = limit
case header.TCPOptionNOP:
i++
case header.TCPOptionTS:
if i+9 >= limit {
t.Errorf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
}
if opts[i+1] != 10 {
t.Errorf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
}
tsVal = binary.BigEndian.Uint32(opts[i+2:])
tsEcr = binary.BigEndian.Uint32(opts[i+6:])
foundTS = true
i += 10
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return
}
l := int(opts[i+1])
if i < 2 || i+l > limit {
return
}
i += l
}
}
if wantTS != foundTS {
t.Errorf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
}
if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
t.Errorf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
}
if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
t.Errorf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
}
}
}
// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
// contain any SACK blocks in the TCP options.
func TCPNoSACKBlockChecker() TransportChecker {
return TCPSACKBlockChecker(nil)
}
// TCPSACKBlockChecker creates a checker that verifies that the segment does
// contain the specified SACK blocks in the TCP options.
func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
return func(t *testing.T, h header.Transport) {
t.Helper()
tcp, ok := h.(header.TCP)
if !ok {
return
}
var gotSACKBlocks []header.SACKBlock
opts := []byte(tcp.Options())
limit := len(opts)
for i := 0; i < limit; {
switch opts[i] {
case header.TCPOptionEOL:
i = limit
case header.TCPOptionNOP:
i++
case header.TCPOptionSACK:
if i+2 > limit {
// Malformed SACK block.
t.Errorf("malformed SACK option in options: %v", opts)
}
sackOptionLen := int(opts[i+1])
if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
// Malformed SACK block.
t.Errorf("malformed SACK option length in options: %v", opts)
}
numBlocks := sackOptionLen / 8
for j := 0; j < numBlocks; j++ {
start := binary.BigEndian.Uint32(opts[i+2+j*8:])
end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
Start: seqnum.Value(start),
End: seqnum.Value(end),
})
}
i += sackOptionLen
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
break
}
l := int(opts[i+1])
if l < 2 || i+l > limit {
break
}
i += l
}
}
if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
t.Errorf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
}
}
}
// Payload creates a checker that checks the payload.
func Payload(want []byte) TransportChecker {
return func(t *testing.T, h header.Transport) {
if got := h.Payload(); !reflect.DeepEqual(got, want) {
t.Errorf("Wrong payload, got %v, want %v", got, want)
}
}
}

View File

@@ -0,0 +1,107 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import "tcpip/netstack/tcpip"
const (
// ARPProtocolNumber是ARP协议号为0x0806
ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
// ARPSize是ARP报文在IPV4网络下的长度
ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
)
// ARPOp表示ARP的操作码
type ARPOp uint16
// RFC 826定义的操作码.
const (
// arp请求
ARPRequest ARPOp = 1
// arp应答
ARPReply ARPOp = 2
)
// ARP报文的封装
type ARP []byte
// 从报文中得到硬件类型
func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
// 从报文中得到协议类型
func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
// 从报文中得到硬件地址的长度
func (a ARP) hardwareAddressSize() int { return int(a[4]) }
// 从报文中得到协议的地址长度
func (a ARP) protocolAddressSize() int { return int(a[5]) }
// Op从报文中得到arp操作码.
func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
// SetOp设置arp操作码.
func (a ARP) SetOp(op ARPOp) {
a[6] = uint8(op >> 8)
a[7] = uint8(op)
}
// SetIPv4OverEthernet设置IPV4网络在以太网中arp报文的硬件和协议信息.
func (a ARP) SetIPv4OverEthernet() {
a[0], a[1] = 0, 1 // htypeEthernet
a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
a[4] = 6 // macSize
a[5] = uint8(IPv4AddressSize)
}
// HardwareAddressSender从报文中得到arp发送方的硬件地址
func (a ARP) HardwareAddressSender() []byte {
const s = 8
return a[s : s+6]
}
// ProtocolAddressSender从报文中得到arp发送方的协议地址为ipv4地址
func (a ARP) ProtocolAddressSender() []byte {
const s = 8 + 6
return a[s : s+4]
}
// HardwareAddressTarget从报文中得到arp目的方的硬件地址
func (a ARP) HardwareAddressTarget() []byte {
const s = 8 + 6 + 4
return a[s : s+6]
}
// ProtocolAddressTarget从报文中得到arp目的方的协议地址为ipv4地址
func (a ARP) ProtocolAddressTarget() []byte {
const s = 8 + 6 + 4 + 6
return a[s : s+4]
}
// IsValid检查arp报文是否有效
func (a ARP) IsValid() bool {
// 比arp报文的长度小返回无效
if len(a) < ARPSize {
return false
}
const htypeEthernet = 1
const macSize = 6
// 是否以太网、ipv4、硬件和协议长度都对
return a.hardwareAddressSpace() == htypeEthernet &&
a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
a.hardwareAddressSize() == macSize &&
a.protocolAddressSize() == IPv4AddressSize
}

View File

@@ -0,0 +1,57 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package header provides the implementation of the encoding and decoding of
// network protocol headers.
package header
import (
"tcpip/netstack/tcpip"
)
// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
// given byte array.
// 校验和的计算
func Checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
l := len(buf)
if l&1 != 0 {
l--
v += uint32(buf[l]) << 8
}
for i := 0; i < l; i += 2 {
v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
}
return ChecksumCombine(uint16(v), uint16(v>>16))
}
// ChecksumCombine combines the two uint16 to form their checksum. This is done
// by adding them and the carry.
func ChecksumCombine(a, b uint16) uint16 {
v := uint32(a) + uint32(b)
return uint16(v + v>>16)
}
// PseudoHeaderChecksum calculates the pseudo-header checksum for the
// given destination protocol and network address, ignoring the length
// field. Pseudo-headers are needed by transport layers when calculating
// their own checksum.
func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address) uint16 {
xsum := Checksum([]byte(srcAddr), 0)
xsum = Checksum([]byte(dstAddr), xsum)
return Checksum([]byte{0, uint8(protocol)}, xsum)
}

View File

@@ -0,0 +1,72 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"tcpip/netstack/tcpip"
)
const (
dstMAC = 0
srcMAC = 6
ethType = 12
)
// EthernetFields表示链路层以太网帧的头部
type EthernetFields struct {
// 源地址
SrcAddr tcpip.LinkAddress
// 目的地址
DstAddr tcpip.LinkAddress
// 协议类型
Type tcpip.NetworkProtocolNumber
}
// Ethernet以太网数据包的封装
type Ethernet []byte
const (
// EthernetMinimumSize以太网帧最小的长度
EthernetMinimumSize = 14
// EthernetAddressSize以太网地址的长度
EthernetAddressSize = 6
)
// SourceAddress从帧头部中得到源地址
func (b Ethernet) SourceAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
}
// DestinationAddress从帧头部中得到目的地址
func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
}
// Type从帧头部中得到协议类型
func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
}
// Encode根据传入的帧头部信息编码成Ethernet二进制形式注意Ethernet应先分配好内存
func (b Ethernet) Encode(e *EthernetFields) {
binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
}

View File

@@ -0,0 +1,73 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
const (
typeHLen = 0
encapProto = 1
)
// GUEFields contains the fields of a GUE packet. It is used to describe the
// fields of a packet that needs to be encoded.
type GUEFields struct {
// Type is the "type" field of the GUE header.
Type uint8
// Control is the "control" field of the GUE header.
Control bool
// HeaderLength is the "header length" field of the GUE header. It must
// be at least 4 octets, and a multiple of 4 as well.
HeaderLength uint8
// Protocol is the "protocol" field of the GUE header. This is one of
// the IPPROTO_* values.
Protocol uint8
}
// GUE represents a Generic UDP Encapsulation header stored in a byte array, the
// fields are described in https://tools.ietf.org/html/draft-ietf-nvo3-gue-01.
type GUE []byte
const (
// GUEMinimumSize is the minimum size of a valid GUE packet.
GUEMinimumSize = 4
)
// TypeAndControl returns the GUE packet type (top 3 bits of the first byte,
// which includes the control bit).
func (b GUE) TypeAndControl() uint8 {
return b[typeHLen] >> 5
}
// HeaderLength returns the total length of the GUE header.
func (b GUE) HeaderLength() uint8 {
return 4 + 4*(b[typeHLen]&0x1f)
}
// Protocol returns the protocol field of the GUE header.
func (b GUE) Protocol() uint8 {
return b[encapProto]
}
// Encode encodes all the fields of the GUE header.
func (b GUE) Encode(i *GUEFields) {
ctl := uint8(0)
if i.Control {
ctl = 1 << 5
}
b[typeHLen] = ctl | i.Type<<6 | (i.HeaderLength-4)/4
b[encapProto] = i.Protocol
}

View File

@@ -0,0 +1,108 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"tcpip/netstack/tcpip"
)
// ICMPv4 represents an ICMPv4 header stored in a byte array.
type ICMPv4 []byte
const (
// ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
ICMPv4MinimumSize = 4
// ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv4EchoMinimumSize = 6
// ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4
// ICMPv4ProtocolNumber is the ICMP transport protocol number.
ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
)
// ICMPv4Type is the ICMP type field described in RFC 792.
type ICMPv4Type byte
// Typical values of ICMPv4Type defined in RFC 792.
const (
ICMPv4EchoReply ICMPv4Type = 0
ICMPv4DstUnreachable ICMPv4Type = 3
ICMPv4SrcQuench ICMPv4Type = 4
ICMPv4Redirect ICMPv4Type = 5
ICMPv4Echo ICMPv4Type = 8
ICMPv4TimeExceeded ICMPv4Type = 11
ICMPv4ParamProblem ICMPv4Type = 12
ICMPv4Timestamp ICMPv4Type = 13
ICMPv4TimestampReply ICMPv4Type = 14
ICMPv4InfoRequest ICMPv4Type = 15
ICMPv4InfoReply ICMPv4Type = 16
)
// Values for ICMP code as defined in RFC 792.
const (
ICMPv4PortUnreachable = 3
ICMPv4FragmentationNeeded = 4
)
// Type is the ICMP type field.
func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
// SetType sets the ICMP type field.
func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
func (b ICMPv4) Code() byte { return b[1] }
// SetCode sets the ICMP code field.
func (b ICMPv4) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[2:])
}
// SetChecksum sets the ICMP checksum field.
func (b ICMPv4) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[2:], checksum)
}
// SourcePort implements Transport.SourcePort.
func (ICMPv4) SourcePort() uint16 {
return 0
}
// DestinationPort implements Transport.DestinationPort.
func (ICMPv4) DestinationPort() uint16 {
return 0
}
// SetSourcePort implements Transport.SetSourcePort.
func (ICMPv4) SetSourcePort(uint16) {
}
// SetDestinationPort implements Transport.SetDestinationPort.
func (ICMPv4) SetDestinationPort(uint16) {
}
// Payload implements Transport.Payload.
func (b ICMPv4) Payload() []byte {
return b[ICMPv4MinimumSize:]
}

View File

@@ -0,0 +1,121 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"tcpip/netstack/tcpip"
)
// ICMPv6 represents an ICMPv6 header stored in a byte array.
type ICMPv6 []byte
const (
// ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
ICMPv6MinimumSize = 4
// ICMPv6ProtocolNumber is the ICMP transport protocol number.
ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
// ICMPv6NeighborSolicitMinimumSize is the minimum size of a
// neighbor solicitation packet.
ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
// ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
ICMPv6NeighborAdvertSize = 32
// ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
ICMPv6EchoMinimumSize = 8
// ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
// destination unreachable packet.
ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
// ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
// packet-too-big packet.
ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
)
// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
type ICMPv6Type byte
// Typical values of ICMPv6Type defined in RFC 4443.
const (
ICMPv6DstUnreachable ICMPv6Type = 1
ICMPv6PacketTooBig ICMPv6Type = 2
ICMPv6TimeExceeded ICMPv6Type = 3
ICMPv6ParamProblem ICMPv6Type = 4
ICMPv6EchoRequest ICMPv6Type = 128
ICMPv6EchoReply ICMPv6Type = 129
// Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
ICMPv6RouterSolicit ICMPv6Type = 133
ICMPv6RouterAdvert ICMPv6Type = 134
ICMPv6NeighborSolicit ICMPv6Type = 135
ICMPv6NeighborAdvert ICMPv6Type = 136
ICMPv6RedirectMsg ICMPv6Type = 137
)
// Values for ICMP code as defined in RFC 4443.
const (
ICMPv6PortUnreachable = 4
)
// Type is the ICMP type field.
func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
// SetType sets the ICMP type field.
func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
// Code is the ICMP code field. Its meaning depends on the value of Type.
func (b ICMPv6) Code() byte { return b[1] }
// SetCode sets the ICMP code field.
func (b ICMPv6) SetCode(c byte) { b[1] = c }
// Checksum is the ICMP checksum field.
func (b ICMPv6) Checksum() uint16 {
return binary.BigEndian.Uint16(b[2:])
}
// SetChecksum calculates and sets the ICMP checksum field.
func (b ICMPv6) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[2:], checksum)
}
// SourcePort implements Transport.SourcePort.
func (ICMPv6) SourcePort() uint16 {
return 0
}
// DestinationPort implements Transport.DestinationPort.
func (ICMPv6) DestinationPort() uint16 {
return 0
}
// SetSourcePort implements Transport.SetSourcePort.
func (ICMPv6) SetSourcePort(uint16) {
}
// SetDestinationPort implements Transport.SetDestinationPort.
func (ICMPv6) SetDestinationPort(uint16) {
}
// Payload implements Transport.Payload.
func (b ICMPv6) Payload() []byte {
return b[ICMPv6MinimumSize:]
}

View File

@@ -0,0 +1,92 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"tcpip/netstack/tcpip"
)
const (
// MaxIPPacketSize is the maximum supported IP packet size, excluding
// jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
// in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
// 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
// m is the minimum IPv6 header size; we leave room for some potential
// IP options.
MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
)
// Transport offers generic methods to query and/or update the fields of the
// header of a transport protocol buffer.
type Transport interface {
// SourcePort returns the value of the "source port" field.
SourcePort() uint16
// Destination returns the value of the "destination port" field.
DestinationPort() uint16
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourcePort sets the value of the "source port" field.
SetSourcePort(uint16)
// SetDestinationPort sets the value of the "destination port" field.
SetDestinationPort(uint16)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// Payload returns the data carried in the transport buffer.
Payload() []byte
}
// Network offers generic methods to query and/or update the fields of the
// header of a network protocol buffer.
type Network interface {
// SourceAddress returns the value of the "source address" field.
SourceAddress() tcpip.Address
// DestinationAddress returns the value of the "destination address"
// field.
DestinationAddress() tcpip.Address
// Checksum returns the value of the "checksum" field.
Checksum() uint16
// SetSourceAddress sets the value of the "source address" field.
SetSourceAddress(tcpip.Address)
// SetDestinationAddress sets the value of the "destination address"
// field.
SetDestinationAddress(tcpip.Address)
// SetChecksum sets the value of the "checksum" field.
SetChecksum(uint16)
// TransportProtocol returns the number of the transport protocol
// stored in the payload.
TransportProtocol() tcpip.TransportProtocolNumber
// Payload returns a byte slice containing the payload of the network
// packet.
Payload() []byte
// TOS returns the values of the "type of service" and "flow label" fields.
TOS() (uint8, uint32)
// SetTOS sets the values of the "type of service" and "flow label" fields.
SetTOS(t uint8, l uint32)
}

View File

@@ -0,0 +1,289 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"tcpip/netstack/tcpip"
)
const (
versIHL = 0
tos = 1
totalLen = 2
id = 4
flagsFO = 6
ttl = 8
protocol = 9
checksum = 10
srcAddr = 12
dstAddr = 16
)
// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
// fields of a packet that needs to be encoded.
// 表示IPv4头部信息的结构体
type IPv4Fields struct {
// IHL is the "internet header length" field of an IPv4 packet.
// 头部长度
IHL uint8
// TOS is the "type of service" field of an IPv4 packet.
// 服务区分的表示
TOS uint8
// TotalLength is the "total length" field of an IPv4 packet.
// 数据报文总长
TotalLength uint16
// ID is the "identification" field of an IPv4 packet.
// 标识符
ID uint16
// Flags is the "flags" field of an IPv4 packet.
// 标签
Flags uint8
// FragmentOffset is the "fragment offset" field of an IPv4 packet.
// 分片偏移
FragmentOffset uint16
// TTL is the "time to live" field of an IPv4 packet.
// 存活时间
TTL uint8
// Protocol is the "protocol" field of an IPv4 packet.
// 表示的传输层协议
Protocol uint8
// Checksum is the "checksum" field of an IPv4 packet.
// 首部校验和
Checksum uint16
// SrcAddr is the "source ip address" of an IPv4 packet.
// 源IP地址
SrcAddr tcpip.Address
// DstAddr is the "destination ip address" of an IPv4 packet.
// 目的IP地址
DstAddr tcpip.Address
}
// IPv4 represents an ipv4 header stored in a byte array.
// Most of the methods of IPv4 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv4 before using other methods.
type IPv4 []byte
const (
// IPv4MinimumSize is the minimum size of a valid IPv4 packet.
IPv4MinimumSize = 20
// IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
// that there are only 4 bits to represents the header length in 32-bit
// units, the header cannot exceed 15*4 = 60 bytes.
IPv4MaximumHeaderSize = 60
// IPv4AddressSize is the size, in bytes, of an IPv4 address.
IPv4AddressSize = 4
// IPv4ProtocolNumber is IPv4's network protocol number.
IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
// IPv4Version is the version of the ipv4 protocol.
IPv4Version = 4
// IPv4Broadcast is the broadcast address of the IPv4 procotol.
IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff"
// IPv4Any is the non-routable IPv4 "any" meta address.
IPv4Any tcpip.Address = "\x00\x00\x00\x00"
)
// Flags that may be set in an IPv4 packet.
const (
IPv4FlagMoreFragments = 1 << iota
IPv4FlagDontFragment
)
// IPVersion returns the version of IP used in the given packet. It returns -1
// if the packet is not large enough to contain the version field.
func IPVersion(b []byte) int {
// Length must be at least offset+length of version field.
if len(b) < versIHL+1 {
return -1
}
return int(b[versIHL] >> 4)
}
// HeaderLength returns the value of the "header length" field of the ipv4
// header.
func (b IPv4) HeaderLength() uint8 {
return (b[versIHL] & 0xf) * 4
}
// ID returns the value of the identifier field of the ipv4 header.
func (b IPv4) ID() uint16 {
return binary.BigEndian.Uint16(b[id:])
}
// Protocol returns the value of the protocol field of the ipv4 header.
func (b IPv4) Protocol() uint8 {
return b[protocol]
}
// Flags returns the "flags" field of the ipv4 header.
func (b IPv4) Flags() uint8 {
return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
}
// TTL returns the "TTL" field of the ipv4 header.
func (b IPv4) TTL() uint8 {
return b[ttl]
}
// FragmentOffset returns the "fragment offset" field of the ipv4 header.
func (b IPv4) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[flagsFO:]) << 3
}
// TotalLength returns the "total length" field of the ipv4 header.
func (b IPv4) TotalLength() uint16 {
return binary.BigEndian.Uint16(b[totalLen:])
}
// Checksum returns the checksum field of the ipv4 header.
func (b IPv4) Checksum() uint16 {
return binary.BigEndian.Uint16(b[checksum:])
}
// SourceAddress returns the "source address" field of the ipv4 header.
func (b IPv4) SourceAddress() tcpip.Address {
return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv4
// header.
func (b IPv4) DestinationAddress() tcpip.Address {
return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.Protocol())
}
// Payload implements Network.Payload.
func (b IPv4) Payload() []byte {
return b[b.HeaderLength():][:b.PayloadLength()]
}
// PayloadLength returns the length of the payload portion of the ipv4 packet.
func (b IPv4) PayloadLength() uint16 {
return b.TotalLength() - uint16(b.HeaderLength())
}
// TOS returns the "type of service" field of the ipv4 header.
func (b IPv4) TOS() (uint8, uint32) {
return b[tos], 0
}
// SetTOS sets the "type of service" field of the ipv4 header.
func (b IPv4) SetTOS(v uint8, _ uint32) {
b[tos] = v
}
// SetTotalLength sets the "total length" field of the ipv4 header.
func (b IPv4) SetTotalLength(totalLength uint16) {
binary.BigEndian.PutUint16(b[totalLen:], totalLength)
}
// SetChecksum sets the checksum field of the ipv4 header.
func (b IPv4) SetChecksum(v uint16) {
binary.BigEndian.PutUint16(b[checksum:], v)
}
// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
// ipv4 header.
func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
v := (uint16(flags) << 13) | (offset >> 3)
binary.BigEndian.PutUint16(b[flagsFO:], v)
}
// SetSourceAddress sets the "source address" field of the ipv4 header.
func (b IPv4) SetSourceAddress(addr tcpip.Address) {
copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv4
// header.
func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
}
// CalculateChecksum calculates the checksum of the ipv4 header.
func (b IPv4) CalculateChecksum() uint16 {
return Checksum(b[:b.HeaderLength()], 0)
}
// Encode encodes all the fields of the ipv4 header.
func (b IPv4) Encode(i *IPv4Fields) {
b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
b[tos] = i.TOS
b.SetTotalLength(i.TotalLength)
binary.BigEndian.PutUint16(b[id:], i.ID)
b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
b[ttl] = i.TTL
b[protocol] = i.Protocol
b.SetChecksum(i.Checksum)
copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
}
// EncodePartial updates the total length and checksum fields of ipv4 header,
// taking in the partial checksum, which is the checksum of the header without
// the total length and checksum fields. It is useful in cases when similar
// packets are produced.
func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
b.SetTotalLength(totalLength)
checksum := Checksum(b[totalLen:totalLen+2], partialChecksum)
b.SetChecksum(^checksum)
}
// IsValid performs basic validation on the packet.
func (b IPv4) IsValid(pktSize int) bool {
if len(b) < IPv4MinimumSize {
return false
}
hlen := int(b.HeaderLength())
tlen := int(b.TotalLength())
if hlen > tlen || tlen > pktSize {
return false
}
return true
}
// IsV4MulticastAddress determines if the provided address is an IPv4 multicast
// address (range 224.0.0.0 to 239.255.255.255). The four most significant bits
// will be 1110 = 0xe0.
func IsV4MulticastAddress(addr tcpip.Address) bool {
if len(addr) != IPv4AddressSize {
return false
}
return (addr[0] & 0xf0) == 0xe0
}

View File

@@ -0,0 +1,236 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"strings"
"tcpip/netstack/tcpip"
)
const (
versTCFL = 0
payloadLen = 4
nextHdr = 6
hopLimit = 7
v6SrcAddr = 8
v6DstAddr = 24
)
// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv6Fields struct {
// TrafficClass is the "traffic class" field of an IPv6 packet.
TrafficClass uint8
// FlowLabel is the "flow label" field of an IPv6 packet.
FlowLabel uint32
// PayloadLength is the "payload length" field of an IPv6 packet.
PayloadLength uint16
// NextHeader is the "next header" field of an IPv6 packet.
NextHeader uint8
// HopLimit is the "hop limit" field of an IPv6 packet.
HopLimit uint8
// SrcAddr is the "source ip address" of an IPv6 packet.
SrcAddr tcpip.Address
// DstAddr is the "destination ip address" of an IPv6 packet.
DstAddr tcpip.Address
}
// IPv6 represents an ipv6 header stored in a byte array.
// Most of the methods of IPv6 access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv6 before using other methods.
type IPv6 []byte
const (
// IPv6MinimumSize is the minimum size of a valid IPv6 packet.
IPv6MinimumSize = 40
// IPv6AddressSize is the size, in bytes, of an IPv6 address.
IPv6AddressSize = 16
// IPv6ProtocolNumber is IPv6's network protocol number.
IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
// IPv6Version is the version of the ipv6 protocol.
IPv6Version = 6
// IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
// section 5.
IPv6MinimumMTU = 1280
)
// PayloadLength returns the value of the "payload length" field of the ipv6
// header.
func (b IPv6) PayloadLength() uint16 {
return binary.BigEndian.Uint16(b[payloadLen:])
}
// HopLimit returns the value of the "hop limit" field of the ipv6 header.
func (b IPv6) HopLimit() uint8 {
return b[hopLimit]
}
// NextHeader returns the value of the "next header" field of the ipv6 header.
func (b IPv6) NextHeader() uint8 {
return b[nextHdr]
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.NextHeader())
}
// Payload implements Network.Payload.
func (b IPv6) Payload() []byte {
return b[IPv6MinimumSize:][:b.PayloadLength()]
}
// SourceAddress returns the "source address" field of the ipv6 header.
func (b IPv6) SourceAddress() tcpip.Address {
return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
}
// DestinationAddress returns the "destination address" field of the ipv6
// header.
func (b IPv6) DestinationAddress() tcpip.Address {
return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
}
// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
// checksum, it just returns 0.
func (IPv6) Checksum() uint16 {
return 0
}
// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
func (b IPv6) TOS() (uint8, uint32) {
v := binary.BigEndian.Uint32(b[versTCFL:])
return uint8(v >> 20), v & 0xfffff
}
// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
func (b IPv6) SetTOS(t uint8, l uint32) {
vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
binary.BigEndian.PutUint32(b[versTCFL:], vtf)
}
// SetPayloadLength sets the "payload length" field of the ipv6 header.
func (b IPv6) SetPayloadLength(payloadLength uint16) {
binary.BigEndian.PutUint16(b[payloadLen:], payloadLength)
}
// SetSourceAddress sets the "source address" field of the ipv6 header.
func (b IPv6) SetSourceAddress(addr tcpip.Address) {
copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
}
// SetDestinationAddress sets the "destination address" field of the ipv6
// header.
func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
}
// SetNextHeader sets the value of the "next header" field of the ipv6 header.
func (b IPv6) SetNextHeader(v uint8) {
b[nextHdr] = v
}
// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
// checksum, it is empty.
func (IPv6) SetChecksum(uint16) {
}
// Encode encodes all the fields of the ipv6 header.
func (b IPv6) Encode(i *IPv6Fields) {
b.SetTOS(i.TrafficClass, i.FlowLabel)
b.SetPayloadLength(i.PayloadLength)
b[nextHdr] = i.NextHeader
b[hopLimit] = i.HopLimit
copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
}
// IsValid performs basic validation on the packet.
func (b IPv6) IsValid(pktSize int) bool {
if len(b) < IPv6MinimumSize {
return false
}
dlen := int(b.PayloadLength())
if dlen > pktSize-IPv6MinimumSize {
return false
}
return true
}
// IsV4MappedAddress determines if the provided address is an IPv4 mapped
// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
func IsV4MappedAddress(addr tcpip.Address) bool {
if len(addr) != IPv6AddressSize {
return false
}
return strings.HasPrefix(string(addr), "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff")
}
// IsV6MulticastAddress determines if the provided address is an IPv6
// multicast address (anything starting with FF).
func IsV6MulticastAddress(addr tcpip.Address) bool {
if len(addr) != IPv6AddressSize {
return false
}
return addr[0] == 0xff
}
// SolicitedNodeAddr computes the solicited-node multicast address. This is
// used for NDP. Described in RFC 4291. The argument must be a full-length IPv6
// address.
func SolicitedNodeAddr(addr tcpip.Address) tcpip.Address {
const solicitedNodeMulticastPrefix = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\xff"
return solicitedNodeMulticastPrefix + addr[len(addr)-3:]
}
// LinkLocalAddr computes the default IPv6 link-local address from a link-layer
// (MAC) address.
func LinkLocalAddr(linkAddr tcpip.LinkAddress) tcpip.Address {
// Convert a 48-bit MAC to an EUI-64 and then prepend the link-local
// header, FE80::.
//
// The conversion is very nearly:
// aa:bb:cc:dd:ee:ff => FE80::Aabb:ccFF:FEdd:eeff
// Note the capital A. The conversion aa->Aa involves a bit flip.
lladdrb := [16]byte{
0: 0xFE,
1: 0x80,
8: linkAddr[0] ^ 2,
9: linkAddr[1],
10: linkAddr[2],
11: 0xFF,
12: 0xFE,
13: linkAddr[3],
14: linkAddr[4],
15: linkAddr[5],
}
return tcpip.Address(lladdrb[:])
}

View File

@@ -0,0 +1,146 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"tcpip/netstack/tcpip"
)
const (
nextHdrFrag = 0
fragOff = 2
more = 3
idV6 = 4
)
// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
// fields of a packet that needs to be encoded.
type IPv6FragmentFields struct {
// NextHeader is the "next header" field of an IPv6 fragment.
NextHeader uint8
// FragmentOffset is the "fragment offset" field of an IPv6 fragment.
FragmentOffset uint16
// M is the "more" field of an IPv6 fragment.
M bool
// Identification is the "identification" field of an IPv6 fragment.
Identification uint32
}
// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
// Most of the methods of IPv6Fragment access to the underlying slice without
// checking the boundaries and could panic because of 'index out of range'.
// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
type IPv6Fragment []byte
const (
// IPv6FragmentHeader header is the number used to specify that the next
// header is a fragment header, per RFC 2460.
IPv6FragmentHeader = 44
// IPv6FragmentHeaderSize is the size of the fragment header.
IPv6FragmentHeaderSize = 8
)
// Encode encodes all the fields of the ipv6 fragment.
func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
b[nextHdrFrag] = i.NextHeader
binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
if i.M {
b[more] |= 1
}
binary.BigEndian.PutUint32(b[idV6:], i.Identification)
}
// IsValid performs basic validation on the fragment header.
func (b IPv6Fragment) IsValid() bool {
return len(b) >= IPv6FragmentHeaderSize
}
// NextHeader returns the value of the "next header" field of the ipv6 fragment.
func (b IPv6Fragment) NextHeader() uint8 {
return b[nextHdrFrag]
}
// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
func (b IPv6Fragment) FragmentOffset() uint16 {
return binary.BigEndian.Uint16(b[fragOff:]) >> 3
}
// More returns the "more" field of the ipv6 fragment.
func (b IPv6Fragment) More() bool {
return b[more]&1 > 0
}
// Payload implements Network.Payload.
func (b IPv6Fragment) Payload() []byte {
return b[IPv6FragmentHeaderSize:]
}
// ID returns the value of the identifier field of the ipv6 fragment.
func (b IPv6Fragment) ID() uint32 {
return binary.BigEndian.Uint32(b[idV6:])
}
// TransportProtocol implements Network.TransportProtocol.
func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
return tcpip.TransportProtocolNumber(b.NextHeader())
}
// The functions below have been added only to satisfy the Network interface.
// Checksum is not supported by IPv6Fragment.
func (b IPv6Fragment) Checksum() uint16 {
panic("not supported")
}
// SourceAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SourceAddress() tcpip.Address {
panic("not supported")
}
// DestinationAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) DestinationAddress() tcpip.Address {
panic("not supported")
}
// SetSourceAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
panic("not supported")
}
// SetDestinationAddress is not supported by IPv6Fragment.
func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
panic("not supported")
}
// SetChecksum is not supported by IPv6Fragment.
func (b IPv6Fragment) SetChecksum(uint16) {
panic("not supported")
}
// TOS is not supported by IPv6Fragment.
func (b IPv6Fragment) TOS() (uint8, uint32) {
panic("not supported")
}
// SetTOS is not supported by IPv6Fragment.
func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
panic("not supported")
}

View File

@@ -0,0 +1,67 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header_test
import (
"testing"
"tcpip/netstack/tcpip/header"
)
func TestIPv4(t *testing.T) {
b := header.IPv4(make([]byte, header.IPv4MinimumSize))
b.Encode(&header.IPv4Fields{})
const want = header.IPv4Version
if v := header.IPVersion(b); v != want {
t.Fatalf("Bad version, want %v, got %v", want, v)
}
}
func TestIPv6(t *testing.T) {
b := header.IPv6(make([]byte, header.IPv6MinimumSize))
b.Encode(&header.IPv6Fields{})
const want = header.IPv6Version
if v := header.IPVersion(b); v != want {
t.Fatalf("Bad version, want %v, got %v", want, v)
}
}
func TestOtherVersion(t *testing.T) {
const want = header.IPv4Version + header.IPv6Version
b := make([]byte, 1)
b[0] = want << 4
if v := header.IPVersion(b); v != want {
t.Fatalf("Bad version, want %v, got %v", want, v)
}
}
func TestTooShort(t *testing.T) {
b := make([]byte, 1)
b[0] = (header.IPv4Version + header.IPv6Version) << 4
// Get the version of a zero-length slice.
const want = -1
if v := header.IPVersion(b[:0]); v != want {
t.Fatalf("Bad version, want %v, got %v", want, v)
}
// Get the version of a nil slice.
if v := header.IPVersion(nil); v != want {
t.Fatalf("Bad version, want %v, got %v", want, v)
}
}

View File

@@ -0,0 +1,536 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/seqnum"
)
const (
srcPort = 0
dstPort = 2
seqNum = 4
ackNum = 8
dataOffset = 12
tcpFlags = 13
winSize = 14
tcpChecksum = 16
urgentPtr = 18
)
const (
// MaxWndScale is maximum allowed window scaling, as described in
// RFC 1323, section 2.3, page 11.
MaxWndScale = 14
// TCPMaxSACKBlocks is the maximum number of SACK blocks that can
// be encoded in a TCP option field.
TCPMaxSACKBlocks = 4
)
// Flags that may be set in a TCP segment.
const (
TCPFlagFin = 1 << iota
TCPFlagSyn
TCPFlagRst
TCPFlagPsh
TCPFlagAck
TCPFlagUrg
)
// Options that may be present in a TCP segment.
const (
TCPOptionEOL = 0
TCPOptionNOP = 1
TCPOptionMSS = 2
TCPOptionWS = 3
TCPOptionTS = 8
TCPOptionSACKPermitted = 4
TCPOptionSACK = 5
)
// TCPFields contains the fields of a TCP packet. It is used to describe the
// fields of a packet that needs to be encoded.
// tcp首部字段
type TCPFields struct {
// SrcPort is the "source port" field of a TCP packet.
SrcPort uint16
// DstPort is the "destination port" field of a TCP packet.
DstPort uint16
// SeqNum is the "sequence number" field of a TCP packet.
SeqNum uint32
// AckNum is the "acknowledgement number" field of a TCP packet.
AckNum uint32
// DataOffset is the "data offset" field of a TCP packet.
DataOffset uint8
// Flags is the "flags" field of a TCP packet.
Flags uint8
// WindowSize is the "window size" field of a TCP packet.
WindowSize uint16
// Checksum is the "checksum" field of a TCP packet.
Checksum uint16
// UrgentPointer is the "urgent pointer" field of a TCP packet.
UrgentPointer uint16
}
// TCPSynOptions is used to return the parsed TCP Options in a syn
// segment.
// syn 报文的选项
type TCPSynOptions struct {
// MSS is the maximum segment size provided by the peer in the SYN.
MSS uint16
// WS is the window scale option provided by the peer in the SYN.
//
// Set to -1 if no window scale option was provided.
WS int
// TS is true if the timestamp option was provided in the syn/syn-ack.
TS bool
// TSVal is the value of the TSVal field in the timestamp option.
TSVal uint32
// TSEcr is the value of the TSEcr field in the timestamp option.
TSEcr uint32
// SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
SACKPermitted bool
}
// SACKBlock represents a single contiguous SACK block.
//
// +stateify savable
// 表示 sack 块的结构体
type SACKBlock struct {
// Start indicates the lowest sequence number in the block.
Start seqnum.Value
// End indicates the sequence number immediately following the last
// sequence number of this block.
End seqnum.Value
}
// TCPOptions are used to parse and cache the TCP segment options for a non
// syn/syn-ack segment.
//
// +stateify savable
// tcp选项结构这个结构不表示 syn/syn-ack 报文
type TCPOptions struct {
// TS is true if the TimeStamp option is enabled.
TS bool
// TSVal is the value in the TSVal field of the segment.
TSVal uint32
// TSEcr is the value in the TSEcr field of the segment.
TSEcr uint32
// SACKBlocks are the SACK blocks specified in the segment.
SACKBlocks []SACKBlock
}
// TCP represents a TCP header stored in a byte array.
type TCP []byte
const (
// TCPMinimumSize is the minimum size of a valid TCP packet.
TCPMinimumSize = 20
// TCPProtocolNumber is TCP's transport protocol number.
TCPProtocolNumber tcpip.TransportProtocolNumber = 6
)
// SourcePort returns the "source port" field of the tcp header.
func (b TCP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[srcPort:])
}
// DestinationPort returns the "destination port" field of the tcp header.
func (b TCP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[dstPort:])
}
// SequenceNumber returns the "sequence number" field of the tcp header.
func (b TCP) SequenceNumber() uint32 {
return binary.BigEndian.Uint32(b[seqNum:])
}
// AckNumber returns the "ack number" field of the tcp header.
func (b TCP) AckNumber() uint32 {
return binary.BigEndian.Uint32(b[ackNum:])
}
// DataOffset returns the "data offset" field of the tcp header.
func (b TCP) DataOffset() uint8 {
return (b[dataOffset] >> 4) * 4
}
// Payload returns the data in the tcp packet.
func (b TCP) Payload() []byte {
return b[b.DataOffset():]
}
// Flags returns the flags field of the tcp header.
func (b TCP) Flags() uint8 {
return b[tcpFlags]
}
// WindowSize returns the "window size" field of the tcp header.
func (b TCP) WindowSize() uint16 {
return binary.BigEndian.Uint16(b[winSize:])
}
// Checksum returns the "checksum" field of the tcp header.
func (b TCP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[tcpChecksum:])
}
// SetSourcePort sets the "source port" field of the tcp header.
func (b TCP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[srcPort:], port)
}
// SetDestinationPort sets the "destination port" field of the tcp header.
func (b TCP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[dstPort:], port)
}
// SetChecksum sets the checksum field of the tcp header.
func (b TCP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[tcpChecksum:], checksum)
}
// CalculateChecksum calculates the checksum of the tcp segment given
// the totalLen and partialChecksum(descriptions below)
// totalLen is the total length of the segment
// partialChecksum is the checksum of the network-layer pseudo-header
// (excluding the total length) and the checksum of the segment data.
func (b TCP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
// Add the length portion of the checksum to the pseudo-checksum.
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
checksum := Checksum(tmp, partialChecksum)
// Calculate the rest of the checksum.
return Checksum(b[:b.DataOffset()], checksum)
}
// Options returns a slice that holds the unparsed TCP options in the segment.
func (b TCP) Options() []byte {
return b[TCPMinimumSize:b.DataOffset()]
}
// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
// option values in the TCP segment. NOTE: Invoking this function repeatedly is
// expensive as it reparses the options on each invocation.
func (b TCP) ParsedOptions() TCPOptions {
return ParseTCPOptions(b.Options())
}
func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) {
binary.BigEndian.PutUint32(b[seqNum:], seq)
binary.BigEndian.PutUint32(b[ackNum:], ack)
b[tcpFlags] = flags
binary.BigEndian.PutUint16(b[winSize:], rcvwnd)
}
// Encode encodes all the fields of the tcp header.
func (b TCP) Encode(t *TCPFields) {
b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
binary.BigEndian.PutUint16(b[srcPort:], t.SrcPort)
binary.BigEndian.PutUint16(b[dstPort:], t.DstPort)
b[dataOffset] = (t.DataOffset / 4) << 4
binary.BigEndian.PutUint16(b[tcpChecksum:], t.Checksum)
binary.BigEndian.PutUint16(b[urgentPtr:], t.UrgentPointer)
}
// EncodePartial updates a subset of the fields of the tcp header. It is useful
// in cases when similar segments are produced.
func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) {
// Add the total length and "flags" field contributions to the checksum.
// We don't use the flags field directly from the header because it's a
// one-byte field with an odd offset, so it would be accounted for
// incorrectly by the Checksum routine.
tmp := make([]byte, 4)
binary.BigEndian.PutUint16(tmp, length)
binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
checksum := Checksum(tmp, partialChecksum)
// Encode the passed-in fields.
b.encodeSubset(seqnum, acknum, flags, rcvwnd)
// Add the contributions of the passed-in fields to the checksum.
checksum = Checksum(b[seqNum:seqNum+8], checksum)
checksum = Checksum(b[winSize:winSize+2], checksum)
// Encode the checksum.
b.SetChecksum(^checksum)
}
// ParseSynOptions parses the options received in a SYN segment and returns the
// relevant ones. opts should point to the option part of the TCP Header.
func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
limit := len(opts)
synOpts := TCPSynOptions{
// Per RFC 1122, page 85: "If an MSS option is not received at
// connection setup, TCP MUST assume a default send MSS of 536."
MSS: 536,
// If no window scale option is specified, WS in options is
// returned as -1; this is because the absence of the option
// indicates that the we cannot use window scaling on the
// receive end either.
WS: -1,
}
for i := 0; i < limit; {
switch opts[i] {
case TCPOptionEOL:
i = limit
case TCPOptionNOP:
i++
case TCPOptionMSS:
if i+4 > limit || opts[i+1] != 4 {
return synOpts
}
mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
if mss == 0 {
return synOpts
}
synOpts.MSS = mss
i += 4
case TCPOptionWS:
if i+3 > limit || opts[i+1] != 3 {
return synOpts
}
ws := int(opts[i+2])
if ws > MaxWndScale {
ws = MaxWndScale
}
synOpts.WS = ws
i += 3
case TCPOptionTS:
if i+10 > limit || opts[i+1] != 10 {
return synOpts
}
synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
if isAck {
// If the segment is a SYN-ACK then store the Timestamp Echo Reply
// in the segment.
synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
}
synOpts.TS = true
i += 10
case TCPOptionSACKPermitted:
if i+2 > limit || opts[i+1] != 2 {
return synOpts
}
synOpts.SACKPermitted = true
i += 2
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return synOpts
}
l := int(opts[i+1])
// If the length is incorrect or if l+i overflows the
// total options length then return false.
if l < 2 || i+l > limit {
return synOpts
}
i += l
}
}
return synOpts
}
// ParseTCPOptions extracts and stores all known options in the provided byte
// slice in a TCPOptions structure.
func ParseTCPOptions(b []byte) TCPOptions {
opts := TCPOptions{}
limit := len(b)
for i := 0; i < limit; {
switch b[i] {
case TCPOptionEOL:
i = limit
case TCPOptionNOP:
i++
case TCPOptionTS:
if i+10 > limit || (b[i+1] != 10) {
return opts
}
opts.TS = true
opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
i += 10
case TCPOptionSACK:
if i+2 > limit {
// Malformed SACK block, just return and stop parsing.
return opts
}
sackOptionLen := int(b[i+1])
if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
// Malformed SACK block, just return and stop parsing.
return opts
}
numBlocks := (sackOptionLen - 2) / 8
opts.SACKBlocks = []SACKBlock{}
for j := 0; j < numBlocks; j++ {
start := binary.BigEndian.Uint32(b[i+2+j*8:])
end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
Start: seqnum.Value(start),
End: seqnum.Value(end),
})
}
i += sackOptionLen
default:
// We don't recognize this option, just skip over it.
if i+2 > limit {
return opts
}
l := int(b[i+1])
// If the length is incorrect or if l+i overflows the
// total options length then return false.
if l < 2 || i+l > limit {
return opts
}
i += l
}
}
return opts
}
// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
// the supplied buffer. If the provided buffer is not large enough then it just
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeMSSOption(mss uint32, b []byte) int {
// mssOptionSize is the number of bytes in a valid MSS option.
const mssOptionSize = 4
if len(b) < mssOptionSize {
return 0
}
b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss)
return mssOptionSize
}
// EncodeWSOption encodes the WS TCP option with the WS value in the
// provided buffer. If the provided buffer is not large enough then it just
// returns without encoding anything. It returns the number of bytes written to
// the provided buffer.
func EncodeWSOption(ws int, b []byte) int {
if len(b) < 3 {
return 0
}
b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws)
return int(b[1])
}
// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
// option into the provided buffer. If the buffer is smaller than expected it
// just returns without encoding anything. It returns the number of bytes
// written to the provided buffer.
func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
if len(b) < 10 {
return 0
}
b[0], b[1] = TCPOptionTS, 10
binary.BigEndian.PutUint32(b[2:], tsVal)
binary.BigEndian.PutUint32(b[6:], tsEcr)
return int(b[1])
}
// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
// buffer. If the buffer is smaller than required it just returns without
// encoding anything. It returns the number of bytes written to the provided
// buffer.
func EncodeSACKPermittedOption(b []byte) int {
if len(b) < 2 {
return 0
}
b[0], b[1] = TCPOptionSACKPermitted, 2
return int(b[1])
}
// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
// in the provided slice. It tries to fit in as many blocks as possible based on
// number of bytes available in the provided buffer. It returns the number of
// bytes written to the provided buffer.
func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
if len(sackBlocks) == 0 {
return 0
}
l := len(sackBlocks)
if l > TCPMaxSACKBlocks {
l = TCPMaxSACKBlocks
}
if ll := (len(b) - 2) / 8; ll < l {
l = ll
}
if l == 0 {
// There is not enough space in the provided buffer to add
// any SACK blocks.
return 0
}
b[0] = TCPOptionSACK
b[1] = byte(l*8 + 2)
for i := 0; i < l; i++ {
binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
}
return int(b[1])
}
// EncodeNOP adds an explicit NOP to the option list.
func EncodeNOP(b []byte) int {
if len(b) == 0 {
return 0
}
b[0] = TCPOptionNOP
return 1
}
// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
// the option buffer. It adds padding bytes after the offset specified and
// returns the number of padding bytes added. The passed in options slice
// must have space for the padding bytes.
func AddTCPOptionPadding(options []byte, offset int) int {
paddingToAdd := -offset & 3
// Now add any padding bytes that might be required to quad align the
// options.
for i := offset; i < offset+paddingToAdd; i++ {
options[i] = TCPOptionNOP
}
return paddingToAdd
}

View File

@@ -0,0 +1,148 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header_test
import (
"reflect"
"testing"
"tcpip/netstack/tcpip/header"
)
func TestEncodeSACKBlocks(t *testing.T) {
testCases := []struct {
sackBlocks []header.SACKBlock
want []header.SACKBlock
bufSize int
}{
{
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
40,
},
{
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}},
30,
},
{
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
[]header.SACKBlock{{10, 20}, {22, 30}},
20,
},
{
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
[]header.SACKBlock{{10, 20}},
10,
},
{
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
nil,
8,
},
{
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
[]header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
60,
},
}
for _, tc := range testCases {
b := make([]byte, tc.bufSize)
t.Logf("testing: %v", tc)
header.EncodeSACKBlocks(tc.sackBlocks, b)
opts := header.ParseTCPOptions(b)
if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) {
t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want)
}
}
}
func TestTCPParseOptions(t *testing.T) {
type tsOption struct {
tsVal uint32
tsEcr uint32
}
generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte {
l := 0
if tsOpt != nil {
l += 10
}
if len(sackBlocks) != 0 {
l += len(sackBlocks)*8 + 2
}
b := make([]byte, l)
offset := 0
if tsOpt != nil {
offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b)
}
header.EncodeSACKBlocks(sackBlocks, b[offset:])
return b
}
testCases := []struct {
b []byte
want header.TCPOptions
}{
// Trivial cases.
{nil, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
// Test timestamp parsing.
{[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
{[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
// Test malformed timestamp option.
{[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
// Test SACKBlock parsing.
{[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}},
{[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}},
// Test malformed SACK option.
{[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}},
// Test Timestamp + SACK block parsing.
{generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}},
{generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}},
{generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}},
// Test valid timestamp + malformed SACK block parsing.
{[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}},
{[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}},
{[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}},
{[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
{[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
{[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}},
{[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}},
{[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
}
for _, tc := range testCases {
if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) {
t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want)
}
}
}

View File

@@ -0,0 +1,117 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package header
import (
"encoding/binary"
"tcpip/netstack/tcpip"
)
const (
udpSrcPort = 0
udpDstPort = 2
udpLength = 4
udpChecksum = 6
)
// UDPFields contains the fields of a UDP packet. It is used to describe the
// fields of a packet that needs to be encoded.
// udp 首部字段
type UDPFields struct {
// SrcPort is the "source port" field of a UDP packet.
SrcPort uint16
// DstPort is the "destination port" field of a UDP packet.
DstPort uint16
// Length is the "length" field of a UDP packet.
Length uint16
// Checksum is the "checksum" field of a UDP packet.
Checksum uint16
}
// UDP represents a UDP header stored in a byte array.
type UDP []byte
const (
// UDPMinimumSize is the minimum size of a valid UDP packet.
UDPMinimumSize = 8
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
// SourcePort returns the "source port" field of the udp header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
// DestinationPort returns the "destination port" field of the udp header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
// Length returns the "length" field of the udp header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
return b[UDPMinimumSize:]
}
// Checksum returns the "checksum" field of the udp header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
// SetSourcePort sets the "source port" field of the udp header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
// SetDestinationPort sets the "destination port" field of the udp header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
// SetChecksum sets the "checksum" field of the udp header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
// CalculateChecksum calculates the checksum of the udp packet, given the total
// length of the packet and the checksum of the network-layer pseudo-header
// (excluding the total length) and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
// Add the length portion of the checksum to the pseudo-checksum.
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
checksum := Checksum(tmp, partialChecksum)
// Calculate the rest of the checksum.
return Checksum(b[:UDPMinimumSize], checksum)
}
// Encode encodes all the fields of the udp header.
func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}

View File

@@ -0,0 +1,125 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package channel provides the implemention of channel-based data-link layer
// endpoints. Such endpoints allow injection of inbound packets and store
// outbound packets in a channel.
package channel
import (
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/stack"
)
// PacketInfo holds all the information about an outbound packet.
type PacketInfo struct {
Header buffer.View
Payload buffer.View
Proto tcpip.NetworkProtocolNumber
}
// Endpoint is link layer endpoint that stores outbound packets in a channel
// and allows injection of inbound packets.
type Endpoint struct {
dispatcher stack.NetworkDispatcher
mtu uint32
linkAddr tcpip.LinkAddress
// C is where outbound packets are queued.
C chan PacketInfo
}
// New creates a new channel endpoint.
func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) {
e := &Endpoint{
C: make(chan PacketInfo, size),
mtu: mtu,
linkAddr: linkAddr,
}
return stack.RegisterLinkEndpoint(e), e
}
// Drain removes all outbound packets from the channel and counts them.
func (e *Endpoint) Drain() int {
c := 0
for {
select {
case <-e.C:
c++
default:
return c
}
}
}
// Inject injects an inbound packet.
func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
e.InjectLinkAddr(protocol, "", vv)
}
// InjectLinkAddr injects an inbound packet with a remote link address.
func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) {
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, "" /* localLinkAddr */, protocol, vv.Clone(nil))
}
// Attach saves the stack network-layer dispatcher for use later when packets
// are injected.
func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *Endpoint) IsAttached() bool {
return e.dispatcher != nil
}
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
func (e *Endpoint) MTU() uint32 {
return e.mtu
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities {
return 0
}
// MaxHeaderLength returns the maximum size of the link layer header. Given it
// doesn't have a header, it just returns 0.
func (*Endpoint) MaxHeaderLength() uint16 {
return 0
}
// LinkAddress returns the link address of this endpoint.
func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
return e.linkAddr
}
// WritePacket stores outbound packets into the channel.
func (e *Endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
p := PacketInfo{
Header: hdr.View(),
Proto: protocol,
Payload: payload.ToView(),
}
select {
case e.C <- p:
default:
}
return nil
}

View File

@@ -0,0 +1,305 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux
// Package fdbased provides the implemention of data-link layer endpoints
// backed by boundary-preserving file descriptors (e.g., TUN devices,
// seqpacket/datagram sockets).
//
// FD based endpoints can be used in the networking stack by calling New() to
// create a new endpoint, and then passing it as an argument to
// Stack.CreateNIC().
package fdbased
import (
"syscall"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/link/rawfile"
"tcpip/netstack/tcpip/stack"
)
// BufConfig defines the shape of the vectorised view used to read packets from the NIC.
// 从NIC读取数据的多级缓存配置
var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
// 负责底层网卡的io读写以及数据分发
type endpoint struct {
// fd is the file descriptor used to send and receive packets.
// 发送和接收数据的文件描述符
fd int
// mtu (maximum transmission unit) is the maximum size of a packet.
// 单个帧的最大长度
mtu uint32
// hdrSize specifies the link-layer header size. If set to 0, no header
// is added/removed; otherwise an ethernet header is used.
// 以太网头部长度
hdrSize int
// addr is the address of the endpoint.
// 网卡地址
addr tcpip.LinkAddress
// caps holds the endpoint capabilities.
// 网卡的能力
caps stack.LinkEndpointCapabilities
// closed is a function to be called when the FD's peer (if any) closes
// its end of the communication pipe.
closed func(*tcpip.Error)
iovecs []syscall.Iovec
views []buffer.View
dispatcher stack.NetworkDispatcher
// handleLocal indicates whether packets destined to itself should be
// handled by the netstack internally (true) or be forwarded to the FD
// endpoint (false).
// handleLocal指示发往自身的数据包是由内部netstack处理true还是转发到FD端点false
// Resend packets back to netstack if destined to itself
// Add option to redirect packet back to netstack if it's destined to itself.
// This fixes the problem where connecting to the local NIC address would
// not work, e.g.:
// echo bar | nc -l -p 8080 &
// echo foo | nc 192.168.0.2 8080
handleLocal bool
}
// Options specify the details about the fd-based endpoint to be created.
// 创建fdbase端的一些选项参数
type Options struct {
FD int
MTU uint32
ClosedFunc func(*tcpip.Error)
Address tcpip.LinkAddress
ResolutionRequired bool
SaveRestore bool
ChecksumOffload bool
DisconnectOk bool
HandleLocal bool
TestLossPacket func(data []byte) bool
}
// New creates a new fd-based endpoint.
//
// Makes fd non-blocking, but does not take ownership of fd, which must remain
// open for the lifetime of the returned endpoint.
// 根据选项参数创建一个链路层的endpoint并返回该endpoint的id
func New(opts *Options) tcpip.LinkEndpointID {
syscall.SetNonblock(opts.FD, true)
caps := stack.LinkEndpointCapabilities(0)
if opts.ResolutionRequired {
caps |= stack.CapabilityResolutionRequired
}
if opts.ChecksumOffload {
caps |= stack.CapabilityChecksumOffload
}
if opts.SaveRestore {
caps |= stack.CapabilitySaveRestore
}
if opts.DisconnectOk {
caps |= stack.CapabilityDisconnectOk
}
e := &endpoint{
fd: opts.FD,
mtu: opts.MTU,
caps: caps,
closed: opts.ClosedFunc,
addr: opts.Address,
hdrSize: header.EthernetMinimumSize,
views: make([]buffer.View, len(BufConfig)),
iovecs: make([]syscall.Iovec, len(BufConfig)),
handleLocal: opts.HandleLocal,
}
// 全局注册链路层设备
return stack.RegisterLinkEndpoint(e)
}
// Attach launches the goroutine that reads packets from the file descriptor and
// dispatches them via the provided dispatcher.
// Attach 启动从文件描述符中读取数据包的goroutine并通过提供的分发函数来分发数据报。
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
// Link endpoints are not savable. When transportation endpoints are
// saved, they stop sending outgoing packets and all incoming packets
// are rejected.
// 链接端点不可靠。保存传输端点后,它们将停止发送传出数据包,并拒绝所有传入数据包。
go e.dispatchLoop()
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
// 判断是否Attach了
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
// during construction.
// 返回当前MTU值
func (e *endpoint) MTU() uint32 {
return e.mtu
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
// 返回当前网卡的能力
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.caps
}
// MaxHeaderLength returns the maximum size of the link-layer header.
// 返回当前以太网头部信息长度
func (e *endpoint) MaxHeaderLength() uint16 {
return uint16(e.hdrSize)
}
// LinkAddress returns the link address of this endpoint.
// 返回当前MAC地址
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.addr
}
// WritePacket writes outbound packets to the file descriptor. If it is not
// currently writable, the packet is dropped.
// 将上层的报文经过链路层封装,写入网卡中,如果写入失败则丢弃该报文
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView,
protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
// 如果目标地址就是设备本身自己,那么将报文重新返回给协议栈
if e.handleLocal && r.LocalAddress != "" && r.LocalAddress == r.RemoteAddress {
views := make([]buffer.View, 1, 1+len(payload.Views()))
views[0] = hdr.View()
views = append(views, payload.Views()...)
vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
e.dispatcher.DeliverNetworkPacket(e, r.RemoteLinkAddress, r.LocalLinkAddress, protocol, vv)
return nil
}
// 封装增加以太网头部
eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
ethHdr := &header.EthernetFields{
DstAddr: r.RemoteLinkAddress,
Type: protocol,
}
// Preserve the src address if it's set in the route.
// 如果路由信息中有配置源MAC地址那么使用该地址
// 如果没有,则使用本网卡的地址
if r.LocalLinkAddress != "" {
ethHdr.SrcAddr = r.LocalLinkAddress
} else {
ethHdr.SrcAddr = e.addr
}
eth.Encode(ethHdr)
// 写入网卡中
// log.Printf("write to nic %d bytes", hdr.UsedLength()+payload.Size())
if payload.Size() == 0 {
return rawfile.NonBlockingWrite(e.fd, hdr.View())
}
return rawfile.NonBlockingWrite2(e.fd, hdr.View(), payload.ToView())
}
func (e *endpoint) capViews(n int, buffers []int) int {
c := 0
for i, s := range buffers {
c += s
if c >= n {
e.views[i].CapLength(s - (c - n))
return i + 1
}
}
return len(buffers)
}
// 按bufConfig的长度分配内存大小
// 注意e.views和e.iovecs共用相同的内存块
func (e *endpoint) allocateViews(bufConfig []int) {
for i, v := range e.views {
if v != nil {
break
}
b := buffer.NewView(bufConfig[i])
e.views[i] = b
e.iovecs[i] = syscall.Iovec{
Base: &b[0],
Len: uint64(len(b)),
}
}
}
// dispatch reads one packet from the file descriptor and dispatches it.
// 从网卡中读取一个数据报,
func (e *endpoint) dispatch() (bool, *tcpip.Error) {
// 读取数据缓存的分配
e.allocateViews(BufConfig)
// 从网卡中读取数据
n, err := rawfile.BlockingReadv(e.fd, e.iovecs)
if err != nil {
return false, err
}
// 如果比头部长度还小,直接丢弃
if n <= e.hdrSize {
return false, nil
}
var (
p tcpip.NetworkProtocolNumber
remoteLinkAddr, localLinkAddr tcpip.LinkAddress
)
// 获取以太网头部信息
eth := header.Ethernet(e.views[0])
p = eth.Type()
remoteLinkAddr = eth.SourceAddress()
localLinkAddr = eth.DestinationAddress()
used := e.capViews(n, BufConfig)
vv := buffer.NewVectorisedView(n, e.views[:used])
// 将数据内容删除以太网头部信息,也就是将数据指针指向网络层的第一个字节
vv.TrimFront(e.hdrSize)
// 调用nic.DeliverNetworkPacket来分发网络层数据
// log.Printf("read from nic %d bytes", e.hdrSize+vv.Size())
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
// Prepare e.views for another packet: release used views.
for i := 0; i < used; i++ {
e.views[i] = nil
}
return true, nil
}
// dispatchLoop reads packets from the file descriptor in a loop and dispatches
// them to the network stack.
// 循环的从fd读取数据然后将数据包分发给协议栈
func (e *endpoint) dispatchLoop() *tcpip.Error {
for {
cont, err := e.dispatch()
if err != nil || !cont {
if e.closed != nil {
e.closed(err)
}
return err
}
}
}

View File

@@ -0,0 +1,341 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux
package fdbased
import (
"bytes"
"fmt"
"math/rand"
"reflect"
"syscall"
"testing"
"time"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
)
const (
mtu = 1500
laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
proto = 10
)
type packetInfo struct {
raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber
contents buffer.View
}
type context struct {
t *testing.T
fds [2]int
ep stack.LinkEndpoint
ch chan packetInfo
done chan struct{}
}
func newContext(t *testing.T, opt *Options) *context {
fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
if err != nil {
t.Fatalf("Socketpair failed: %v", err)
}
done := make(chan struct{}, 1)
opt.ClosedFunc = func(*tcpip.Error) {
done <- struct{}{}
}
opt.FD = fds[1]
ep := stack.FindLinkEndpoint(New(opt)).(*endpoint)
c := &context{
t: t,
fds: fds,
ep: ep,
ch: make(chan packetInfo, 100),
done: done,
}
ep.Attach(c)
return c
}
func (c *context) cleanup() {
syscall.Close(c.fds[0])
<-c.done
syscall.Close(c.fds[1])
}
func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()}
}
func TestEthernetProperties(t *testing.T) {
c := newContext(t, &Options{MTU: mtu})
defer c.cleanup()
if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
}
if want, v := uint32(mtu), c.ep.MTU(); want != v {
t.Fatalf("MTU() = %v, want %v", v, want)
}
}
func TestAddress(t *testing.T) {
addrs := []tcpip.LinkAddress{"", "abc", "def"}
for _, a := range addrs {
t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
c := newContext(t, &Options{Address: a, MTU: mtu})
defer c.cleanup()
if want, v := a, c.ep.LinkAddress(); want != v {
t.Fatalf("LinkAddress() = %v, want %v", v, want)
}
})
}
}
func TestWritePacket(t *testing.T) {
lengths := []int{0, 100, 1000}
for _, plen := range lengths {
t.Run(fmt.Sprintf("PayloadLen=%v", plen), func(t *testing.T) {
c := newContext(t, &Options{Address: laddr, MTU: mtu})
defer c.cleanup()
r := &stack.Route{
RemoteLinkAddress: raddr,
}
// Build header.
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
b := hdr.Prepend(100)
for i := range b {
b[i] = uint8(rand.Intn(256))
}
// Build payload and write.
payload := make(buffer.View, plen)
for i := range payload {
payload[i] = uint8(rand.Intn(256))
}
want := append(hdr.View(), payload...)
if err := c.ep.WritePacket(r, hdr, payload.ToVectorisedView(), proto); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
// Read from fd, then compare with what we wrote.
b = make([]byte, mtu)
n, err := syscall.Read(c.fds[0], b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
b = b[:n]
h := header.Ethernet(b)
b = b[header.EthernetMinimumSize:]
if a := h.SourceAddress(); a != laddr {
t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
}
if a := h.DestinationAddress(); a != raddr {
t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
}
if et := h.Type(); et != proto {
t.Fatalf("Type() = %v, want %v", et, proto)
}
if len(b) != len(want) {
t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
}
if !bytes.Equal(b, want) {
t.Fatalf("Read returned %x, want %x", b, want)
}
})
}
}
func TestPreserveSrcAddress(t *testing.T) {
baddr := tcpip.LinkAddress("\xcc\xbb\xaa\x77\x88\x99")
c := newContext(t, &Options{Address: laddr, MTU: mtu})
defer c.cleanup()
// Set LocalLinkAddress in route to the value of the bridged address.
r := &stack.Route{
RemoteLinkAddress: raddr,
LocalLinkAddress: baddr,
}
// WritePacket panics given a prependable with anything less than
// the minimum size of the ethernet header.
hdr := buffer.NewPrependable(header.EthernetMinimumSize)
if err := c.ep.WritePacket(r, hdr, buffer.VectorisedView{}, proto); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
// Read from the FD, then compare with what we wrote.
b := make([]byte, mtu)
n, err := syscall.Read(c.fds[0], b)
if err != nil {
t.Fatalf("Read failed: %v", err)
}
b = b[:n]
h := header.Ethernet(b)
if a := h.SourceAddress(); a != baddr {
t.Fatalf("SourceAddress() = %v, want %v", a, baddr)
}
}
func TestDeliverPacket(t *testing.T) {
lengths := []int{100, 1000}
for _, plen := range lengths {
t.Run(fmt.Sprintf("PayloadLen=%v", plen), func(t *testing.T) {
c := newContext(t, &Options{Address: laddr, MTU: mtu})
defer c.cleanup()
// Build packet.
b := make([]byte, plen)
all := b
for i := range b {
b[i] = uint8(rand.Intn(256))
}
hdr := make(header.Ethernet, header.EthernetMinimumSize)
hdr.Encode(&header.EthernetFields{
SrcAddr: raddr,
DstAddr: laddr,
Type: proto,
})
all = append(hdr, b...)
// Write packet via the file descriptor.
if _, err := syscall.Write(c.fds[0], all); err != nil {
t.Fatalf("Write failed: %v", err)
}
// Receive packet through the endpoint.
select {
case pi := <-c.ch:
want := packetInfo{
raddr: raddr,
proto: proto,
contents: b,
}
if !reflect.DeepEqual(want, pi) {
t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
}
case <-time.After(10 * time.Second):
t.Fatalf("Timed out waiting for packet")
}
})
}
}
func TestBufConfigMaxLength(t *testing.T) {
got := 0
for _, i := range BufConfig {
got += i
}
want := header.MaxIPPacketSize // maximum TCP packet size
if got < want {
t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
}
}
func TestBufConfigFirst(t *testing.T) {
// The stack assumes that the TCP/IP header is enterily contained in the first view.
// Therefore, the first view needs to be large enough to contain the maximum TCP/IP
// header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
want := 120
got := BufConfig[0]
if got < want {
t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
}
}
func build(bufConfig []int) *endpoint {
e := &endpoint{
views: make([]buffer.View, len(bufConfig)),
iovecs: make([]syscall.Iovec, len(bufConfig)),
}
e.allocateViews(bufConfig)
return e
}
var capLengthTestCases = []struct {
comment string
config []int
n int
wantUsed int
wantLengths []int
}{
{
comment: "Single slice",
config: []int{2},
n: 1,
wantUsed: 1,
wantLengths: []int{1},
},
{
comment: "Multiple slices",
config: []int{1, 2},
n: 2,
wantUsed: 2,
wantLengths: []int{1, 1},
},
{
comment: "Entire buffer",
config: []int{1, 2},
n: 3,
wantUsed: 2,
wantLengths: []int{1, 2},
},
{
comment: "Entire buffer but not on the last slice",
config: []int{1, 2, 3},
n: 3,
wantUsed: 2,
wantLengths: []int{1, 2, 3},
},
}
func TestCapLength(t *testing.T) {
for _, c := range capLengthTestCases {
e := build(c.config)
used := e.capViews(c.n, c.config)
if used != c.wantUsed {
t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
}
lengths := make([]int, len(e.views))
for i, v := range e.views {
lengths[i] = len(v)
}
if !reflect.DeepEqual(lengths, c.wantLengths) {
t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
}
}
}

View File

@@ -0,0 +1,87 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package loopback provides the implemention of loopback data-link layer
// endpoints. Such endpoints just turn outbound packets into inbound ones.
//
// Loopback endpoints can be used in the networking stack by calling New() to
// create a new endpoint, and then passing it as an argument to
// Stack.CreateNIC().
package loopback
import (
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/stack"
)
type endpoint struct {
dispatcher stack.NetworkDispatcher
}
// New creates a new loopback endpoint. This link-layer endpoint just turns
// outbound packets into inbound packets.
func New() tcpip.LinkEndpointID {
return stack.RegisterLinkEndpoint(&endpoint{})
}
// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
// layer dispatcher for later use when packets need to be dispatched.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the
// linux loopback interface.
func (*endpoint) MTU() uint32 {
return 65536
}
// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises
// itself as supporting checksum offload, but in reality it's just omitted.
func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback
}
// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the
// loopback interface doesn't have a header, it just returns 0.
func (*endpoint) MaxHeaderLength() uint16 {
return 0
}
// LinkAddress returns the link address of this endpoint.
func (*endpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
// packets to the network-layer dispatcher.
func (e *endpoint) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
views := make([]buffer.View, 1, 1+len(payload.Views()))
views[0] = hdr.View()
views = append(views, payload.Views()...)
vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views)
// Because we're immediately turning around and writing the packet back to the
// rx path, we intentionally don't preserve the remote and local link
// addresses from the stack.Route we're passed.
e.dispatcher.DeliverNetworkPacket(e, "" /* remoteLinkAddr */, "" /* localLinkAddr */, protocol, vv)
return nil
}

View File

@@ -0,0 +1,27 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux
package rawfile
import (
"syscall"
"unsafe"
)
func blockingPoll(fds *pollEvent, nfds int, timeout int64) (int, syscall.Errno) {
n, _, e := syscall.Syscall(syscall.SYS_POLL, uintptr(unsafe.Pointer(fds)), uintptr(nfds), uintptr(timeout))
return int(n), e
}

View File

@@ -0,0 +1,70 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux
package rawfile
import (
"fmt"
"syscall"
"tcpip/netstack/tcpip"
)
const maxErrno = 134
var translations [maxErrno]*tcpip.Error
// TranslateErrno translate an errno from the syscall package into a
// *tcpip.Error.
//
// Valid, but unreconigized errnos will be translated to
// tcpip.ErrInvalidEndpointState (EINVAL). Panics on invalid errnos.
func TranslateErrno(e syscall.Errno) *tcpip.Error {
if err := translations[e]; err != nil {
return err
}
return tcpip.ErrInvalidEndpointState
}
func addTranslation(host syscall.Errno, trans *tcpip.Error) {
if translations[host] != nil {
panic(fmt.Sprintf("duplicate translation for host errno %q (%d)", host.Error(), host))
}
translations[host] = trans
}
func init() {
addTranslation(syscall.EEXIST, tcpip.ErrDuplicateAddress)
addTranslation(syscall.ENETUNREACH, tcpip.ErrNoRoute)
addTranslation(syscall.EINVAL, tcpip.ErrInvalidEndpointState)
addTranslation(syscall.EALREADY, tcpip.ErrAlreadyConnecting)
addTranslation(syscall.EISCONN, tcpip.ErrAlreadyConnected)
addTranslation(syscall.EADDRINUSE, tcpip.ErrPortInUse)
addTranslation(syscall.EADDRNOTAVAIL, tcpip.ErrBadLocalAddress)
addTranslation(syscall.EPIPE, tcpip.ErrClosedForSend)
addTranslation(syscall.EWOULDBLOCK, tcpip.ErrWouldBlock)
addTranslation(syscall.ECONNREFUSED, tcpip.ErrConnectionRefused)
addTranslation(syscall.ETIMEDOUT, tcpip.ErrTimeout)
addTranslation(syscall.EINPROGRESS, tcpip.ErrConnectStarted)
addTranslation(syscall.EDESTADDRREQ, tcpip.ErrDestinationRequired)
addTranslation(syscall.ENOTSUP, tcpip.ErrNotSupported)
addTranslation(syscall.ENOTTY, tcpip.ErrQueueSizeNotSupported)
addTranslation(syscall.ENOTCONN, tcpip.ErrNotConnected)
addTranslation(syscall.ECONNRESET, tcpip.ErrConnectionReset)
addTranslation(syscall.ECONNABORTED, tcpip.ErrConnectionAborted)
addTranslation(syscall.EMSGSIZE, tcpip.ErrMessageTooLong)
addTranslation(syscall.ENOBUFS, tcpip.ErrNoBufferSpace)
}

View File

@@ -0,0 +1,145 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build linux
// Package rawfile contains utilities for using the netstack with raw host
// files on Linux hosts.
package rawfile
import (
"syscall"
"unsafe"
"tcpip/netstack/tcpip"
)
// GetMTU determines the MTU of a network interface device.
func GetMTU(name string) (uint32, error) {
fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
if err != nil {
return 0, err
}
defer syscall.Close(fd)
var ifreq struct {
name [16]byte
mtu int32
_ [20]byte
}
copy(ifreq.name[:], name)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
if errno != 0 {
return 0, errno
}
return uint32(ifreq.mtu), nil
}
// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
// partial data is written.
func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
var ptr unsafe.Pointer
if len(buf) > 0 {
ptr = unsafe.Pointer(&buf[0])
}
_, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
if e != 0 {
return TranslateErrno(e)
}
return nil
}
// NonBlockingWrite2 writes up to two byte slices to a file descriptor in a
// single syscall. It fails if partial data is written.
func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error {
// If the is no second buffer, issue a regular write.
if len(b2) == 0 {
return NonBlockingWrite(fd, b1)
}
// We have two buffers. Build the iovec that represents them and issue
// a writev syscall.
iovec := [...]syscall.Iovec{
{
Base: &b1[0],
Len: uint64(len(b1)),
},
{
Base: &b2[0],
Len: uint64(len(b2)),
},
}
_, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
if e != 0 {
return TranslateErrno(e)
}
return nil
}
type pollEvent struct {
fd int32
events int16
revents int16
}
// BlockingRead reads from a file descriptor that is set up as non-blocking. If
// no data is available, it will block in a poll() syscall until the file
// descirptor becomes readable.
func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
if e == 0 {
return int(n), nil
}
event := pollEvent{
fd: int32(fd),
events: 1, // POLLIN
}
_, e = blockingPoll(&event, 1, -1)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
}
}
// BlockingReadv reads from a file descriptor that is set up as non-blocking and
// stores the data in a list of iovecs buffers. If no data is available, it will
// block in a poll() syscall until the file descirptor becomes readable.
func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
for {
n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
return int(n), nil
}
event := pollEvent{
fd: int32(fd),
events: 1, // POLLIN
}
_, e = blockingPoll(&event, 1, -1)
if e != 0 && e != syscall.EINTR {
return 0, TranslateErrno(e)
}
}
}

View File

@@ -0,0 +1,66 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sniffer
import "time"
type pcapHeader struct {
// MagicNumber is the file magic number.
MagicNumber uint32
// VersionMajor is the major version number.
VersionMajor uint16
// VersionMinor is the minor version number.
VersionMinor uint16
// Thiszone is the GMT to local correction.
Thiszone int32
// Sigfigs is the accuracy of timestamps.
Sigfigs uint32
// Snaplen is the max length of captured packets, in octets.
Snaplen uint32
// Network is the data link type.
Network uint32
}
const pcapPacketHeaderLen = 16
type pcapPacketHeader struct {
// Seconds is the timestamp seconds.
Seconds uint32
// Microseconds is the timestamp microseconds.
Microseconds uint32
// IncludedLength is the number of octets of packet saved in file.
IncludedLength uint32
// OriginalLength is the actual length of packet.
OriginalLength uint32
}
func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
now := time.Now()
return pcapPacketHeader{
Seconds: uint32(now.Unix()),
Microseconds: uint32(now.Nanosecond() / 1000),
IncludedLength: incLen,
OriginalLength: orgLen,
}
}

View File

@@ -0,0 +1,390 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package sniffer provides the implementation of data-link layer endpoints that
// wrap another endpoint and logs inbound and outbound packets.
//
// Sniffer endpoints can be used in the networking stack by calling New(eID) to
// create a new endpoint, where eID is the ID of the endpoint being wrapped,
// and then passing it as an argument to Stack.CreateNIC().
package sniffer
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"os"
"sync/atomic"
"time"
"log"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
)
// LogPackets is a flag used to enable or disable packet logging via the log
// package. Valid values are 0 or 1.
//
// LogPackets must be accessed atomically.
var LogPackets uint32 = 1
// LogPacketsToFile is a flag used to enable or disable logging packets to a
// pcap file. Valid values are 0 or 1. A file must have been specified when the
// sniffer was created for this flag to have effect.
//
// LogPacketsToFile must be accessed atomically.
var LogPacketsToFile uint32 = 1
type endpoint struct {
dispatcher stack.NetworkDispatcher
lower stack.LinkEndpoint
file *os.File
maxPCAPLen uint32
}
// New creates a new sniffer link-layer endpoint. It wraps around another
// endpoint and logs packets and they traverse the endpoint.
func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
return stack.RegisterLinkEndpoint(&endpoint{
lower: stack.FindLinkEndpoint(lower),
})
}
func zoneOffset() (int32, error) {
loc, err := time.LoadLocation("Local")
if err != nil {
return 0, err
}
date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
_, offset := date.Zone()
return int32(offset), nil
}
func writePCAPHeader(w io.Writer, maxLen uint32) error {
offset, err := zoneOffset()
if err != nil {
return err
}
return binary.Write(w, binary.BigEndian, pcapHeader{
// From https://wiki.wireshark.org/Development/LibpcapFileFormat
MagicNumber: 0xa1b2c3d4,
VersionMajor: 2,
VersionMinor: 4,
Thiszone: offset,
Sigfigs: 0,
Snaplen: maxLen,
Network: 101, // LINKTYPE_RAW
})
}
// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
// another endpoint and logs packets and they traverse the endpoint.
//
// Packets can be logged to file in the pcap format. A sniffer created
// with this function will not emit packets using the standard log
// package.
//
// snapLen is the maximum amount of a packet to be saved. Packets with a length
// less than or equal too snapLen will be saved in their entirety. Longer
// packets will be truncated to snapLen.
func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
if err := writePCAPHeader(file, snapLen); err != nil {
return 0, err
}
return stack.RegisterLinkEndpoint(&endpoint{
lower: stack.FindLinkEndpoint(lower),
file: file,
maxPCAPLen: snapLen,
}), nil
}
// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
// called by the link-layer endpoint being wrapped when a packet arrives, and
// logs the packet before forwarding to the actual dispatcher.
func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("recv", protocol, vv.First())
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
vs := vv.Views()
length := vv.Size()
if length > int(e.maxPCAPLen) {
length = int(e.maxPCAPLen)
}
buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size()))); err != nil {
panic(err)
}
for _, v := range vs {
if length == 0 {
break
}
if len(v) > length {
v = v[:length]
}
if _, err := buf.Write([]byte(v)); err != nil {
panic(err)
}
length -= len(v)
}
if _, err := e.file.Write(buf.Bytes()); err != nil {
panic(err)
}
}
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, protocol, vv)
}
// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
// and registers with the lower endpoint as its dispatcher so that "e" is called
// for inbound packets.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
e.dispatcher = dispatcher
e.lower.Attach(e)
}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (e *endpoint) IsAttached() bool {
return e.dispatcher != nil
}
// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
// lower endpoint.
func (e *endpoint) MTU() uint32 {
return e.lower.MTU()
}
// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
// request to the lower endpoint.
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.lower.Capabilities()
}
// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
// the request to the lower endpoint.
func (e *endpoint) MaxHeaderLength() uint16 {
return e.lower.MaxHeaderLength()
}
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
return e.lower.LinkAddress()
}
// WritePacket implements the stack.LinkEndpoint interface. It is called by
// higher-level protocols to write packets; it just logs the packet and forwards
// the request to the lower endpoint.
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
if atomic.LoadUint32(&LogPackets) == 1 && e.file == nil {
logPacket("send", protocol, hdr.View())
}
if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
hdrBuf := hdr.View()
length := len(hdrBuf) + payload.Size()
if length > int(e.maxPCAPLen) {
length = int(e.maxPCAPLen)
}
buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen+length))
if err := binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(len(hdrBuf)+payload.Size()))); err != nil {
panic(err)
}
if len(hdrBuf) > length {
hdrBuf = hdrBuf[:length]
}
if _, err := buf.Write(hdrBuf); err != nil {
panic(err)
}
length -= len(hdrBuf)
if length > 0 {
for _, v := range payload.Views() {
if len(v) > length {
v = v[:length]
}
n, err := buf.Write(v)
if err != nil {
panic(err)
}
length -= n
if length == 0 {
break
}
}
}
if _, err := e.file.Write(buf.Bytes()); err != nil {
panic(err)
}
}
return e.lower.WritePacket(r, hdr, payload, protocol)
}
func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b buffer.View) {
// Figure out the network layer info.
var transProto uint8
src := tcpip.Address("unknown")
dst := tcpip.Address("unknown")
id := 0
size := uint16(0)
switch protocol {
case header.IPv4ProtocolNumber:
ipv4 := header.IPv4(b)
src = ipv4.SourceAddress()
dst = ipv4.DestinationAddress()
transProto = ipv4.Protocol()
size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
b = b[ipv4.HeaderLength():]
id = int(ipv4.ID())
case header.IPv6ProtocolNumber:
ipv6 := header.IPv6(b)
src = ipv6.SourceAddress()
dst = ipv6.DestinationAddress()
transProto = ipv6.NextHeader()
size = ipv6.PayloadLength()
b = b[header.IPv6MinimumSize:]
case header.ARPProtocolNumber:
arp := header.ARP(b)
log.Printf(
"%s arp %v (%v) -> %v (%v) valid:%v",
prefix,
tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
arp.IsValid(),
)
return
default:
log.Printf("%s unknown network protocol", prefix)
return
}
// Figure out the transport layer info.
transName := "unknown"
srcPort := uint16(0)
dstPort := uint16(0)
details := ""
switch tcpip.TransportProtocolNumber(transProto) {
case header.ICMPv4ProtocolNumber:
transName = "icmp"
icmp := header.ICMPv4(b)
icmpType := "unknown"
switch icmp.Type() {
case header.ICMPv4EchoReply:
icmpType = "echo reply"
case header.ICMPv4DstUnreachable:
icmpType = "destination unreachable"
case header.ICMPv4SrcQuench:
icmpType = "source quench"
case header.ICMPv4Redirect:
icmpType = "redirect"
case header.ICMPv4Echo:
icmpType = "echo"
case header.ICMPv4TimeExceeded:
icmpType = "time exceeded"
case header.ICMPv4ParamProblem:
icmpType = "param problem"
case header.ICMPv4Timestamp:
icmpType = "timestamp"
case header.ICMPv4TimestampReply:
icmpType = "timestamp reply"
case header.ICMPv4InfoRequest:
icmpType = "info request"
case header.ICMPv4InfoReply:
icmpType = "info reply"
}
log.Printf("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.ICMPv6ProtocolNumber:
transName = "icmp"
icmp := header.ICMPv6(b)
icmpType := "unknown"
switch icmp.Type() {
case header.ICMPv6DstUnreachable:
icmpType = "destination unreachable"
case header.ICMPv6PacketTooBig:
icmpType = "packet too big"
case header.ICMPv6TimeExceeded:
icmpType = "time exceeded"
case header.ICMPv6ParamProblem:
icmpType = "param problem"
case header.ICMPv6EchoRequest:
icmpType = "echo request"
case header.ICMPv6EchoReply:
icmpType = "echo reply"
case header.ICMPv6RouterSolicit:
icmpType = "router solicit"
case header.ICMPv6RouterAdvert:
icmpType = "router advert"
case header.ICMPv6NeighborSolicit:
icmpType = "neighbor solicit"
case header.ICMPv6NeighborAdvert:
icmpType = "neighbor advert"
case header.ICMPv6RedirectMsg:
icmpType = "redirect message"
}
log.Printf("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
return
case header.UDPProtocolNumber:
transName = "udp"
udp := header.UDP(b)
srcPort = udp.SourcePort()
dstPort = udp.DestinationPort()
size -= header.UDPMinimumSize
details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
case header.TCPProtocolNumber:
transName = "tcp"
tcp := header.TCP(b)
offset := int(tcp.DataOffset())
if offset < header.TCPMinimumSize {
details += fmt.Sprintf("invalid packet: tcp data offset too small %d", offset)
break
}
if offset > len(tcp) {
details += fmt.Sprintf("invalid packet: tcp data offset %d larger than packet buffer length %d", offset, len(tcp))
break
}
srcPort = tcp.SourcePort()
dstPort = tcp.DestinationPort()
size -= uint16(offset)
// Initialize the TCP flags.
flags := tcp.Flags()
flagsStr := []byte("FSRPAU")
for i := range flagsStr {
if flags&(1<<uint(i)) == 0 {
flagsStr[i] = ' '
}
}
details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
if flags&header.TCPFlagSyn != 0 {
details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
} else {
details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
}
default:
log.Printf("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
return
}
log.Printf("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
}

View File

@@ -0,0 +1,141 @@
// +build linux
package tuntap
import (
"errors"
"fmt"
"os/exec"
"syscall"
"unsafe"
)
const (
TUN = 1
TAP = 2
)
var (
ErrDeviceMode = errors.New("unsupport device mode")
)
type rawSockaddr struct {
Family uint16
Data [14]byte
}
// 虚拟网卡设置的配置
type Config struct {
Name string // 网卡名
Mode int // 网卡模式TUN or TAP
}
// NewNetDev根据配置返回虚拟网卡的文件描述符
func NewNetDev(c *Config) (fd int, err error) {
switch c.Mode {
case TUN:
fd, err = newTun(c.Name)
case TAP:
fd, err = newTAP(c.Name)
default:
err = ErrDeviceMode
return
}
if err != nil {
return
}
return
}
// SetLinkUp 让系统启动该网卡
func SetLinkUp(name string) (err error) {
// ip link set <device-name> up
out, cmdErr := exec.Command("ip", "link", "set", name, "up").CombinedOutput()
if cmdErr != nil {
err = fmt.Errorf("%v:%v", cmdErr, string(out))
return
}
return
}
// SetRoute 通过ip命令添加路由
func SetRoute(name, cidr string) (err error) {
// ip route add 192.168.1.0/24 dev tap0
out, cmdErr := exec.Command("ip", "route", "add", cidr, "dev", name).CombinedOutput()
if cmdErr != nil {
err = fmt.Errorf("%v:%v", cmdErr, string(out))
return
}
return
}
// AddIP 通过ip命令添加IP地址
func AddIP(name, ip string) (err error) {
// ip addr add 192.168.1.1 dev tap0
out, cmdErr := exec.Command("ip", "addr", "add", ip, "dev", name).CombinedOutput()
if cmdErr != nil {
err = fmt.Errorf("%v:%v", cmdErr, string(out))
return
}
return
}
func GetHardwareAddr(name string) (string, error) {
fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
if err != nil {
return "", err
}
defer syscall.Close(fd)
var ifreq struct {
name [16]byte
addr rawSockaddr
_ [8]byte
}
copy(ifreq.name[:], name)
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFHWADDR, uintptr(unsafe.Pointer(&ifreq)))
if errno != 0 {
return "", errno
}
mac := ifreq.addr.Data[:6]
return string(mac[:]), nil
}
// newTun新建一个tun模式的虚拟网卡然后返回该网卡的文件描述符
// IFF_NO_PI表示不需要包信息
func newTun(name string) (int, error) {
return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI)
}
// newTAP新建一个tap模式的虚拟网卡然后返回该网卡的文件描述符
func newTAP(name string) (int, error) {
return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI)
}
// 先打开一个字符串设备通过系统调用将虚拟网卡和字符串设备fd绑定在一起
func open(name string, flags uint16) (int, error) {
// 打开tuntap的字符设备得到字符设备的文件描述符
fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0)
if err != nil {
return -1, err
}
var ifr struct {
name [16]byte
flags uint16
_ [22]byte
}
copy(ifr.name[:], name)
ifr.flags = flags
// 通过ioctl系统调用将fd和虚拟网卡驱动绑定在一起
_, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
if errno != 0 {
syscall.Close(fd)
return -1, errno
}
return fd, nil
}

View File

@@ -0,0 +1,196 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package arp implements the ARP network protocol. It is used to resolve
// IPv4 addresses into link-local MAC addresses, and advertises IPv4
// addresses of its stack with the local network.
//
// To use it in the networking stack, pass arp.ProtocolName as one of the
// network protocols when calling stack.New. Then add an "arp" address to
// every NIC on the stack that should respond to ARP requests. That is:
//
// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
// // handle err
// }
package arp
import (
"log"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
)
const (
// ProtocolName is the string representation of the ARP protocol name.
ProtocolName = "arp"
// ProtocolNumber is the ARP protocol number.
ProtocolNumber = header.ARPProtocolNumber
// ProtocolAddress is the address expected by the ARP endpoint.
ProtocolAddress = tcpip.Address("arp")
)
// endpoint implements stack.NetworkEndpoint.
// 表示arp端
type endpoint struct {
nicid tcpip.NICID
addr tcpip.Address
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
}
// DefaultTTL is unused for ARP. It implements stack.NetworkEndpoint.
func (e *endpoint) DefaultTTL() uint8 {
return 0
}
func (e *endpoint) MTU() uint32 {
lmtu := e.linkEP.MTU()
return lmtu - uint32(e.MaxHeaderLength())
}
func (e *endpoint) NICID() tcpip.NICID {
return e.nicid
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
}
func (e *endpoint) ID() *stack.NetworkEndpointID {
return &stack.NetworkEndpointID{ProtocolAddress}
}
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
func (e *endpoint) Close() {}
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
return tcpip.ErrNotSupported
}
// arp数据包的处理包括arp请求和响应
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
h := header.ARP(v)
if !h.IsValid() {
return
}
// 判断操作码类型
switch h.Op() {
case header.ARPRequest:
// 如果是arp请求
log.Printf("recv arp request")
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
return // we have no useful answer, ignore the request
}
hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
pkt := header.ARP(hdr.Prepend(header.ARPSize))
pkt.SetIPv4OverEthernet()
pkt.SetOp(header.ARPReply)
copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
log.Printf("send arp reply")
e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber)
// 注意这里的 fallthrough 表示需要继续执行下面分支的代码
// 当收到 arp 请求需要添加到链路地址缓存中
fallthrough // also fill the cache from requests
case header.ARPReply:
// 这里记录ip和mac对应关系也就是arp表
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
}
}
// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
// 实现了 stack.NetworkProtocol 和 stack.LinkAddressResolver 两个接口
type protocol struct {
}
func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.ARP(v)
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
if addr != ProtocolAddress {
return nil, tcpip.ErrBadLocalAddress
}
return &endpoint{
nicid: nicid,
addr: addr,
linkEP: sender,
linkAddrCache: linkAddrCache,
}, nil
}
// LinkAddressProtocol implements stack.LinkAddressResolver.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
h := header.ARP(hdr.Prepend(header.ARPSize))
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
copy(h.HardwareAddressSender(), linkEP.LinkAddress())
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == "\xff\xff\xff\xff" {
return broadcastMAC, true
}
return "", false
}
// SetOption implements NetworkProtocol.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements NetworkProtocol.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
func init() {
stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
return &protocol{}
})
}

View File

@@ -0,0 +1,149 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package arp_test
import (
"testing"
"time"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/link/channel"
"tcpip/netstack/tcpip/link/sniffer"
"tcpip/netstack/tcpip/network/arp"
"tcpip/netstack/tcpip/network/ipv4"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/tcpip/transport/ping"
)
const (
stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
)
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
}
func newTestContext(t *testing.T) *testContext {
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ping.ProtocolName4}, stack.Options{})
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
if testing.Verbose() {
id = sniffer.New(id)
}
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
t.Fatalf("AddAddress for arp failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: "\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
}})
return &testContext{
t: t,
s: s,
linkEP: linkEP,
}
}
func (c *testContext) cleanup() {
close(c.linkEP.C)
}
func TestDirectRequest(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
const senderIPv4 = "\x0a\x00\x00\x02"
v := make(buffer.View, header.ARPSize)
h := header.ARP(v)
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
copy(h.HardwareAddressSender(), senderMAC)
copy(h.ProtocolAddressSender(), senderIPv4)
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView())
}
inject(stackAddr1)
{
pkt := <-c.linkEP.C
if pkt.Proto != arp.ProtocolNumber {
t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
}
rep := header.ARP(pkt.Header)
if !rep.IsValid() {
t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
}
if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
t.Errorf("stackAddr1: expected sender to be set")
}
if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
}
}
inject(stackAddr2)
{
pkt := <-c.linkEP.C
if pkt.Proto != arp.ProtocolNumber {
t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
}
rep := header.ARP(pkt.Header)
if !rep.IsValid() {
t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
}
if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
t.Errorf("stackAddr2: expected sender to be set")
}
if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
}
}
inject(stackAddrBad)
select {
case pkt := <-c.linkEP.C:
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
case <-time.After(100 * time.Millisecond):
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
}
}

View File

@@ -0,0 +1,77 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fragmentation
import (
"container/heap"
"fmt"
"tcpip/netstack/tcpip/buffer"
)
type fragment struct {
offset uint16
vv buffer.VectorisedView
}
type fragHeap []fragment
func (h *fragHeap) Len() int {
return len(*h)
}
func (h *fragHeap) Less(i, j int) bool {
return (*h)[i].offset < (*h)[j].offset
}
func (h *fragHeap) Swap(i, j int) {
(*h)[i], (*h)[j] = (*h)[j], (*h)[i]
}
func (h *fragHeap) Push(x interface{}) {
*h = append(*h, x.(fragment))
}
func (h *fragHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[:n-1]
return x
}
// reassamble empties the heap and returns a VectorisedView
// containing a reassambled version of the fragments inside the heap.
func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
curr := heap.Pop(h).(fragment)
views := curr.vv.Views()
size := curr.vv.Size()
if curr.offset != 0 {
return buffer.VectorisedView{}, fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
}
for h.Len() > 0 {
curr := heap.Pop(h).(fragment)
if int(curr.offset) < size {
curr.vv.TrimFront(size - int(curr.offset))
} else if int(curr.offset) > size {
return buffer.VectorisedView{}, fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
}
size += curr.vv.Size()
views = append(views, curr.vv.Views()...)
}
return buffer.NewVectorisedView(size, views), nil
}

View File

@@ -0,0 +1,126 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fragmentation
import (
"container/heap"
"reflect"
"testing"
"tcpip/netstack/tcpip/buffer"
)
var reassambleTestCases = []struct {
comment string
in []fragment
want buffer.VectorisedView
}{
{
comment: "Non-overlapping in-order",
in: []fragment{
{offset: 0, vv: vv(1, "0")},
{offset: 1, vv: vv(1, "1")},
},
want: vv(2, "0", "1"),
},
{
comment: "Non-overlapping out-of-order",
in: []fragment{
{offset: 1, vv: vv(1, "1")},
{offset: 0, vv: vv(1, "0")},
},
want: vv(2, "0", "1"),
},
{
comment: "Duplicated packets",
in: []fragment{
{offset: 0, vv: vv(1, "0")},
{offset: 0, vv: vv(1, "0")},
},
want: vv(1, "0"),
},
{
comment: "Overlapping in-order",
in: []fragment{
{offset: 0, vv: vv(2, "01")},
{offset: 1, vv: vv(2, "12")},
},
want: vv(3, "01", "2"),
},
{
comment: "Overlapping out-of-order",
in: []fragment{
{offset: 1, vv: vv(2, "12")},
{offset: 0, vv: vv(2, "01")},
},
want: vv(3, "01", "2"),
},
{
comment: "Overlapping subset in-order",
in: []fragment{
{offset: 0, vv: vv(3, "012")},
{offset: 1, vv: vv(1, "1")},
},
want: vv(3, "012"),
},
{
comment: "Overlapping subset out-of-order",
in: []fragment{
{offset: 1, vv: vv(1, "1")},
{offset: 0, vv: vv(3, "012")},
},
want: vv(3, "012"),
},
}
func TestReassamble(t *testing.T) {
for _, c := range reassambleTestCases {
t.Run(c.comment, func(t *testing.T) {
h := make(fragHeap, 0, 8)
heap.Init(&h)
for _, f := range c.in {
heap.Push(&h, f)
}
got, err := h.reassemble()
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, c.want) {
t.Errorf("got reassemble(%+v) = %v, want = %v", c.in, got, c.want)
}
})
}
}
func TestReassambleFailsForNonZeroOffset(t *testing.T) {
h := make(fragHeap, 0, 8)
heap.Init(&h)
heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
_, err := h.reassemble()
if err == nil {
t.Errorf("reassemble() did not fail when the first packet had offset != 0")
}
}
func TestReassambleFailsForHoles(t *testing.T) {
h := make(fragHeap, 0, 8)
heap.Init(&h)
heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
_, err := h.reassemble()
if err == nil {
t.Errorf("reassemble() did not fail when there was a hole in the packet")
}
}

View File

@@ -0,0 +1,134 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package fragmentation contains the implementation of IP fragmentation.
// It is based on RFC 791 and RFC 815.
package fragmentation
import (
"log"
"sync"
"time"
"tcpip/netstack/tcpip/buffer"
)
// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
const DefaultReassembleTimeout = 30 * time.Second
// HighFragThreshold is the threshold at which we start trimming old
// fragmented packets. Linux uses a default value of 4 MB. See
// net.ipv4.ipfrag_high_thresh for more information.
const HighFragThreshold = 4 << 20 // 4MB
// LowFragThreshold is the threshold we reach to when we start dropping
// older fragmented packets. It's important that we keep enough room for newer
// packets to be re-assembled. Hence, this needs to be lower than
// HighFragThreshold enough. Linux uses a default value of 3 MB. See
// net.ipv4.ipfrag_low_thresh for more information.
const LowFragThreshold = 3 << 20 // 3MB
// Fragmentation is the main structure that other modules
// of the stack should use to implement IP Fragmentation.
type Fragmentation struct {
mu sync.Mutex
highLimit int
lowLimit int
reassemblers map[uint32]*reassembler
rList reassemblerList
size int
timeout time.Duration
}
// NewFragmentation creates a new Fragmentation.
//
// highMemoryLimit specifies the limit on the memory consumed
// by the fragments stored by Fragmentation (overhead of internal data-structures
// is not accounted). Fragments are dropped when the limit is reached.
//
// lowMemoryLimit specifies the limit on which we will reach by dropping
// fragments after reaching highMemoryLimit.
//
// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
// Fragments are lazily evicted only when a new a packet with an
// already existing fragmentation-id arrives after the timeout.
func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
if lowMemoryLimit >= highMemoryLimit {
lowMemoryLimit = highMemoryLimit
}
if lowMemoryLimit < 0 {
lowMemoryLimit = 0
}
return &Fragmentation{
reassemblers: make(map[uint32]*reassembler),
highLimit: highMemoryLimit,
lowLimit: lowMemoryLimit,
timeout: reassemblingTimeout,
}
}
// Process processes an incoming fragment beloning to an ID
// and returns a complete packet when all the packets belonging to that ID have been received.
func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool) {
f.mu.Lock()
r, ok := f.reassemblers[id]
if ok && r.tooOld(f.timeout) {
// This is very likely to be an id-collision or someone performing a slow-rate attack.
f.release(r)
ok = false
}
if !ok {
r = newReassembler(id)
f.reassemblers[id] = r
f.rList.PushFront(r)
}
f.mu.Unlock()
res, done, consumed := r.process(first, last, more, vv)
f.mu.Lock()
f.size += consumed
if done {
f.release(r)
}
// Evict reassemblers if we are consuming more memory than highLimit until
// we reach lowLimit.
if f.size > f.highLimit {
tail := f.rList.Back()
for f.size > f.lowLimit && tail != nil {
f.release(tail)
tail = tail.Prev()
}
}
f.mu.Unlock()
return res, done
}
func (f *Fragmentation) release(r *reassembler) {
// Before releasing a fragment we need to check if r is already marked as done.
// Otherwise, we would delete it twice.
if r.checkDoneOrMark() {
return
}
delete(f.reassemblers, r.id)
f.rList.Remove(r)
f.size -= r.size
if f.size < 0 {
log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
f.size = 0
}
}

View File

@@ -0,0 +1,159 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fragmentation
import (
"reflect"
"testing"
"time"
"tcpip/netstack/tcpip/buffer"
)
// vv is a helper to build VectorisedView from different strings.
func vv(size int, pieces ...string) buffer.VectorisedView {
views := make([]buffer.View, len(pieces))
for i, p := range pieces {
views[i] = []byte(p)
}
return buffer.NewVectorisedView(size, views)
}
type processInput struct {
id uint32
first uint16
last uint16
more bool
vv buffer.VectorisedView
}
type processOutput struct {
vv buffer.VectorisedView
done bool
}
var processTestCases = []struct {
comment string
in []processInput
out []processOutput
}{
{
comment: "One ID",
in: []processInput{
{id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
{id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
{vv: vv(4, "01", "23"), done: true},
},
},
{
comment: "Two IDs",
in: []processInput{
{id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
{id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
{id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
{id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
},
out: []processOutput{
{vv: buffer.VectorisedView{}, done: false},
{vv: buffer.VectorisedView{}, done: false},
{vv: vv(4, "ab", "cd"), done: true},
{vv: vv(4, "01", "23"), done: true},
},
},
}
func TestFragmentationProcess(t *testing.T) {
for _, c := range processTestCases {
t.Run(c.comment, func(t *testing.T) {
f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
for i, in := range c.in {
vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
if !reflect.DeepEqual(vv, c.out[i].vv) {
t.Errorf("got Process(%d) = %+v, want = %+v", i, vv, c.out[i].vv)
}
if done != c.out[i].done {
t.Errorf("got Process(%d) = %+v, want = %+v", i, done, c.out[i].done)
}
if c.out[i].done {
if _, ok := f.reassemblers[in.id]; ok {
t.Errorf("Process(%d) did not remove buffer from reassemblers", i)
}
for n := f.rList.Front(); n != nil; n = n.Next() {
if n.id == in.id {
t.Errorf("Process(%d) did not remove buffer from rList", i)
}
}
}
}
})
}
}
func TestReassemblingTimeout(t *testing.T) {
timeout := time.Millisecond
f := NewFragmentation(1024, 512, timeout)
// Send first fragment with id = 0, first = 0, last = 0, and more = true.
f.Process(0, 0, 0, true, vv(1, "0"))
// Sleep more than the timeout.
time.Sleep(2 * timeout)
// Send another fragment that completes a packet.
// However, no packet should be reassembled because the fragment arrived after the timeout.
_, done := f.Process(0, 1, 1, false, vv(1, "1"))
if done {
t.Errorf("Fragmentation does not respect the reassembling timeout.")
}
}
func TestMemoryLimits(t *testing.T) {
f := NewFragmentation(3, 1, DefaultReassembleTimeout)
// Send first fragment with id = 0.
f.Process(0, 0, 0, true, vv(1, "0"))
// Send first fragment with id = 1.
f.Process(1, 0, 0, true, vv(1, "1"))
// Send first fragment with id = 2.
f.Process(2, 0, 0, true, vv(1, "2"))
// Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
// evicted.
f.Process(3, 0, 0, true, vv(1, "3"))
if _, ok := f.reassemblers[0]; ok {
t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
}
if _, ok := f.reassemblers[1]; ok {
t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
}
if _, ok := f.reassemblers[3]; !ok {
t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
}
}
func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
f := NewFragmentation(1, 0, DefaultReassembleTimeout)
// Send first fragment with id = 0.
f.Process(0, 0, 0, true, vv(1, "0"))
// Send the same packet again.
f.Process(0, 0, 0, true, vv(1, "0"))
got := f.size
want := 1
if got != want {
t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
}
}

View File

@@ -0,0 +1,118 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fragmentation
import (
"container/heap"
"fmt"
"math"
"sync"
"time"
"tcpip/netstack/tcpip/buffer"
)
type hole struct {
first uint16
last uint16
deleted bool
}
type reassembler struct {
reassemblerEntry
id uint32
size int
mu sync.Mutex
holes []hole
deleted int
heap fragHeap
done bool
creationTime time.Time
}
func newReassembler(id uint32) *reassembler {
r := &reassembler{
id: id,
holes: make([]hole, 0, 16),
deleted: 0,
heap: make(fragHeap, 0, 8),
creationTime: time.Now(),
}
r.holes = append(r.holes, hole{
first: 0,
last: math.MaxUint16,
deleted: false})
return r
}
// updateHoles updates the list of holes for an incoming fragment and
// returns true iff the fragment filled at least part of an existing hole.
func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
used := false
for i := range r.holes {
if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
continue
}
used = true
r.deleted++
r.holes[i].deleted = true
if first > r.holes[i].first {
r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
}
if last < r.holes[i].last && more {
r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
}
}
return used
}
func (r *reassembler) process(first, last uint16, more bool, vv buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
r.mu.Lock()
defer r.mu.Unlock()
consumed := 0
if r.done {
// A concurrent goroutine might have already reassembled
// the packet and emptied the heap while this goroutine
// was waiting on the mutex. We don't have to do anything in this case.
return buffer.VectorisedView{}, false, consumed
}
if r.updateHoles(first, last, more) {
// We store the incoming packet only if it filled some holes.
heap.Push(&r.heap, fragment{offset: first, vv: vv.Clone(nil)})
consumed = vv.Size()
r.size += consumed
}
// Check if all the holes have been deleted and we are ready to reassamble.
if r.deleted < len(r.holes) {
return buffer.VectorisedView{}, false, consumed
}
res, err := r.heap.reassemble()
if err != nil {
panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
}
return res, true, consumed
}
func (r *reassembler) tooOld(timeout time.Duration) bool {
return time.Now().Sub(r.creationTime) > timeout
}
func (r *reassembler) checkDoneOrMark() bool {
r.mu.Lock()
prev := r.done
r.done = true
r.mu.Unlock()
return prev
}

View File

@@ -0,0 +1,173 @@
package fragmentation
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type reassemblerElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (reassemblerElementMapper) linkerFor(elem *reassembler) *reassembler { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type reassemblerList struct {
head *reassembler
tail *reassembler
}
// Reset resets list l to the empty state.
func (l *reassemblerList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
func (l *reassemblerList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
func (l *reassemblerList) Front() *reassembler {
return l.head
}
// Back returns the last element of list l or nil.
func (l *reassemblerList) Back() *reassembler {
return l.tail
}
// PushFront inserts the element e at the front of list l.
func (l *reassemblerList) PushFront(e *reassembler) {
reassemblerElementMapper{}.linkerFor(e).SetNext(l.head)
reassemblerElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
reassemblerElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushBack inserts the element e at the back of list l.
func (l *reassemblerList) PushBack(e *reassembler) {
reassemblerElementMapper{}.linkerFor(e).SetNext(nil)
reassemblerElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
reassemblerElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
func (l *reassemblerList) PushBackList(m *reassemblerList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
reassemblerElementMapper{}.linkerFor(l.tail).SetNext(m.head)
reassemblerElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *reassemblerList) InsertAfter(b, e *reassembler) {
a := reassemblerElementMapper{}.linkerFor(b).Next()
reassemblerElementMapper{}.linkerFor(e).SetNext(a)
reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
reassemblerElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
func (l *reassemblerList) InsertBefore(a, e *reassembler) {
b := reassemblerElementMapper{}.linkerFor(a).Prev()
reassemblerElementMapper{}.linkerFor(e).SetNext(a)
reassemblerElementMapper{}.linkerFor(e).SetPrev(b)
reassemblerElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
reassemblerElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
func (l *reassemblerList) Remove(e *reassembler) {
prev := reassemblerElementMapper{}.linkerFor(e).Prev()
next := reassemblerElementMapper{}.linkerFor(e).Next()
if prev != nil {
reassemblerElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
reassemblerElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type reassemblerEntry struct {
next *reassembler
prev *reassembler
}
// Next returns the entry that follows e in the list.
func (e *reassemblerEntry) Next() *reassembler {
return e.next
}
// Prev returns the entry that precedes e in the list.
func (e *reassemblerEntry) Prev() *reassembler {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
func (e *reassemblerEntry) SetNext(elem *reassembler) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *reassemblerEntry) SetPrev(elem *reassembler) {
e.prev = elem
}

View File

@@ -0,0 +1,105 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package fragmentation
import (
"math"
"reflect"
"testing"
)
type updateHolesInput struct {
first uint16
last uint16
more bool
}
var holesTestCases = []struct {
comment string
in []updateHolesInput
want []hole
}{
{
comment: "No fragments. Expected holes: {[0 -> inf]}.",
in: []updateHolesInput{},
want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
},
{
comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
in: []updateHolesInput{{first: 0, last: 1, more: true}},
want: []hole{
{first: 0, last: math.MaxUint16, deleted: true},
{first: 2, last: math.MaxUint16, deleted: false},
},
},
{
comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
in: []updateHolesInput{{first: 1, last: 2, more: true}},
want: []hole{
{first: 0, last: math.MaxUint16, deleted: true},
{first: 0, last: 0, deleted: false},
{first: 3, last: math.MaxUint16, deleted: false},
},
},
{
comment: "One fragment at the end. Expected holes: {[0, 0]}.",
in: []updateHolesInput{{first: 1, last: 2, more: false}},
want: []hole{
{first: 0, last: math.MaxUint16, deleted: true},
{first: 0, last: 0, deleted: false},
},
},
{
comment: "One fragment completing a packet. Expected holes: {}.",
in: []updateHolesInput{{first: 0, last: 1, more: false}},
want: []hole{
{first: 0, last: math.MaxUint16, deleted: true},
},
},
{
comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
in: []updateHolesInput{
{first: 0, last: 1, more: true},
{first: 2, last: 3, more: false},
},
want: []hole{
{first: 0, last: math.MaxUint16, deleted: true},
{first: 2, last: math.MaxUint16, deleted: true},
},
},
{
comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
in: []updateHolesInput{
{first: 0, last: 2, more: true},
{first: 2, last: 3, more: false},
},
want: []hole{
{first: 0, last: math.MaxUint16, deleted: true},
{first: 3, last: math.MaxUint16, deleted: true},
},
},
}
func TestUpdateHoles(t *testing.T) {
for _, c := range holesTestCases {
r := newReassembler(0)
for _, i := range c.in {
r.updateHoles(i.first, i.last, i.more)
}
if !reflect.DeepEqual(r.holes, c.want) {
t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
}
}
}

View File

@@ -0,0 +1,94 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package hash contains utility functions for hashing.
package hash
import (
"encoding/binary"
"tcpip/netstack/rand"
"tcpip/netstack/tcpip/header"
)
var hashIV = RandN32(1)[0]
// RandN32 generates a slice of n cryptographic random 32-bit numbers.
func RandN32(n int) []uint32 {
b := make([]byte, 4*n)
if _, err := rand.Read(b); err != nil {
panic("unable to get random numbers: " + err.Error())
}
r := make([]uint32, n)
for i := range r {
r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
}
return r
}
// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
// from linux.
func Hash3Words(a, b, c, initval uint32) uint32 {
const iv = 0xdeadbeef + (3 << 2)
initval += iv
a += initval
b += initval
c += initval
c ^= b
c -= rol32(b, 14)
a ^= c
a -= rol32(c, 11)
b ^= a
b -= rol32(a, 25)
c ^= b
c -= rol32(b, 16)
a ^= c
a -= rol32(c, 4)
b ^= a
b -= rol32(a, 14)
c ^= b
c -= rol32(b, 24)
return c
}
// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
// 根据id源ip目的ip和协议类型得到hash值
func IPv4FragmentHash(h header.IPv4) uint32 {
x := uint32(h.ID())<<16 | uint32(h.Protocol())
t := h.SourceAddress()
y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = h.DestinationAddress()
z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
return Hash3Words(x, y, z, hashIV)
}
// IPv6FragmentHash computes the hash of the ipv6 fragment.
// Unlike IPv4, the protocol is not used to compute the hash.
// RFC 2640 (sec 4.5) is not very sharp on this aspect.
// As a reference, also Linux ignores the protocol to compute
// the hash (inet6_hash_frag).
func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
t := h.SourceAddress()
y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = h.DestinationAddress()
z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
return Hash3Words(f.ID(), y, z, hashIV)
}
func rol32(v, shift uint32) uint32 {
return (v << shift) | (v >> ((-shift) & 31))
}

View File

@@ -0,0 +1,604 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ip_test
import (
"testing"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/link/loopback"
"tcpip/netstack/tcpip/network/ipv4"
"tcpip/netstack/tcpip/network/ipv6"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/tcpip/transport/tcp"
"tcpip/netstack/tcpip/transport/udp"
)
const (
localIpv4Addr = "\x0a\x00\x00\x01"
remoteIpv4Addr = "\x0a\x00\x00\x02"
ipv4SubnetAddr = "\x0a\x00\x00\x00"
ipv4SubnetMask = "\xff\xff\xff\x00"
ipv4Gateway = "\x0a\x00\x00\x03"
localIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
remoteIpv6Addr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
ipv6SubnetAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
ipv6SubnetMask = "\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"
ipv6Gateway = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x03"
)
// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
// The former is used to pretend that it's a link endpoint so that we can
// inspect packets written by the network endpoints. The latter is used to
// pretend that it's the network stack so that it can inspect incoming packets
// that have been handled by the network endpoints.
//
// Packets are checked by comparing their fields/values against the expected
// values stored in the test object itself.
type testObject struct {
t *testing.T
protocol tcpip.TransportProtocolNumber
contents []byte
srcAddr tcpip.Address
dstAddr tcpip.Address
v4 bool
typ stack.ControlType
extra uint32
dataCalls int
controlCalls int
}
// checkValues verifies that the transport protocol, data contents, src & dst
// addresses of a packet match what's expected. If any field doesn't match, the
// test fails.
func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
v := vv.ToView()
if protocol != t.protocol {
t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
}
if srcAddr != t.srcAddr {
t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
}
if dstAddr != t.dstAddr {
t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
}
if len(v) != len(t.contents) {
t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
}
for i := range t.contents {
if t.contents[i] != v[i] {
t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
}
}
}
// DeliverTransportPacket is called by network endpoints after parsing incoming
// packets. This is used by the test object to verify that the results of the
// parsing are expected.
func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
t.dataCalls++
}
// DeliverTransportControlPacket is called by network endpoints after parsing
// incoming control (ICMP) packets. This is used by the test object to verify
// that the results of the parsing are expected.
func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
t.checkValues(trans, vv, remote, local)
if typ != t.typ {
t.t.Errorf("typ = %v, want %v", typ, t.typ)
}
if extra != t.extra {
t.t.Errorf("extra = %v, want %v", extra, t.extra)
}
t.controlCalls++
}
// Attach is only implemented to satisfy the LinkEndpoint interface.
func (*testObject) Attach(stack.NetworkDispatcher) {}
// IsAttached implements stack.LinkEndpoint.IsAttached.
func (*testObject) IsAttached() bool {
return true
}
// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
// matches the linux loopback MTU.
func (*testObject) MTU() uint32 {
return 65536
}
// Capabilities implements stack.LinkEndpoint.Capabilities.
func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
return 0
}
// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
func (*testObject) MaxHeaderLength() uint16 {
return 0
}
// LinkAddress returns the link address of this endpoint.
func (*testObject) LinkAddress() tcpip.LinkAddress {
return ""
}
// WritePacket is called by network endpoints after producing a packet and
// writing it to the link endpoint. This is used by the test object to verify
// that the produced packet is as expected.
func (t *testObject) WritePacket(_ *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
var prot tcpip.TransportProtocolNumber
var srcAddr tcpip.Address
var dstAddr tcpip.Address
if t.v4 {
h := header.IPv4(hdr.View())
prot = tcpip.TransportProtocolNumber(h.Protocol())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
} else {
h := header.IPv6(hdr.View())
prot = tcpip.TransportProtocolNumber(h.NextHeader())
srcAddr = h.SourceAddress()
dstAddr = h.DestinationAddress()
}
t.checkValues(prot, payload, srcAddr, dstAddr)
return nil
}
func buildIPv4Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv4.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: ipv4SubnetAddr,
Mask: ipv4SubnetMask,
Gateway: ipv4Gateway,
NIC: 1,
}})
return s.FindRoute(1, local, remote, ipv4.ProtocolNumber)
}
func buildIPv6Route(local, remote tcpip.Address) (stack.Route, *tcpip.Error) {
s := stack.New([]string{ipv6.ProtocolName}, []string{udp.ProtocolName, tcp.ProtocolName}, stack.Options{})
s.CreateNIC(1, loopback.New())
s.AddAddress(1, ipv6.ProtocolNumber, local)
s.SetRouteTable([]tcpip.Route{{
Destination: ipv6SubnetAddr,
Mask: ipv6SubnetMask,
Gateway: ipv6Gateway,
NIC: 1,
}})
return s.FindRoute(1, local, remote, ipv6.ProtocolNumber)
}
func TestIPv4Send(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, nil, &o)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
for i := 0; i < len(payload); i++ {
payload[i] = uint8(i)
}
// Allocate the header buffer.
hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
// Issue the write.
o.protocol = 123
o.srcAddr = localIpv4Addr
o.dstAddr = remoteIpv4Addr
o.contents = payload
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
func TestIPv4Receive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
totalLen := header.IPv4MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv4(view)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
SrcAddr: remoteIpv4Addr,
DstAddr: localIpv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
view[i] = uint8(i)
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
o.protocol = 10
o.srcAddr = remoteIpv4Addr
o.dstAddr = localIpv4Addr
o.contents = view[header.IPv4MinimumSize:totalLen]
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
ep.HandlePacket(&r, view.ToVectorisedView())
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
}
func TestIPv4ReceiveControl(t *testing.T) {
const mtu = 0xbeef - header.IPv4MinimumSize
cases := []struct {
name string
expectedCount int
fragmentOffset uint16
code uint8
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
}{
{"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
{"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
{"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
{"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
{"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8},
{"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8},
}
r := stack.Route{
LocalAddress: localIpv4Addr,
RemoteAddress: "\x0a\x00\x00\xbb",
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv4.NewProtocol()
ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep.Close()
const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4
view := buffer.NewView(dataOffset + 8)
// Create the outer IPv4 header.
ip := header.IPv4(view)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(len(view) - c.trunc),
TTL: 20,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: "\x0a\x00\x00\xbb",
DstAddr: localIpv4Addr,
})
// Create the ICMP header.
icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
icmp.SetType(header.ICMPv4DstUnreachable)
icmp.SetCode(c.code)
copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
// Create the inner IPv4 header.
ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:])
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: 100,
TTL: 20,
Protocol: 10,
FragmentOffset: c.fragmentOffset,
SrcAddr: localIpv4Addr,
DstAddr: remoteIpv4Addr,
})
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
}
// Give packet to IPv4 endpoint, dispatcher will validate that
// it's ok.
o.protocol = 10
o.srcAddr = remoteIpv4Addr
o.dstAddr = localIpv4Addr
o.contents = view[dataOffset:]
o.typ = c.expectedTyp
o.extra = c.expectedExtra
vv := view[:len(view)-c.trunc].ToVectorisedView()
ep.HandlePacket(&r, vv)
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
})
}
}
func TestIPv4FragmentationReceive(t *testing.T) {
o := testObject{t: t, v4: true}
proto := ipv4.NewProtocol()
ep, err := proto.NewEndpoint(1, localIpv4Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
totalLen := header.IPv4MinimumSize + 24
frag1 := buffer.NewView(totalLen)
ip1 := header.IPv4(frag1)
ip1.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
FragmentOffset: 0,
Flags: header.IPv4FlagMoreFragments,
SrcAddr: remoteIpv4Addr,
DstAddr: localIpv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag1[i] = uint8(i)
}
frag2 := buffer.NewView(totalLen)
ip2 := header.IPv4(frag2)
ip2.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(totalLen),
TTL: 20,
Protocol: 10,
FragmentOffset: 24,
SrcAddr: remoteIpv4Addr,
DstAddr: localIpv4Addr,
})
// Make payload be non-zero.
for i := header.IPv4MinimumSize; i < totalLen; i++ {
frag2[i] = uint8(i)
}
// Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
o.protocol = 10
o.srcAddr = remoteIpv4Addr
o.dstAddr = localIpv4Addr
o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
r, err := buildIPv4Route(localIpv4Addr, remoteIpv4Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
// Send first segment.
ep.HandlePacket(&r, frag1.ToVectorisedView())
if o.dataCalls != 0 {
t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
}
// Send second segment.
ep.HandlePacket(&r, frag2.ToVectorisedView())
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
}
func TestIPv6Send(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, nil, &o)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
// Allocate and initialize the payload view.
payload := buffer.NewView(100)
for i := 0; i < len(payload); i++ {
payload[i] = uint8(i)
}
// Allocate the header buffer.
hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
// Issue the write.
o.protocol = 123
o.srcAddr = localIpv6Addr
o.dstAddr = remoteIpv6Addr
o.contents = payload
r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
if err := ep.WritePacket(&r, hdr, payload.ToVectorisedView(), 123, 123); err != nil {
t.Fatalf("WritePacket failed: %v", err)
}
}
func TestIPv6Receive(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
totalLen := header.IPv6MinimumSize + 30
view := buffer.NewView(totalLen)
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
NextHeader: 10,
HopLimit: 20,
SrcAddr: remoteIpv6Addr,
DstAddr: localIpv6Addr,
})
// Make payload be non-zero.
for i := header.IPv6MinimumSize; i < totalLen; i++ {
view[i] = uint8(i)
}
// Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
o.protocol = 10
o.srcAddr = remoteIpv6Addr
o.dstAddr = localIpv6Addr
o.contents = view[header.IPv6MinimumSize:totalLen]
r, err := buildIPv6Route(localIpv6Addr, remoteIpv6Addr)
if err != nil {
t.Fatalf("could not find route: %v", err)
}
ep.HandlePacket(&r, view.ToVectorisedView())
if o.dataCalls != 1 {
t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
}
}
func TestIPv6ReceiveControl(t *testing.T) {
newUint16 := func(v uint16) *uint16 { return &v }
const mtu = 0xffff
cases := []struct {
name string
expectedCount int
fragmentOffset *uint16
typ header.ICMPv6Type
code uint8
expectedTyp stack.ControlType
expectedExtra uint32
trunc int
}{
{"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
{"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
{"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
{"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
{"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
{"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
{"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
{"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
}
r := stack.Route{
LocalAddress: localIpv6Addr,
RemoteAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
o := testObject{t: t}
proto := ipv6.NewProtocol()
ep, err := proto.NewEndpoint(1, localIpv6Addr, nil, &o, nil)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep.Close()
dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4
if c.fragmentOffset != nil {
dataOffset += header.IPv6FragmentHeaderSize
}
view := buffer.NewView(dataOffset + 8)
// Create the outer IPv6 header.
ip := header.IPv6(view)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: 20,
SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
DstAddr: localIpv6Addr,
})
// Create the ICMP header.
icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
icmp.SetType(c.typ)
icmp.SetCode(c.code)
copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
// Create the inner IPv6 header.
ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
ip.Encode(&header.IPv6Fields{
PayloadLength: 100,
NextHeader: 10,
HopLimit: 20,
SrcAddr: localIpv6Addr,
DstAddr: remoteIpv6Addr,
})
// Build the fragmentation header if needed.
if c.fragmentOffset != nil {
ip.SetNextHeader(header.IPv6FragmentHeader)
frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
frag.Encode(&header.IPv6FragmentFields{
NextHeader: 10,
FragmentOffset: *c.fragmentOffset,
M: true,
Identification: 0x12345678,
})
}
// Make payload be non-zero.
for i := dataOffset; i < len(view); i++ {
view[i] = uint8(i)
}
// Give packet to IPv6 endpoint, dispatcher will validate that
// it's ok.
o.protocol = 10
o.srcAddr = remoteIpv6Addr
o.dstAddr = localIpv6Addr
o.contents = view[dataOffset:]
o.typ = c.expectedTyp
o.extra = c.expectedExtra
vv := view[:len(view)-c.trunc].ToVectorisedView()
ep.HandlePacket(&r, vv)
if want := c.expectedCount; o.controlCalls != want {
t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
}
})
}
}

View File

@@ -0,0 +1,134 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv4
import (
"encoding/binary"
"log"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
)
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
// handleControl处理ICMP数据包包含导致ICMP发送的原始数据包的标头的情况。
// 此信息用于确定必须通知哪个传输端点有关ICMP数据包。
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
h := header.IPv4(vv.First())
// We don't use IsValid() here because ICMP only requires that the IP
// header plus 8 bytes of the transport header be included. So it's
// likely that it is truncated, which would cause IsValid to return
// false.
//
// Drop packet if it doesn't have the basic IPv4 header or if the
// original source address doesn't match the endpoint's address.
if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
return
}
hlen := int(h.HeaderLength())
if vv.Size() < hlen || h.FragmentOffset() != 0 {
// We won't be able to handle this if it doesn't contain the
// full IPv4 header, or if it's a fragment not at offset 0
// (because it won't have the transport header).
return
}
// Skip the ip header, then deliver control message.
vv.TrimFront(hlen)
p := h.TransportProtocol()
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
}
// 处理ICMP报文
func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv4MinimumSize {
return
}
h := header.ICMPv4(v)
// 更具icmp的类型来进行相应的处理
switch h.Type() {
case header.ICMPv4Echo: // icmp echo请求
if len(v) < header.ICMPv4EchoMinimumSize {
return
}
log.Printf("icmp echo")
vv.TrimFront(header.ICMPv4MinimumSize)
req := echoRequest{r: r.Clone(), v: vv.ToView()}
select {
case e.echoRequests <- req: // 发送给echoReplier处理
default:
req.r.Release()
}
case header.ICMPv4EchoReply: // icmp echo响应
if len(v) < header.ICMPv4EchoMinimumSize {
return
}
e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, vv)
case header.ICMPv4DstUnreachable: // 目标不可达
if len(v) < header.ICMPv4DstUnreachableMinimumSize {
return
}
vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize)
switch h.Code() {
case header.ICMPv4PortUnreachable: // 端口不可达
e.handleControl(stack.ControlPortUnreachable, 0, vv)
case header.ICMPv4FragmentationNeeded: // 需要进行分片但设置不分片标志
mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:]))
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
}
}
// TODO: Handle other ICMP types.
}
type echoRequest struct {
r stack.Route
v buffer.View
}
// 处理icmp echo请求的goroutine
func (e *endpoint) echoReplier() {
for req := range e.echoRequests {
sendPing4(&req.r, 0, req.v)
req.r.Release()
}
}
// 根据icmp echo请求封装icmp echo响应报文并传给ip层处理
func sendPing4(r *stack.Route, code byte, data buffer.View) *tcpip.Error {
hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
icmpv4.SetType(header.ICMPv4EchoReply)
icmpv4.SetCode(code)
copy(icmpv4[header.ICMPv4MinimumSize:], data)
data = data[header.ICMPv4EchoMinimumSize-header.ICMPv4MinimumSize:]
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
log.Printf("icmp reply")
// 传给ip层处理
return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
}

View File

@@ -0,0 +1,280 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ipv4 contains the implementation of the ipv4 network protocol. To use
// it in the networking stack, this package must be added to the project, and
// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
// network protocols when calling stack.New(). Then endpoints can be created
// by passing ipv4.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv4
import (
"log"
"sync/atomic"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/network/fragmentation"
"tcpip/netstack/tcpip/network/hash"
"tcpip/netstack/tcpip/stack"
)
const (
// ProtocolName is the string representation of the ipv4 protocol name.
ProtocolName = "ipv4"
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
// maxTotalSize is maximum size that can be encoded in the 16-bit
// TotalLength field of the ipv4 header.
maxTotalSize = 0xffff
// buckets is the number of identifier buckets.
buckets = 2048
)
// ip端表示
type endpoint struct {
// 网卡id
nicid tcpip.NICID
// 表示该endpoint的id也是ip地址
id stack.NetworkEndpointID
// 链路端的表示
linkEP stack.LinkEndpoint
// 报文分发器
dispatcher stack.TransportDispatcher
// ping请求报文接收队列
echoRequests chan echoRequest
// ip报文分片处理器
fragmentation *fragmentation.Fragmentation
}
// NewEndpoint creates a new ipv4 endpoint.
// 根据参数新建一个ipv4端
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache,
dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
nicid: nicid,
id: stack.NetworkEndpointID{LocalAddress: addr},
linkEP: linkEP,
dispatcher: dispatcher,
echoRequests: make(chan echoRequest, 10),
fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold,
fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
}
// 启动一个goroutine给icmp请求服务
go e.echoReplier()
return e, nil
}
// DefaultTTL is the default time-to-live value for this endpoint.
// 默认的TTL值TTL每经过路由转发一次就会减1
func (e *endpoint) DefaultTTL() uint8 {
return 255
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
// 获取去除ipv4头部后的最大报文长度
func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
// Capabilities implements stack.NetworkEndpoint.Capabilities.
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
}
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
return e.nicid
}
// ID returns the ipv4 endpoint ID.
// 获取该网络层端的id也就是ip地址
func (e *endpoint) ID() *stack.NetworkEndpointID {
return &e.id
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
// 链路层和网络层的头部长度
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
}
// WritePacket writes a packet to the given destination address and protocol.
// 将传输层的数据封装加上IP头并调用网卡的写入接口写入IP报文
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView,
protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
// 预留ip报文的空间
ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
length := uint16(hdr.UsedLength() + payload.Size())
id := uint32(0)
// 如果报文长度大于68
if length > header.IPv4MaximumHeaderSize+8 {
// Packets of 68 bytes or less are required by RFC 791 to not be
// fragmented, so we only assign ids to larger packets.
id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
}
// ip首部编码
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: length,
ID: uint16(id),
TTL: ttl,
Protocol: uint8(protocol),
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
// 计算校验和和设置校验和
ip.SetChecksum(^ip.CalculateChecksum())
r.Stats().IP.PacketsSent.Increment()
// 写入网卡接口
log.Printf("send ipv4 packet %d bytes, proto: 0x%x", hdr.UsedLength()+payload.Size(), protocol)
return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
// 收到ip包的处理
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// 得到ip报文
h := header.IPv4(vv.First())
// 检查报文是否有效
if !h.IsValid(vv.Size()) {
return
}
hlen := int(h.HeaderLength())
tlen := int(h.TotalLength())
vv.TrimFront(hlen)
vv.CapLength(tlen - hlen)
// 检查ip报文是否有更多的分片
more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
// 是否需要ip重组
if more || h.FragmentOffset() != 0 {
// The packet is a fragment, let's try to reassemble it.
last := h.FragmentOffset() + uint16(vv.Size()) - 1
var ready bool
// ip分片重组
vv, ready = e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
if !ready {
return
}
}
// 得到传输层的协议
p := h.TransportProtocol()
// 如果时ICMP协议则进入ICMP处理函数
if p == header.ICMPv4ProtocolNumber {
e.handleICMP(r, vv)
return
}
r.Stats().IP.PacketsDelivered.Increment()
// 根据协议分发到不通处理函数比如协议时TCP会进入tcp.HandlePacket
log.Printf("recv ipv4 packet %d bytes, proto: 0x%x", tlen, p)
e.dispatcher.DeliverTransportPacket(r, p, vv)
}
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {
close(e.echoRequests)
}
// 实现NetworkProtocol接口
type protocol struct{}
// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
// only for tests that short-circuit the stack. Regular use of the protocol is
// done via the stack, which gets a protocol descriptor from the init() function
// below.
func NewProtocol() stack.NetworkProtocol {
return &protocol{}
}
// Number returns the ipv4 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
}
// MinimumPacketSize returns the minimum valid ipv4 packet size.
func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv4(v)
return h.SourceAddress(), h.DestinationAddress()
}
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
if mtu > maxTotalSize {
mtu = maxTotalSize
}
return mtu - header.IPv4MinimumSize
}
// hashRoute calculates a hash value for the given route. It uses the source &
// destination address, the transport protocol number, and a random initial
// value (generated once on initialization) to generate the hash.
func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
t := r.LocalAddress
a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
t = r.RemoteAddress
b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
return hash.Hash3Words(a, b, uint32(protocol), hashIV)
}
var (
ids []uint32
hashIV uint32
)
// 初始化的时候初始化buckets用于hash计算
// 并在网络层协议中注册ipv4协议
func init() {
ids = make([]uint32, buckets)
// Randomly initialize hashIV and the ids.
r := hash.RandN32(1 + buckets)
for i := range ids {
ids[i] = r[i]
}
hashIV = r[buckets]
// 注册网络层协议
stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
return &protocol{}
})
}

View File

@@ -0,0 +1,92 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv4_test
import (
"testing"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/link/channel"
"tcpip/netstack/tcpip/link/sniffer"
"tcpip/netstack/tcpip/network/ipv4"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/tcpip/transport/udp"
"tcpip/netstack/waiter"
)
func TestExcludeBroadcast(t *testing.T) {
s := stack.New([]string{ipv4.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
const defaultMTU = 65536
id, _ := channel.New(256, defaultMTU, "")
if testing.Verbose() {
id = sniffer.New(id)
}
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Broadcast); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, header.IPv4Any); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: "\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
}})
randomAddr := tcpip.FullAddress{NIC: 1, Addr: "\x0a\x00\x00\x01", Port: 53}
var wq waiter.Queue
t.Run("WithoutPrimaryAddress", func(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
t.Fatal(err)
}
defer ep.Close()
// Cannot connect using a broadcast address as the source.
if err := ep.Connect(randomAddr); err != tcpip.ErrNoRoute {
t.Errorf("got ep.Connect(...) = %v, want = %v", err, tcpip.ErrNoRoute)
}
// However, we can bind to a broadcast address to listen.
if err := ep.Bind(tcpip.FullAddress{Addr: header.IPv4Broadcast, Port: 53, NIC: 1}, nil); err != nil {
t.Errorf("Bind failed: %v", err)
}
})
t.Run("WithPrimaryAddress", func(t *testing.T) {
ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
if err != nil {
t.Fatal(err)
}
defer ep.Close()
// Add a valid primary endpoint address, now we can connect.
if err := s.AddAddress(1, ipv4.ProtocolNumber, "\x0a\x00\x00\x02"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := ep.Connect(randomAddr); err != nil {
t.Errorf("Connect failed: %v", err)
}
})
}

View File

@@ -0,0 +1,231 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv6
import (
"encoding/binary"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
)
// handleControl handles the case when an ICMP packet contains the headers of
// the original packet that caused the ICMP one to be sent. This information is
// used to find out which transport endpoint must be notified about the ICMP
// packet.
func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
h := header.IPv6(vv.First())
// We don't use IsValid() here because ICMP only requires that up to
// 1280 bytes of the original packet be included. So it's likely that it
// is truncated, which would cause IsValid to return false.
//
// Drop packet if it doesn't have the basic IPv6 header or if the
// original source address doesn't match the endpoint's address.
if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
return
}
// Skip the IP header, then handle the fragmentation header if there
// is one.
vv.TrimFront(header.IPv6MinimumSize)
p := h.TransportProtocol()
if p == header.IPv6FragmentHeader {
f := header.IPv6Fragment(vv.First())
if !f.IsValid() || f.FragmentOffset() != 0 {
// We can't handle fragments that aren't at offset 0
// because they don't have the transport headers.
return
}
// Skip fragmentation header and find out the actual protocol
// number.
vv.TrimFront(header.IPv6FragmentHeaderSize)
p = f.TransportProtocol()
}
// Deliver the control packet to the transport endpoint.
e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
}
func (e *endpoint) handleICMP(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
if len(v) < header.ICMPv6MinimumSize {
return
}
h := header.ICMPv6(v)
switch h.Type() {
case header.ICMPv6PacketTooBig:
if len(v) < header.ICMPv6PacketTooBigMinimumSize {
return
}
vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
case header.ICMPv6DstUnreachable:
if len(v) < header.ICMPv6DstUnreachableMinimumSize {
return
}
vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
switch h.Code() {
case header.ICMPv6PortUnreachable:
e.handleControl(stack.ControlPortUnreachable, 0, vv)
}
case header.ICMPv6NeighborSolicit:
if len(v) < header.ICMPv6NeighborSolicitMinimumSize {
return
}
targetAddr := tcpip.Address(v[8 : 8+16])
if e.linkAddrCache.CheckLocalAddress(e.nicid, ProtocolNumber, targetAddr) == 0 {
// We don't have a useful answer; the best we can do is ignore the request.
return
}
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborAdvert)
pkt[icmpV6FlagOffset] = ndpSolicitedFlag | ndpOverrideFlag
copy(pkt[icmpV6OptOffset-len(targetAddr):], targetAddr)
pkt[icmpV6OptOffset] = ndpOptDstLinkAddr
pkt[icmpV6LengthOffset] = 1
copy(pkt[icmpV6LengthOffset+1:], r.LocalLinkAddress[:])
// ICMPv6 Neighbor Solicit messages are always sent to
// specially crafted IPv6 multicast addresses. As a result, the
// route we end up with here has as its LocalAddress such a
// multicast address. It would be nonsense to claim that our
// source address is a multicast address, so we manually set
// the source address to the target address requested in the
// solicit message. Since that requires mutating the route, we
// must first clone it.
r := r.Clone()
defer r.Release()
r.LocalAddress = targetAddr
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
r.WritePacket(hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, r.DefaultTTL())
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
case header.ICMPv6NeighborAdvert:
if len(v) < header.ICMPv6NeighborAdvertSize {
return
}
targetAddr := tcpip.Address(v[8 : 8+16])
e.linkAddrCache.AddLinkAddress(e.nicid, targetAddr, r.RemoteLinkAddress)
if targetAddr != r.RemoteAddress {
e.linkAddrCache.AddLinkAddress(e.nicid, r.RemoteAddress, r.RemoteLinkAddress)
}
case header.ICMPv6EchoRequest:
if len(v) < header.ICMPv6EchoMinimumSize {
return
}
vv.TrimFront(header.ICMPv6EchoMinimumSize)
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
copy(pkt, h)
pkt.SetType(header.ICMPv6EchoReply)
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, vv))
r.WritePacket(hdr, vv, header.ICMPv6ProtocolNumber, r.DefaultTTL())
case header.ICMPv6EchoReply:
if len(v) < header.ICMPv6EchoMinimumSize {
return
}
e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, vv)
}
}
const (
ndpSolicitedFlag = 1 << 6
ndpOverrideFlag = 1 << 5
ndpOptSrcLinkAddr = 1
ndpOptDstLinkAddr = 2
icmpV6FlagOffset = 4
icmpV6OptOffset = 24
icmpV6LengthOffset = 25
)
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
var _ stack.LinkAddressResolver = (*protocol)(nil)
// LinkAddressProtocol implements stack.LinkAddressResolver.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv6ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
snaddr := header.SolicitedNodeAddr(addr)
r := &stack.Route{
LocalAddress: localAddr,
RemoteAddress: snaddr,
RemoteLinkAddress: broadcastMAC,
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6NeighborAdvertSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6NeighborAdvertSize))
pkt.SetType(header.ICMPv6NeighborSolicit)
copy(pkt[icmpV6OptOffset-len(addr):], addr)
pkt[icmpV6OptOffset] = ndpOptSrcLinkAddr
pkt[icmpV6LengthOffset] = 1
copy(pkt[icmpV6LengthOffset+1:], linkEP.LinkAddress())
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
length := uint16(hdr.UsedLength())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(header.ICMPv6ProtocolNumber),
HopLimit: defaultIPv6HopLimit,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
return "", false
}
func icmpChecksum(h header.ICMPv6, src, dst tcpip.Address, vv buffer.VectorisedView) uint16 {
// Calculate the IPv6 pseudo-header upper-layer checksum.
xsum := header.Checksum([]byte(src), 0)
xsum = header.Checksum([]byte(dst), xsum)
var upperLayerLength [4]byte
binary.BigEndian.PutUint32(upperLayerLength[:], uint32(len(h)+vv.Size()))
xsum = header.Checksum(upperLayerLength[:], xsum)
xsum = header.Checksum([]byte{0, 0, 0, uint8(header.ICMPv6ProtocolNumber)}, xsum)
for _, v := range vv.Views() {
xsum = header.Checksum(v, xsum)
}
// h[2:4] is the checksum itself, set it aside to avoid checksumming the checksum.
h2, h3 := h[2], h[3]
h[2], h[3] = 0, 0
xsum = ^header.Checksum(h, xsum)
h[2], h[3] = h2, h3
return xsum
}

View File

@@ -0,0 +1,226 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ipv6
import (
"context"
"runtime"
"strings"
"testing"
"time"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/link/channel"
"tcpip/netstack/tcpip/link/sniffer"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/tcpip/transport/ping"
"tcpip/netstack/waiter"
)
const (
linkAddr0 = tcpip.LinkAddress("\x01\x02\x03\x04\x05\x06")
linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0f")
)
var (
lladdr0 = header.LinkLocalAddr(linkAddr0)
lladdr1 = header.LinkLocalAddr(linkAddr1)
)
type icmpInfo struct {
typ header.ICMPv6Type
src tcpip.Address
}
type testContext struct {
t *testing.T
s0 *stack.Stack
s1 *stack.Stack
linkEP0 *channel.Endpoint
linkEP1 *channel.Endpoint
icmpCh chan icmpInfo
}
type endpointWithResolutionCapability struct {
stack.LinkEndpoint
}
func (e endpointWithResolutionCapability) Capabilities() stack.LinkEndpointCapabilities {
return e.LinkEndpoint.Capabilities() | stack.CapabilityResolutionRequired
}
func newTestContext(t *testing.T) *testContext {
c := &testContext{
t: t,
s0: stack.New([]string{ProtocolName}, []string{ping.ProtocolName6}, stack.Options{}),
s1: stack.New([]string{ProtocolName}, []string{ping.ProtocolName6}, stack.Options{}),
icmpCh: make(chan icmpInfo, 10),
}
const defaultMTU = 65536
_, linkEP0 := channel.New(256, defaultMTU, linkAddr0)
c.linkEP0 = linkEP0
wrappedEP0 := endpointWithResolutionCapability{LinkEndpoint: linkEP0}
id0 := stack.RegisterLinkEndpoint(wrappedEP0)
if testing.Verbose() {
id0 = sniffer.New(id0)
}
if err := c.s0.CreateNIC(1, id0); err != nil {
t.Fatalf("CreateNIC s0: %v", err)
}
if err := c.s0.AddAddress(1, ProtocolNumber, lladdr0); err != nil {
t.Fatalf("AddAddress lladdr0: %v", err)
}
if err := c.s0.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr0)); err != nil {
t.Fatalf("AddAddress sn lladdr0: %v", err)
}
_, linkEP1 := channel.New(256, defaultMTU, linkAddr1)
c.linkEP1 = linkEP1
wrappedEP1 := endpointWithResolutionCapability{LinkEndpoint: linkEP1}
id1 := stack.RegisterLinkEndpoint(wrappedEP1)
if err := c.s1.CreateNIC(1, id1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, lladdr1); err != nil {
t.Fatalf("AddAddress lladdr1: %v", err)
}
if err := c.s1.AddAddress(1, ProtocolNumber, header.SolicitedNodeAddr(lladdr1)); err != nil {
t.Fatalf("AddAddress sn lladdr1: %v", err)
}
c.s0.SetRouteTable(
[]tcpip.Route{{
Destination: lladdr1,
Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
NIC: 1,
}},
)
c.s1.SetRouteTable(
[]tcpip.Route{{
Destination: lladdr0,
Mask: tcpip.AddressMask(strings.Repeat("\xff", 16)),
NIC: 1,
}},
)
go c.routePackets(linkEP0.C, linkEP1)
go c.routePackets(linkEP1.C, linkEP0)
return c
}
func (c *testContext) countPacket(pkt channel.PacketInfo) {
if pkt.Proto != ProtocolNumber {
return
}
ipv6 := header.IPv6(pkt.Header)
transProto := tcpip.TransportProtocolNumber(ipv6.NextHeader())
if transProto != header.ICMPv6ProtocolNumber {
return
}
b := pkt.Header[header.IPv6MinimumSize:]
icmp := header.ICMPv6(b)
c.icmpCh <- icmpInfo{
typ: icmp.Type(),
src: ipv6.SourceAddress(),
}
}
func (c *testContext) routePackets(ch <-chan channel.PacketInfo, ep *channel.Endpoint) {
for pkt := range ch {
c.countPacket(pkt)
views := []buffer.View{pkt.Header, pkt.Payload}
size := len(pkt.Header) + len(pkt.Payload)
vv := buffer.NewVectorisedView(size, views)
ep.InjectLinkAddr(pkt.Proto, ep.LinkAddress(), vv)
}
}
func (c *testContext) cleanup() {
close(c.linkEP0.C)
close(c.linkEP1.C)
}
func TestLinkResolution(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
r, err := c.s0.FindRoute(1, lladdr0, lladdr1, ProtocolNumber)
if err != nil {
t.Fatal(err)
}
defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.IPv6MinimumSize + header.ICMPv6EchoMinimumSize)
pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
pkt.SetType(header.ICMPv6EchoRequest)
pkt.SetChecksum(icmpChecksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{}))
payload := tcpip.SlicePayload(hdr.View())
// We can't send our payload directly over the route because that
// doesn't provoke NDP discovery.
var wq waiter.Queue
ep, err := c.s0.NewEndpoint(header.ICMPv6ProtocolNumber, ProtocolNumber, &wq)
if err != nil {
t.Fatal(err)
}
// This actually takes about 10 milliseconds, so no need to wait for
// a multi-minute go test timeout if something is broken.
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
for {
if ctx.Err() != nil {
break
}
if _, _, err := ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{NIC: 1, Addr: lladdr1}}); err == tcpip.ErrNoLinkAddress {
// There's something asynchronous going on; yield to let it do its thing.
runtime.Gosched()
} else if err == nil {
break
} else {
t.Fatal(err)
}
}
stats := make(map[header.ICMPv6Type]int)
for {
select {
case <-ctx.Done():
t.Errorf("timeout waiting for ICMP, got: %#+v", stats)
return
case icmpInfo := <-c.icmpCh:
switch icmpInfo.typ {
case header.ICMPv6NeighborAdvert:
if got, want := icmpInfo.src, lladdr1; got != want {
t.Errorf("got ICMPv6NeighborAdvert.sourceAddress = %v, want = %v", got, want)
}
}
stats[icmpInfo.typ]++
if stats[header.ICMPv6NeighborSolicit] > 0 &&
stats[header.ICMPv6NeighborAdvert] > 0 &&
stats[header.ICMPv6EchoRequest] > 0 &&
stats[header.ICMPv6EchoReply] > 0 {
return
}
}
}
}

View File

@@ -0,0 +1,187 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ipv6 contains the implementation of the ipv6 network protocol. To use
// it in the networking stack, this package must be added to the project, and
// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
// network protocols when calling stack.New(). Then endpoints can be created
// by passing ipv6.ProtocolNumber as the network protocol number when calling
// Stack.NewEndpoint().
package ipv6
import (
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
)
const (
// ProtocolName is the string representation of the ipv6 protocol name.
ProtocolName = "ipv6"
// ProtocolNumber is the ipv6 protocol number.
ProtocolNumber = header.IPv6ProtocolNumber
// maxTotalSize is maximum size that can be encoded in the 16-bit
// PayloadLength field of the ipv6 header.
maxPayloadSize = 0xffff
// defaultIPv6HopLimit is the default hop limit for IPv6 Packets
// egressed by Netstack.
defaultIPv6HopLimit = 255
)
type endpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
linkEP stack.LinkEndpoint
linkAddrCache stack.LinkAddressCache
dispatcher stack.TransportDispatcher
}
// DefaultTTL is the default hop limit for this endpoint.
func (e *endpoint) DefaultTTL() uint8 {
return 255
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
return e.nicid
}
// ID returns the ipv6 endpoint ID.
func (e *endpoint) ID() *stack.NetworkEndpointID {
return &e.id
}
// Capabilities implements stack.NetworkEndpoint.Capabilities.
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
}
// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
// underlying protocols).
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
}
// WritePacket writes a packet to the given destination address and protocol.
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
length := uint16(hdr.UsedLength() + payload.Size())
ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
ip.Encode(&header.IPv6Fields{
PayloadLength: length,
NextHeader: uint8(protocol),
HopLimit: ttl,
SrcAddr: r.LocalAddress,
DstAddr: r.RemoteAddress,
})
r.Stats().IP.PacketsSent.Increment()
return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
}
// HandlePacket is called by the link layer when new ipv6 packets arrive for
// this endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
h := header.IPv6(vv.First())
if !h.IsValid(vv.Size()) {
return
}
vv.TrimFront(header.IPv6MinimumSize)
vv.CapLength(int(h.PayloadLength()))
p := h.TransportProtocol()
if p == header.ICMPv6ProtocolNumber {
e.handleICMP(r, vv)
return
}
r.Stats().IP.PacketsDelivered.Increment()
e.dispatcher.DeliverTransportPacket(r, p, vv)
}
// Close cleans up resources associated with the endpoint.
func (*endpoint) Close() {}
type protocol struct{}
// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
// only for tests that short-circuit the stack. Regular use of the protocol is
// done via the stack, which gets a protocol descriptor from the init() function
// below.
func NewProtocol() stack.NetworkProtocol {
return &protocol{}
}
// Number returns the ipv6 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
}
// MinimumPacketSize returns the minimum valid ipv6 packet size.
func (p *protocol) MinimumPacketSize() int {
return header.IPv6MinimumSize
}
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.IPv6(v)
return h.SourceAddress(), h.DestinationAddress()
}
// NewEndpoint creates a new ipv6 endpoint.
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &endpoint{
nicid: nicid,
id: stack.NetworkEndpointID{LocalAddress: addr},
linkEP: linkEP,
linkAddrCache: linkAddrCache,
dispatcher: dispatcher,
}, nil
}
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
mtu -= header.IPv6MinimumSize
if mtu <= maxPayloadSize {
return mtu
}
return maxPayloadSize
}
func init() {
stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
return &protocol{}
})
}

View File

@@ -0,0 +1,185 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ports provides PortManager that manages allocating, reserving and releasing ports.
package ports
import (
"log"
"math"
"math/rand"
"sync"
"tcpip/netstack/tcpip"
)
const (
// 临时端口的最小值
FirstEphemeral = 16000
anyIPAddress tcpip.Address = ""
)
// 端口的唯一标识: 网络层协议-传输层协议-端口号
type portDescriptor struct {
network tcpip.NetworkProtocolNumber
transport tcpip.TransportProtocolNumber
port uint16
}
// 管理端口的对象,由它来保留和释放端口
type PortManager struct {
mu sync.RWMutex
// 用一个map接口来保存端口被占用
allocatedPorts map[portDescriptor]bindAddresses
}
// bindAddresses is a set of IP addresses.
type bindAddresses map[tcpip.Address]struct{}
// isAvailable checks whether an IP address is available to bind to.
func (b bindAddresses) isAvailable(addr tcpip.Address) bool {
if addr == anyIPAddress {
return len(b) == 0
}
// If all addresses for this portDescriptor are already bound, no
// address is available.
if _, ok := b[anyIPAddress]; ok {
return false
}
if _, ok := b[addr]; ok {
return false
}
return true
}
// NewPortManager 新建一个端口管理器
func NewPortManager() *PortManager {
return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)}
}
// PickEphemeralPort 从端口管理器中随机分配一个端口并调用testPort来检测是否可用。
func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
count := uint16(math.MaxUint16 - FirstEphemeral + 1)
offset := uint16(rand.Int31n(int32(count)))
for i := uint16(0); i < count; i++ {
port = FirstEphemeral + (offset+i)%count
ok, err := testPort(port)
if err != nil {
return 0, err
}
if ok {
return port, nil
}
}
return 0, tcpip.ErrNoPortAvailable
}
// IsPortAvailable tests if the given port is available on all given protocols.
func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber,
addr tcpip.Address, port uint16) bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.isPortAvailableLocked(networks, transport, addr, port)
}
// isPortAvailableLocked 根据参数判断该端口号是否已经被占用了
func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber,
addr tcpip.Address, port uint16) bool {
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if addrs, ok := s.allocatedPorts[desc]; ok {
if !addrs.isAvailable(addr) {
return false
}
}
}
return true
}
// ReservePort 将端口和IP地址绑定在一起这样别的程序就无法使用已经被绑定的端口。
// 如果传入的端口不为0那么会尝试绑定该端口若该端口没有被占用那么绑定成功。
// 如果传人的端口等于0那么就是告诉协议栈自己分配端口端口管理器就会随机返回一个端口。
func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber,
addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) {
s.mu.Lock()
defer s.mu.Unlock()
// If a port is specified, just try to reserve it for all network
// protocols.
if port != 0 {
if !s.reserveSpecificPort(networks, transport, addr, port) {
return 0, tcpip.ErrPortInUse
}
reservedPort = port
log.Printf("new transport: %d, port: %d", transport, reservedPort)
return
}
// 随机分配一个未占用的端口
reservedPort, err = s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
return s.reserveSpecificPort(networks, transport, addr, p), nil
})
log.Printf("new transport: %d, port: %d", transport, reservedPort)
return
}
// reserveSpecificPort 尝试根据协议号和IP地址绑定一个端口
func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber,
addr tcpip.Address, port uint16) bool {
if !s.isPortAvailableLocked(networks, transport, addr, port) {
return false
}
// Reserve port on all network protocols.
// 根据给定的网络层协议号IPV4或IPV6绑定端口
for _, network := range networks {
desc := portDescriptor{network, transport, port}
m, ok := s.allocatedPorts[desc]
if !ok {
m = make(bindAddresses)
s.allocatedPorts[desc] = m
}
// 注册该地址被绑定了
m[addr] = struct{}{}
}
return true
}
// ReleasePort releases the reservation on a port/IP combination so that it can
// be reserved by other endpoints.
// 释放绑定的端口,以便别的程序复用。
func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber,
addr tcpip.Address, port uint16) {
s.mu.Lock()
defer s.mu.Unlock()
// 删除绑定关系
for _, network := range networks {
desc := portDescriptor{network, transport, port}
if m, ok := s.allocatedPorts[desc]; ok {
log.Printf("delete transport: %d, port: %d", transport, port)
delete(m, addr)
if len(m) == 0 {
delete(s.allocatedPorts, desc)
}
}
}
}

View File

@@ -0,0 +1,145 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ports
import (
"testing"
"tcpip/netstack/tcpip"
)
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
fakeNetworkNumber tcpip.NetworkProtocolNumber = 2
fakeIPAddress = tcpip.Address("\x08\x08\x08\x08")
fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
)
func TestPortReservation(t *testing.T) {
pm := NewPortManager()
net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
for _, test := range []struct {
port uint16
ip tcpip.Address
want *tcpip.Error
}{
{
port: 80,
ip: fakeIPAddress,
want: nil,
},
{
port: 80,
ip: fakeIPAddress1,
want: nil,
},
{
/* N.B. Order of tests matters! */
port: 80,
ip: anyIPAddress,
want: tcpip.ErrPortInUse,
},
{
port: 22,
ip: anyIPAddress,
want: nil,
},
{
port: 22,
ip: fakeIPAddress,
want: tcpip.ErrPortInUse,
},
{
port: 0,
ip: fakeIPAddress,
want: nil,
},
{
port: 0,
ip: fakeIPAddress,
want: nil,
},
} {
gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port)
if err != test.want {
t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want)
}
if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) {
t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral)
}
}
// Release port 22 from any IP address, then try to reserve fake IP
// address on 22.
pm.ReleasePort(net, fakeTransNumber, anyIPAddress, 22)
if port, err := pm.ReservePort(net, fakeTransNumber, fakeIPAddress, 22); port != 22 || err != nil {
t.Fatalf("ReservePort(.., .., .., %d) = (port %d, err %v), want (22, nil); failed to reserve port after it should have been released", 22, port, err)
}
}
func TestPickEphemeralPort(t *testing.T) {
pm := NewPortManager()
customErr := &tcpip.Error{}
for _, test := range []struct {
name string
f func(port uint16) (bool, *tcpip.Error)
wantErr *tcpip.Error
wantPort uint16
}{
{
name: "no-port-available",
f: func(port uint16) (bool, *tcpip.Error) {
return false, nil
},
wantErr: tcpip.ErrNoPortAvailable,
},
{
name: "port-tester-error",
f: func(port uint16) (bool, *tcpip.Error) {
return false, customErr
},
wantErr: customErr,
},
{
name: "only-port-16042-available",
f: func(port uint16) (bool, *tcpip.Error) {
if port == FirstEphemeral+42 {
return true, nil
}
return false, nil
},
wantPort: FirstEphemeral + 42,
},
{
name: "only-port-under-16000-available",
f: func(port uint16) (bool, *tcpip.Error) {
if port < FirstEphemeral {
return true, nil
}
return false, nil
},
wantErr: tcpip.ErrNoPortAvailable,
},
} {
t.Run(test.name, func(t *testing.T) {
if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
}
})
}
}

View File

@@ -0,0 +1,67 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package seqnum defines the types and methods for TCP sequence numbers such
// that they fit in 32-bit words and work properly when overflows occur.
package seqnum
// Value represents the value of a sequence number.
type Value uint32
// Size represents the size (length) of a sequence number window.
type Size uint32
// LessThan checks if v is before w, i.e., v < w.
func (v Value) LessThan(w Value) bool {
return int32(v-w) < 0
}
// LessThanEq returns true if v==w or v is before i.e., v < w.
func (v Value) LessThanEq(w Value) bool {
if v == w {
return true
}
return v.LessThan(w)
}
// InRange checks if v is in the range [a,b), i.e., a <= v < b.
func (v Value) InRange(a, b Value) bool {
return v-a < b-a
}
// InWindow checks if v is in the window that starts at 'first' and spans 'size'
// sequence numbers.
func (v Value) InWindow(first Value, size Size) bool {
return v.InRange(first, first.Add(size))
}
// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
func Overlap(a Value, b Size, x Value, y Size) bool {
return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
}
// Add calculates the sequence number following the [v, v+s) window.
func (v Value) Add(s Size) Value {
return v + Value(s)
}
// Size calculates the size of the window defined by [v, w).
func (v Value) Size(w Value) Size {
return Size(w - v)
}
// UpdateForward updates v such that it becomes v + s.
func (v *Value) UpdateForward(s Size) {
*v += Value(s)
}

View File

@@ -0,0 +1,308 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"log"
"sync"
"time"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
)
const linkAddrCacheSize = 512 // max cache entries
// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
//
// The entries are stored in a ring buffer, oldest entry replaced first.
//
// This struct is safe for concurrent use.
type linkAddrCache struct {
// ageLimit is how long a cache entry is valid for.
ageLimit time.Duration
// resolutionTimeout is the amount of time to wait for a link request to
// resolve an address.
resolutionTimeout time.Duration
// resolutionAttempts is the number of times an address is attempted to be
// resolved before failing.
resolutionAttempts int
mu sync.Mutex
cache map[tcpip.FullAddress]*linkAddrEntry
next int // array index of next available entry
entries [linkAddrCacheSize]linkAddrEntry
}
// entryState controls the state of a single entry in the cache.
type entryState int
const (
// incomplete means that there is an outstanding request to resolve the
// address. This is the initial state.
incomplete entryState = iota
// ready means that the address has been resolved and can be used.
ready
// failed means that address resolution timed out and the address
// could not be resolved.
failed
// expired means that the cache entry has expired and the address must be
// resolved again.
expired
)
// String implements Stringer.
func (s entryState) String() string {
switch s {
case incomplete:
return "incomplete"
case ready:
return "ready"
case failed:
return "failed"
case expired:
return "expired"
default:
return fmt.Sprintf("unknown(%d)", s)
}
}
// A linkAddrEntry is an entry in the linkAddrCache.
// This struct is thread-compatible.
type linkAddrEntry struct {
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
expiration time.Time
s entryState
// wakers is a set of waiters for address resolution result. Anytime
// state transitions out of 'incomplete' these waiters are notified.
wakers map[*sleep.Waker]struct{}
done chan struct{}
}
func (e *linkAddrEntry) state() entryState {
if e.s != expired && time.Now().After(e.expiration) {
// Force the transition to ensure waiters are notified.
e.changeState(expired)
}
return e.s
}
func (e *linkAddrEntry) changeState(ns entryState) {
if e.s == ns {
return
}
// Validate state transition.
switch e.s {
case incomplete:
// All transitions are valid.
case ready, failed:
if ns != expired {
panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
}
case expired:
// Terminal state.
panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns))
default:
panic(fmt.Sprintf("invalid state: %s", e.s))
}
// Notify whoever is waiting on address resolution when transitioning
// out of 'incomplete'.
if e.s == incomplete {
for w := range e.wakers {
w.Assert()
}
e.wakers = nil
if e.done != nil {
close(e.done)
}
}
e.s = ns
}
func (e *linkAddrEntry) addWaker(w *sleep.Waker) {
e.wakers[w] = struct{}{}
}
func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
delete(e.wakers, w)
}
// add adds a k -> v mapping to the cache.
func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
log.Printf("add link cache: %v-%v", k, v)
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.cache[k]
if ok {
s := entry.state()
if s != expired && entry.linkAddr == v {
// Disregard repeated calls.
return
}
// Check if entry is waiting for address resolution.
if s == incomplete {
entry.linkAddr = v
} else {
// Otherwise create a new entry to replace it.
entry = c.makeAndAddEntry(k, v)
}
} else {
entry = c.makeAndAddEntry(k, v)
}
entry.changeState(ready)
}
// makeAndAddEntry is a helper function to create and add a new
// entry to the cache map and evict older entry as needed.
func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
// Take over the next entry.
entry := &c.entries[c.next]
if c.cache[entry.addr] == entry {
delete(c.cache, entry.addr)
}
// Mark the soon-to-be-replaced entry as expired, just in case there is
// someone waiting for address resolution on it.
entry.changeState(expired)
*entry = linkAddrEntry{
addr: k,
linkAddr: v,
expiration: time.Now().Add(c.ageLimit),
wakers: make(map[*sleep.Waker]struct{}),
done: make(chan struct{}),
}
c.cache[k] = entry
c.next = (c.next + 1) % len(c.entries)
return entry
}
// get reports any known link address for k.
func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
log.Printf("link addr get linkRes: %#v, addr: %+v", linkRes, k)
if linkRes != nil {
if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
return addr, nil, nil
}
}
c.mu.Lock()
defer c.mu.Unlock()
// 尝试从缓存中得到MAC地址
if entry, ok := c.cache[k]; ok {
switch s := entry.state(); s {
case expired:
case ready:
return entry.linkAddr, nil, nil
case failed:
return "", nil, tcpip.ErrNoLinkAddress
case incomplete:
// Address resolution is still in progress.
entry.addWaker(waker)
return "", entry.done, tcpip.ErrWouldBlock
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
if linkRes == nil {
return "", nil, tcpip.ErrNoLinkAddress
}
// Add 'incomplete' entry in the cache to mark that resolution is in progress.
e := c.makeAndAddEntry(k, "")
e.addWaker(waker)
go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done)
return "", e.done, tcpip.ErrWouldBlock
}
// removeWaker removes a waker previously added through get().
func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
c.mu.Lock()
defer c.mu.Unlock()
if entry, ok := c.cache[k]; ok {
entry.removeWaker(waker)
}
}
func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) {
for i := 0; ; i++ {
// Send link request, then wait for the timeout limit and check
// whether the request succeeded.
linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
select {
case <-time.After(c.resolutionTimeout):
if stop := c.checkLinkRequest(k, i); stop {
return
}
case <-done:
return
}
}
}
// checkLinkRequest checks whether previous attempt to resolve address has succeeded
// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
// can stop, false if another request should be sent.
func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.cache[k]
if !ok {
// Entry was evicted from the cache.
return true
}
switch s := entry.state(); s {
case ready, failed, expired:
// Entry was made ready by resolver or failed. Either way we're done.
return true
case incomplete:
if attempt+1 >= c.resolutionAttempts {
// Max number of retries reached, mark entry as failed.
entry.changeState(failed)
return true
}
// No response yet, need to send another ARP request.
return false
default:
panic(fmt.Sprintf("invalid cache entry state: %s", s))
}
}
func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
return &linkAddrCache{
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
}
}

View File

@@ -0,0 +1,266 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"fmt"
"sync"
"testing"
"time"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
)
type testaddr struct {
addr tcpip.FullAddress
linkAddr tcpip.LinkAddress
}
var testaddrs []testaddr
type testLinkAddressResolver struct {
cache *linkAddrCache
delay time.Duration
}
func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
go func() {
if r.delay > 0 {
time.Sleep(r.delay)
}
r.fakeRequest(addr)
}()
return nil
}
func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
for _, ta := range testaddrs {
if ta.addr.Addr == addr {
r.cache.add(ta.addr, ta.linkAddr)
break
}
}
}
func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == "broadcast" {
return "mac_broadcast", true
}
return "", false
}
func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return 1
}
func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
w := sleep.Waker{}
s := sleep.Sleeper{}
s.AddWaker(&w, 123)
defer s.Done()
for {
if got, _, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
return got, err
}
s.Fetch(true)
}
}
func init() {
for i := 0; i < 4*linkAddrCacheSize; i++ {
addr := fmt.Sprintf("Addr%06d", i)
testaddrs = append(testaddrs, testaddr{
addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
linkAddr: tcpip.LinkAddress("Link" + addr),
})
}
}
func TestCacheOverflow(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
for i := len(testaddrs) - 1; i >= 0; i-- {
e := testaddrs[i]
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
}
if got != e.linkAddr {
t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
}
}
// Expect to find at least half of the most recent entries.
for i := 0; i < linkAddrCacheSize/2; i++ {
e := testaddrs[i]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
}
if got != e.linkAddr {
t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
}
}
// The earliest entries should no longer be in the cache.
for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
e := testaddrs[i]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
}
}
}
func TestCacheConcurrent(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
var wg sync.WaitGroup
for r := 0; r < 16; r++ {
wg.Add(1)
go func() {
for _, e := range testaddrs {
c.add(e.addr, e.linkAddr)
c.get(e.addr, nil, "", nil, nil) // make work for gotsan
}
wg.Done()
}()
}
wg.Wait()
// All goroutines add in the same order and add more values than
// can fit in the cache, so our eviction strategy requires that
// the last entry be present and the first be missing.
e := testaddrs[len(testaddrs)-1]
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
if got != e.linkAddr {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
e = testaddrs[0]
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
func TestCacheAgeLimit(t *testing.T) {
c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
e := testaddrs[0]
c.add(e.addr, e.linkAddr)
time.Sleep(50 * time.Millisecond)
if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
func TestCacheReplace(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
e := testaddrs[0]
l2 := e.linkAddr + "2"
c.add(e.addr, e.linkAddr)
got, _, err := c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
if got != e.linkAddr {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
c.add(e.addr, l2)
got, _, err = c.get(e.addr, nil, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
if got != l2 {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
}
}
func TestCacheResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testaddrs {
got, err := getBlocking(c, ta.addr, linkRes)
if err != nil {
t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
}
if got != ta.linkAddr {
t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
}
}
// Check that after resolved, address stays in the cache and never returns WouldBlock.
for i := 0; i < 10; i++ {
e := testaddrs[len(testaddrs)-1]
got, _, err := c.get(e.addr, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
if got != e.linkAddr {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
}
}
func TestCacheResolutionFailed(t *testing.T) {
c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
// First, sanity check that resolution is working...
e := testaddrs[0]
got, err := getBlocking(c, e.addr, linkRes)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
}
if got != e.linkAddr {
t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
}
e.addr.Addr += "2"
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
func TestCacheResolutionTimeout(t *testing.T) {
resolverDelay := 50 * time.Millisecond
expiration := resolverDelay / 2
c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
e := testaddrs[0]
if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
}
}
// TestStaticResolution checks that static link addresses are resolved immediately and don't
// send resolution requests.
func TestStaticResolution(t *testing.T) {
c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
addr := tcpip.Address("broadcast")
want := tcpip.LinkAddress("mac_broadcast")
got, _, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
if err != nil {
t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
}
if got != want {
t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
}
}

668
netstack/tcpip/stack/nic.go Normal file
View File

@@ -0,0 +1,668 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"strings"
"sync"
"sync/atomic"
"tcpip/netstack/ilist"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
)
// NIC represents a "network interface card" to which the networking stack is
// attached.
// 代表一个网卡对象
type NIC struct {
stack *Stack
// 每个网卡的唯一标识号
id tcpip.NICID
// 网卡名,可有可无
name string
// 链路层端
linkEP LinkEndpoint
// 传输层的解复用
demux *transportDemuxer
mu sync.RWMutex
spoofing bool
promiscuous bool
primary map[tcpip.NetworkProtocolNumber]*ilist.List
// 网络层端的记录
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
// 子网的记录
subnets []tcpip.Subnet
}
// PrimaryEndpointBehavior is an enumeration of an endpoint's primacy behavior.
type PrimaryEndpointBehavior int
const (
// CanBePrimaryEndpoint indicates the endpoint can be used as a primary
// endpoint for new connections with no local address. This is the
// default when calling NIC.AddAddress.
CanBePrimaryEndpoint PrimaryEndpointBehavior = iota
// FirstPrimaryEndpoint indicates the endpoint should be the first
// primary endpoint considered. If there are multiple endpoints with
// this behavior, the most recently-added one will be first.
FirstPrimaryEndpoint
// NeverPrimaryEndpoint indicates the endpoint should never be a
// primary endpoint.
NeverPrimaryEndpoint
)
// 根据参数新建一个NIC
func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
return &NIC{
stack: stack,
id: id,
name: name,
linkEP: ep,
demux: newTransportDemuxer(stack),
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
}
}
// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
// to start delivering packets.
// attachLinkEndpoint添加当前的NIC到链路层设备激活该NIC
func (n *NIC) attachLinkEndpoint() {
n.linkEP.Attach(n)
}
// setPromiscuousMode enables or disables promiscuous mode.
// 设备网卡为混杂模式
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
n.promiscuous = enable
n.mu.Unlock()
}
// 判断网卡是否开启混杂模式
func (n *NIC) isPromiscuousMode() bool {
n.mu.RLock()
rv := n.promiscuous
n.mu.RUnlock()
return rv
}
// setSpoofing enables or disables address spoofing.
func (n *NIC) setSpoofing(enable bool) {
n.mu.Lock()
n.spoofing = enable
n.mu.Unlock()
}
func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
n.mu.RLock()
defer n.mu.RUnlock()
var r *referencedNetworkEndpoint
// Check for a primary endpoint.
if list, ok := n.primary[protocol]; ok {
for e := list.Front(); e != nil; e = e.Next() {
ref := e.(*referencedNetworkEndpoint)
if ref.holdsInsertRef && ref.tryIncRef() {
r = ref
break
}
}
}
// If no primary endpoints then check for other endpoints.
if r == nil {
for _, ref := range n.endpoints {
if ref.holdsInsertRef && ref.tryIncRef() {
r = ref
break
}
}
}
if r == nil {
return "", tcpip.Subnet{}, tcpip.ErrNoLinkAddress
}
address := r.ep.ID().LocalAddress
r.decRef()
// Find the least-constrained matching subnet for the address, if one
// exists, and return it.
var subnet tcpip.Subnet
for _, s := range n.subnets {
if s.Contains(address) && !subnet.Contains(s.ID()) {
subnet = s
}
}
return address, subnet, nil
}
// primaryEndpoint returns the primary endpoint of n for the given network
// protocol.
// 根据网络层协议号找到对应的网络层端
func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
n.mu.RLock()
defer n.mu.RUnlock()
list := n.primary[protocol]
if list == nil {
return nil
}
for e := list.Front(); e != nil; e = e.Next() {
r := e.(*referencedNetworkEndpoint)
// TODO: allow broadcast address when SO_BROADCAST is set.
switch r.ep.ID().LocalAddress {
case header.IPv4Broadcast, header.IPv4Any:
continue
}
if r.tryIncRef() {
return r
}
}
return nil
}
// findEndpoint finds the endpoint, if any, with the given address.
// 根据address参数查找对应的网络层端
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address,
peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
ref := n.endpoints[id]
if ref != nil && !ref.tryIncRef() {
ref = nil
}
spoofing := n.spoofing
n.mu.RUnlock()
if ref != nil || !spoofing {
return ref
}
// Try again with the lock in exclusive mode. If we still can't get the
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
ref = n.endpoints[id]
if ref == nil || !ref.tryIncRef() {
ref, _ = n.addAddressLocked(protocol, address, peb, true)
if ref != nil {
ref.holdsInsertRef = false
}
}
n.mu.Unlock()
return ref
}
// 在NIC上添加addr地址注册和初始化网络层协议
// 相当于给网卡添加ip地址
func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address,
peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
// 查看是否支持该协议,若不支持则返回错误
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
return nil, tcpip.ErrUnknownProtocol
}
// Create the new network endpoint.
// 比如netProto为ipv4会调用ipv4.NewEndpoint新建一个网络层端
ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
if err != nil {
return nil, err
}
// 获取网络层端的id其实就是ip地址
id := *ep.ID()
if ref, ok := n.endpoints[id]; ok {
// 不是替换且该id否存在返回错误
if !replace {
return nil, tcpip.ErrDuplicateAddress
}
n.removeEndpointLocked(ref)
}
ref := &referencedNetworkEndpoint{
refs: 1,
ep: ep,
nic: n,
protocol: protocol,
holdsInsertRef: true,
}
// Set up cache if link address resolution exists for this protocol.
if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
if _, ok := n.stack.linkAddrResolvers[protocol]; ok {
ref.linkCache = n.stack
}
}
// 注册该网络端
n.endpoints[id] = ref
l, ok := n.primary[protocol]
if !ok {
l = &ilist.List{}
n.primary[protocol] = l
}
switch peb {
case CanBePrimaryEndpoint:
l.PushBack(ref)
case FirstPrimaryEndpoint:
l.PushFront(ref)
}
return ref, nil
}
// AddAddress adds a new address to n, so that it starts accepting packets
// targeted at the given address (and network protocol).
// AddAddress向n添加一个新地址以便它开始接受针对给定地址和网络协议的数据包
// 最终会生成一个网络端注册到 endpoints
func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
return n.AddAddressWithOptions(protocol, addr, CanBePrimaryEndpoint)
}
// AddAddressWithOptions is the same as AddAddress, but allows you to specify
// whether the new endpoint can be primary or not.
// AddAddressWithOptions与AddAddress相同但允许您指定新端点是否可以是主端点。
func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address,
peb PrimaryEndpointBehavior) *tcpip.Error {
// Add the endpoint.
n.mu.Lock()
_, err := n.addAddressLocked(protocol, addr, peb, false)
n.mu.Unlock()
return err
}
// Addresses returns the addresses associated with this NIC.
// 获取网卡已分配的协议和对应的地址
func (n *NIC) Addresses() []tcpip.ProtocolAddress {
n.mu.RLock()
defer n.mu.RUnlock()
addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
for nid, ep := range n.endpoints {
addrs = append(addrs, tcpip.ProtocolAddress{
Protocol: ep.protocol,
Address: nid.LocalAddress,
})
}
return addrs
}
// AddSubnet adds a new subnet to n, so that it starts accepting packets
// targeted at the given address and network protocol.
// AddSubnet向n添加一个新子网以便它开始接受针对给定地址和网络协议的数据包。
func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
n.mu.Lock()
n.subnets = append(n.subnets, subnet)
n.mu.Unlock()
}
// RemoveSubnet removes the given subnet from n.
// 从n中删除一个子网
func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
n.mu.Lock()
// Use the same underlying array.
tmp := n.subnets[:0]
for _, sub := range n.subnets {
if sub != subnet {
tmp = append(tmp, sub)
}
}
n.subnets = tmp
n.mu.Unlock()
}
// ContainsSubnet reports whether this NIC contains the given subnet.
// 判断 subnet 这个子网是否在该网卡下
func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
for _, s := range n.Subnets() {
if s == subnet {
return true
}
}
return false
}
// Subnets returns the Subnets associated with this NIC.
// 获取该网卡的所有子网
func (n *NIC) Subnets() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
for nid := range n.endpoints {
sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
if err != nil {
// This should never happen as the mask has been carefully crafted to
// match the address.
panic("Invalid endpoint subnet: " + err.Error())
}
sns = append(sns, sn)
}
return append(sns, n.subnets...)
}
// 删除一个网络端
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
// Nothing to do if the reference has already been replaced with a
// different one.
if n.endpoints[id] != r {
return
}
if r.holdsInsertRef {
panic("Reference count dropped to zero before being removed")
}
delete(n.endpoints, id)
wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
if wasInList {
n.primary[r.protocol].Remove(r)
}
r.ep.Close()
}
func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Lock()
n.removeEndpointLocked(r)
n.mu.Unlock()
}
// RemoveAddress removes an address from n.
func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
n.mu.Lock()
r := n.endpoints[NetworkEndpointID{addr}]
if r == nil || !r.holdsInsertRef {
n.mu.Unlock()
return tcpip.ErrBadLocalAddress
}
r.holdsInsertRef = false
n.mu.Unlock()
r.decRef()
return nil
}
// DeliverNetworkPacket finds the appropriate network protocol endpoint and
// hands the packet over for further processing. This function is called when
// the NIC receives a packet from the physical interface.
// Note that the ownership of the slice backing vv is retained by the caller.
// This rule applies only to the slice itself, not to the items of the slice;
// the ownership of the items is not retained by the caller.
// DeliverNetworkPacket 找到适当的网络协议端点并将数据包交给它进一步处理。
// 当NIC从物理接口接收数据包时将调用此函数。比如protocol是arp协议号 那么会找到arp.HandlePacket来处理数据报。
// protocol是ipv4协议号 那么会找到ipv4.HandlePacket来处理数据报。
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
if len(vv.First()) < netProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
src, dst := netProto.ParseAddresses(vv.First())
// 根据网络协议和数据包的目的地址,找到网络端
// 然后将数据包分发给网络层
if ref := n.getRef(protocol, dst); ref != nil {
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
r.RemoteLinkAddress = remoteLinkAddr
ref.ep.HandlePacket(&r, vv)
ref.decRef()
return
}
// This NIC doesn't care about the packet. Find a NIC that cares about the
// packet and forward it to the NIC.
//
// TODO: Should we be forwarding the packet even if promiscuous?
if n.stack.Forwarding() {
r, err := n.stack.FindRoute(0, "", dst, protocol)
if err != nil {
n.stack.stats.IP.InvalidAddressesReceived.Increment()
return
}
defer r.Release()
r.LocalLinkAddress = n.linkEP.LinkAddress()
r.RemoteLinkAddress = remoteLinkAddr
// Found a NIC.
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
n.mu.RUnlock()
if ok && ref.tryIncRef() {
ref.ep.HandlePacket(&r, vv)
ref.decRef()
} else {
// n doesn't have a destination endpoint.
// Send the packet out of n.
hdr := buffer.NewPrependableFromView(vv.First())
vv.RemoveFirst()
n.linkEP.WritePacket(&r, hdr, vv, protocol)
}
return
}
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
// 根据协议类型和目标地址找出关联的Endpoint
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
id := NetworkEndpointID{dst}
n.mu.RLock()
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.RUnlock()
return ref
}
promiscuous := n.promiscuous
// Check if the packet is for a subnet this NIC cares about.
if !promiscuous {
for _, sn := range n.subnets {
if sn.Contains(dst) {
promiscuous = true
break
}
}
}
n.mu.RUnlock()
if promiscuous {
// Try again with the lock in exclusive mode. If we still can't
// get the endpoint, create a new "temporary" one. It will only
// exist while there's a route through it.
n.mu.Lock()
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.Unlock()
return ref
}
ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
n.mu.Unlock()
if err == nil {
ref.holdsInsertRef = false
return ref
}
}
return nil
}
// DeliverTransportPacket delivers the packets to the appropriate transport
// protocol endpoint.
// DeliverTransportPacket 按照传输层协议号和传输层ID将数据包分发给相应的传输层端
// 比如 protocol=6表示为tcp协议将会给相应的tcp端分发消息。
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
// 先查找协议栈是否注册了该传输层协议
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
transProto := state.proto
// 如果报文长度比该协议最小报文长度还小,那么丢弃它
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
// 解析报文得到源端口和目的端口
srcPort, dstPort, err := transProto.ParsePorts(vv.First())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
// 调用分流器根据传输层协议和传输层id分发数据报文
if n.demux.deliverPacket(r, protocol, vv, id) {
return
}
if n.stack.demux.deliverPacket(r, protocol, vv, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
if state.defaultHandler(r, id, vv) {
return
}
}
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
// DeliverTransportControlPacket delivers control packets to the appropriate
// transport protocol endpoint.
func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber,
trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView) {
state, ok := n.stack.transportProtocols[trans]
if !ok {
return
}
transProto := state.proto
// ICMPv4 only guarantees that 8 bytes of the transport protocol will
// be present in the payload. We know that the ports are within the
// first 8 bytes for all known transport protocols.
if len(vv.First()) < 8 {
return
}
srcPort, dstPort, err := transProto.ParsePorts(vv.First())
if err != nil {
return
}
id := TransportEndpointID{srcPort, local, dstPort, remote}
if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
return
}
if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
return
}
}
// ID returns the identifier of n.
func (n *NIC) ID() tcpip.NICID {
return n.id
}
type referencedNetworkEndpoint struct {
ilist.Entry
refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
// linkCache is set if link address resolution is enabled for this
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
// holdsInsertRef is protected by the NIC's mutex. It indicates whether
// the reference count is biased by 1 due to the insertion of the
// endpoint. It is reset to false when RemoveAddress is called on the
// NIC.
holdsInsertRef bool
}
// decRef decrements the ref count and cleans up the endpoint once it reaches
// zero.
func (r *referencedNetworkEndpoint) decRef() {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpoint(r)
}
}
// incRef increments the ref count. It must only be called when the caller is
// known to be holding a reference to the endpoint, otherwise tryIncRef should
// be used.
func (r *referencedNetworkEndpoint) incRef() {
atomic.AddInt32(&r.refs, 1)
}
// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
// not zero. That is, it will increment the count if the endpoint is still
// alive, and do nothing if it has already been clean up.
func (r *referencedNetworkEndpoint) tryIncRef() bool {
for {
v := atomic.LoadInt32(&r.refs)
if v == 0 {
return false
}
if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
return true
}
}
}

View File

@@ -0,0 +1,373 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"sync"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/waiter"
)
// NetworkEndpointID is the identifier of a network layer protocol endpoint.
// Currently the local address is sufficient because all supported protocols
// (i.e., IPv4 and IPv6) have different sizes for their addresses.
type NetworkEndpointID struct {
LocalAddress tcpip.Address
}
// TransportEndpointID is the identifier of a transport layer protocol endpoint.
//
// +stateify savable
type TransportEndpointID struct {
// LocalPort is the local port associated with the endpoint.
LocalPort uint16
// LocalAddress is the local [network layer] address associated with
// the endpoint.
LocalAddress tcpip.Address
// RemotePort is the remote port associated with the endpoint.
RemotePort uint16
// RemoteAddress it the remote [network layer] address associated with
// the endpoint.
RemoteAddress tcpip.Address
}
// ControlType is the type of network control message.
type ControlType int
// The following are the allowed values for ControlType values.
const (
ControlPacketTooBig ControlType = iota
ControlPortUnreachable
ControlUnknown
)
// TransportEndpoint is the interface that needs to be implemented by transport
// protocol (e.g., tcp, udp) endpoints that can handle packets.
type TransportEndpoint interface {
// HandlePacket is called by the stack when new packets arrive to
// this transport endpoint.
HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView)
// HandleControlPacket is called by the stack when new control (e.g.,
// ICMP) packets arrive to this transport endpoint.
HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
// TransportProtocol is the interface that needs to be implemented by transport
// protocols (e.g., tcp, udp) that want to be part of the networking stack.
type TransportProtocol interface {
// Number returns the transport protocol number.
Number() tcpip.TransportProtocolNumber
// NewEndpoint creates a new endpoint of the transport protocol.
NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
// MinimumPacketSize returns the minimum valid packet size of this
// transport protocol. The stack automatically drops any packets smaller
// than this targeted at this protocol.
MinimumPacketSize() int
// ParsePorts returns the source and destination ports stored in a
// packet of this protocol.
ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
// HandleUnknownDestinationPacket handles packets targeted at this
// protocol but that don't match any existing endpoint. For example,
// it is targeted at a port that have no listeners.
//
// The return value indicates whether the packet was well-formed (for
// stats purposes only).
HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
SetOption(option interface{}) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
}
// TransportDispatcher contains the methods used by the network stack to deliver
// packets to the appropriate transport endpoint after it has been handled by
// the network layer.
type TransportDispatcher interface {
// DeliverTransportPacket delivers packets to the appropriate
// transport protocol endpoint.
DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView)
// DeliverTransportControlPacket delivers control packets to the
// appropriate transport protocol endpoint.
DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView)
}
// NetworkEndpoint is the interface that needs to be implemented by endpoints
// of network layer protocols (e.g., ipv4, ipv6).
// NetworkEndpoint是需要由网络层协议例如ipv4ipv6的端点实现的接口。
type NetworkEndpoint interface {
// DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
// for this endpoint.
DefaultTTL() uint8
// MTU is the maximum transmission unit for this endpoint. This is
// generally calculated as the MTU of the underlying data link endpoint
// minus the network endpoint max header length.
MTU() uint32
// Capabilities returns the set of capabilities supported by the
// underlying link-layer endpoint.
Capabilities() LinkEndpointCapabilities
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
// building.
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
// protocol.
WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
// NICID returns the id of the NIC this endpoint belongs to.
NICID() tcpip.NICID
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint.
HandlePacket(r *Route, vv buffer.VectorisedView)
// Close is called when the endpoint is reomved from a stack.
Close()
}
// NetworkProtocol is the interface that needs to be implemented by network
// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
type NetworkProtocol interface {
// Number returns the network protocol number.
Number() tcpip.NetworkProtocolNumber
// MinimumPacketSize returns the minimum valid packet size of this
// network protocol. The stack automatically drops any packets smaller
// than this targeted at this protocol.
MinimumPacketSize() int
// ParsePorts returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// NewEndpoint creates a new endpoint of this protocol.
NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
SetOption(option interface{}) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
}
// NetworkDispatcher contains the methods used by the network stack to deliver
// packets to the appropriate network endpoint after it has been handled by
// the data link layer.
type NetworkDispatcher interface {
// DeliverNetworkPacket finds the appropriate network protocol
// endpoint and hands the packet over for further processing.
DeliverNetworkPacket(linkEP LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView)
}
// LinkEndpointCapabilities is the type associated with the capabilities
// supported by a link-layer endpoint. It is a set of bitfields.
type LinkEndpointCapabilities uint
// The following are the supported link endpoint capabilities.
const (
CapabilityChecksumOffload LinkEndpointCapabilities = 1 << iota
CapabilityResolutionRequired
CapabilitySaveRestore
CapabilityDisconnectOk
CapabilityLoopback
)
// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
// ethernet, loopback, raw) and used by network layer protocols to send packets
// out through the implementer's data link endpoint.
// LinkEndpoint是由数据链路层协议例如以太网环回原始实现的接口并由网络层协议用于通过实施者的数据链路端点发送数据包。
type LinkEndpoint interface {
// MTU is the maximum transmission unit for this endpoint. This is
// usually dictated by the backing physical network; when such a
// physical network doesn't exist, the limit is generally 64k, which
// includes the maximum size of an IP packet.
// MTU是此端点的最大传输单位。这通常由支持物理网络决定;
// 当这种物理网络不存在时限制通常为64k其中包括IP数据包的最大大小。
MTU() uint32
// Capabilities returns the set of capabilities supported by the
// endpoint.
// Capabilities返回链路层端点支持的功能集。
Capabilities() LinkEndpointCapabilities
// MaxHeaderLength returns the maximum size the data link (and
// lower level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
// building.
// MaxHeaderLength 返回数据链接(和较低级别的图层组合)标头可以具有的最大大小。
// 较高级别使用此信息来保留它们正在构建的数据包前面预留空间。
MaxHeaderLength() uint16
// LinkAddress returns the link address (typically a MAC) of the
// link endpoint.
// 本地链路层地址
LinkAddress() tcpip.LinkAddress
// WritePacket writes a packet with the given protocol through the given
// route.
//
// To participate in transparent bridging, a LinkEndpoint implementation
// should call eth.Encode with header.EthernetFields.SrcAddr set to
// r.LocalLinkAddress if it is provided.
// WritePacket通过给定的路由写入具有给定协议的数据包。
// 要参与透明桥接LinkEndpoint实现应调用eth.Encode并将header.EthernetFields.SrcAddr设置为r.LocalLinkAddress如果已提供
WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
// Attach attaches the data link layer endpoint to the network-layer
// dispatcher of the stack.
// Attach 将数据链路层端点附加到协议栈的网络层调度程序。
Attach(dispatcher NetworkDispatcher)
// IsAttached returns whether a NetworkDispatcher is attached to the
// endpoint.
// 是否已经添加了网络层调度器
IsAttached() bool
}
// A LinkAddressResolver is an extension to a NetworkProtocol that
// can resolve link addresses.
type LinkAddressResolver interface {
// LinkAddressRequest sends a request for the LinkAddress of addr.
// The request is sent on linkEP with localAddr as the source.
//
// A valid response will cause the discovery protocol's network
// endpoint to call AddLinkAddress.
LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
// ResolveStaticAddress attempts to resolve address without sending
// requests. It either resolves the name immediately or returns the
// empty LinkAddress.
//
// It can be used to resolve broadcast addresses for example.
ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
// LinkAddressProtocol returns the network protocol of the
// addresses this this resolver can resolve.
LinkAddressProtocol() tcpip.NetworkProtocolNumber
}
// A LinkAddressCache caches link addresses.
type LinkAddressCache interface {
// CheckLocalAddress determines if the given local address exists, and if it
// does not exist.
CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
// AddLinkAddress adds a link address to the cache.
AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
// GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
// If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
// registered with the network protocol, the cache attempts to resolve the address
// and returns ErrWouldBlock. Waker is notified when address resolution is
// complete (success or not).
//
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error)
// RemoveWaker removes a waker that has been added in GetLinkAddress().
RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
}
// TransportProtocolFactory functions are used by the stack to instantiate
// transport protocols.
type TransportProtocolFactory func() TransportProtocol
// NetworkProtocolFactory provides methods to be used by the stack to
// instantiate network protocols.
type NetworkProtocolFactory func() NetworkProtocol
var (
// 传输层协议的注册储存结构
transportProtocols = make(map[string]TransportProtocolFactory)
// 网络层协议的注册存储结构
networkProtocols = make(map[string]NetworkProtocolFactory)
linkEPMu sync.RWMutex
nextLinkEndpointID tcpip.LinkEndpointID = 1
linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
)
// RegisterTransportProtocolFactory registers a new transport protocol factory
// with the stack so that it becomes available to users of the stack. This
// function is intended to be called by init() functions of the protocols.
// RegisterTransportProtocolFactory 在协议栈中注册一个新的传输协议工厂,
// 以便协议栈可以使用它。此函数在由协议的init函数中使用。
func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
transportProtocols[name] = p
}
// RegisterNetworkProtocolFactory registers a new network protocol factory with
// the stack so that it becomes available to users of the stack. This function
// is intended to be called by init() functions of the protocols.
// RegisterNetworkProtocolFactory 在协议栈中注册一个新的网络协议工厂,
// 以便协议栈可以使用它。此函数在由协议的init函数中使用。
func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
networkProtocols[name] = p
}
// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
// ID that can be used to refer to it.
// 注册一个链路层设备
func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
linkEPMu.Lock()
defer linkEPMu.Unlock()
v := nextLinkEndpointID
nextLinkEndpointID++
linkEndpoints[v] = linkEP
return v
}
// FindLinkEndpoint finds the link endpoint associated with the given ID.
// 根据ID找到网卡设备
func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
linkEPMu.RLock()
defer linkEPMu.RUnlock()
return linkEndpoints[id]
}

View File

@@ -0,0 +1,186 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
)
// Route represents a route through the networking stack to a given destination.
// 贯穿整个协议栈的路由,也就是在链路层和网络层都可以路由
// 如果目标地址是链路层地址,那么在链路层路由,
// 如果目标地址是网络层地址,那么在网络层路由。
type Route struct {
// 远端网络层地址ipv4或者ipv6地址
RemoteAddress tcpip.Address
// RemoteLinkAddress is the link-layer (MAC) address of the
// final destination of the route.
// 远端网卡MAC地址
RemoteLinkAddress tcpip.LinkAddress
// LocalAddress is the local address where the route starts.
// 本地网络层地址ipv4或者ipv6地址
LocalAddress tcpip.Address
// LocalLinkAddress is the link-layer (MAC) address of the
// where the route starts.
// 本地网卡MAC地址
LocalLinkAddress tcpip.LinkAddress
// NextHop is the next node in the path to the destination.
// 下一跳网络层地址
NextHop tcpip.Address
// NetProto is the network-layer protocol.
// 网络层协议号
NetProto tcpip.NetworkProtocolNumber
// ref a reference to the network endpoint through which the route
// starts.
// 相关的网络终端
ref *referencedNetworkEndpoint
}
// makeRoute initializes a new route. It takes ownership of the provided
// reference to a network endpoint.
// 根据参数新建一个路由,并关联一个网络层端
func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address,
localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route {
return Route{
NetProto: netProto,
LocalAddress: localAddr,
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
}
}
// NICID returns the id of the NIC from which this route originates.
func (r *Route) NICID() tcpip.NICID {
return r.ref.ep.NICID()
}
// MaxHeaderLength forwards the call to the network endpoint's implementation.
func (r *Route) MaxHeaderLength() uint16 {
return r.ref.ep.MaxHeaderLength()
}
// Stats returns a mutable copy of current stats.
func (r *Route) Stats() tcpip.Stats {
return r.ref.nic.stack.Stats()
}
// PseudoHeaderChecksum forwards the call to the network endpoint's
// implementation.
// udp或tcp伪首部校验和的计算
func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 {
return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress)
}
// Capabilities returns the link-layer capabilities of the route.
func (r *Route) Capabilities() LinkEndpointCapabilities {
return r.ref.ep.Capabilities()
}
// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
// notified when address resolution is complete (success or not).
//
// If address resolution is required, ErrNoLinkAddress and a notification channel is
// returned for the top level caller to block. Channel is closed once address resolution
// is complete (success or not).
//
// Resolve 如有必要解决尝试解析链接地址的问题。如果地址解析需要阻塞则返回ErrWouldBlock
// 例如等待ARP回复。地址解析完成成功与否时通知Waker。
// 如果需要地址解析则返回ErrNoLinkAddress和通知通道以阻止顶级调用者。
// 地址解析完成后,通道关闭(不管成功与否)。
func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) {
if !r.IsResolutionRequired() {
// Nothing to do if there is no cache (which does the resolution on cache miss) or
// link address is already known.
return nil, nil
}
nextAddr := r.NextHop
if nextAddr == "" {
// Local link address is already known.
if r.RemoteAddress == r.LocalAddress {
r.RemoteLinkAddress = r.LocalLinkAddress
return nil, nil
}
nextAddr = r.RemoteAddress
}
// 调用地址解析协议来解析IP地址
linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
if err != nil {
return ch, err
}
r.RemoteLinkAddress = linkAddr
return nil, nil
}
// RemoveWaker removes a waker that has been added in Resolve().
func (r *Route) RemoveWaker(waker *sleep.Waker) {
nextAddr := r.NextHop
if nextAddr == "" {
nextAddr = r.RemoteAddress
}
r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
}
// IsResolutionRequired returns true if Resolve() must be called to resolve
// the link address before the this route can be written to.
func (r *Route) IsResolutionRequired() bool {
return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
}
// WritePacket writes the packet through the given route.
func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView,
protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl)
if err == tcpip.ErrNoRoute {
r.Stats().IP.OutgoingPacketErrors.Increment()
}
return err
}
// DefaultTTL returns the default TTL of the underlying network endpoint.
func (r *Route) DefaultTTL() uint8 {
return r.ref.ep.DefaultTTL()
}
// MTU returns the MTU of the underlying network endpoint.
func (r *Route) MTU() uint32 {
return r.ref.ep.MTU()
}
// Release frees all resources associated with the route.
func (r *Route) Release() {
if r.ref != nil {
r.ref.decRef()
r.ref = nil
}
}
// Clone Clone a route such that the original one can be released and the new
// one will remain valid.
func (r *Route) Clone() Route {
r.ref.incRef()
return *r
}

View File

@@ -0,0 +1,980 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package stack provides the glue between networking protocols and the
// consumers of the networking stack.
//
// For consumers, the only function of interest is New(), everything else is
// provided by the tcpip/public package.
//
// For protocol implementers, RegisterTransportProtocolFactory() and
// RegisterNetworkProtocolFactory() are used to register protocol factories with
// the stack, which will then be used to instantiate protocol objects when
// consumers interact with the stack.
package stack
import (
"sync"
"time"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/ports"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/waiter"
)
const (
// ageLimit is set to the same cache stale time used in Linux.
ageLimit = 1 * time.Minute
// resolutionTimeout is set to the same ARP timeout used in Linux.
resolutionTimeout = 1 * time.Second
// resolutionAttempts is set to the same ARP retries used in Linux.
resolutionAttempts = 3
)
type transportProtocolState struct {
proto TransportProtocol
defaultHandler func(*Route, TransportEndpointID, buffer.VectorisedView) bool
}
// TCPProbeFunc is the expected function type for a TCP probe function to be
// passed to stack.AddTCPProbe.
type TCPProbeFunc func(s TCPEndpointState)
// TCPCubicState is used to hold a copy of the internal cubic state when the
// TCPProbeFunc is invoked.
type TCPCubicState struct {
WLastMax float64
WMax float64
T time.Time
TimeSinceLastCongestion time.Duration
C float64
K float64
Beta float64
WC float64
WEst float64
}
// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
type TCPEndpointID struct {
// LocalPort is the local port associated with the endpoint.
LocalPort uint16
// LocalAddress is the local [network layer] address associated with
// the endpoint.
LocalAddress tcpip.Address
// RemotePort is the remote port associated with the endpoint.
RemotePort uint16
// RemoteAddress it the remote [network layer] address associated with
// the endpoint.
RemoteAddress tcpip.Address
}
// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
// TCP endpoint.
type TCPFastRecoveryState struct {
// Active if true indicates the endpoint is in fast recovery.
Active bool
// First is the first unacknowledged sequence number being recovered.
First seqnum.Value
// Last is the 'recover' sequence number that indicates the point at
// which we should exit recovery barring any timeouts etc.
Last seqnum.Value
// MaxCwnd is the maximum value we are permitted to grow the congestion
// window during recovery. This is set at the time we enter recovery.
MaxCwnd int
}
// TCPReceiverState holds a copy of the internal state of the receiver for
// a given TCP endpoint.
type TCPReceiverState struct {
// RcvNxt is the TCP variable RCV.NXT.
RcvNxt seqnum.Value
// RcvAcc is the TCP variable RCV.ACC.
RcvAcc seqnum.Value
// RcvWndScale is the window scaling to use for inbound segments.
RcvWndScale uint8
// PendingBufUsed is the number of bytes pending in the receive
// queue.
PendingBufUsed seqnum.Size
// PendingBufSize is the size of the socket receive buffer.
PendingBufSize seqnum.Size
}
// TCPSenderState holds a copy of the internal state of the sender for
// a given TCP Endpoint.
type TCPSenderState struct {
// LastSendTime is the time at which we sent the last segment.
LastSendTime time.Time
// DupAckCount is the number of Duplicate ACK's received.
DupAckCount int
// SndCwnd is the size of the sending congestion window in packets.
SndCwnd int
// Ssthresh is the slow start threshold in packets.
Ssthresh int
// SndCAAckCount is the number of packets consumed in congestion
// avoidance mode.
SndCAAckCount int
// Outstanding is the number of packets in flight.
Outstanding int
// SndWnd is the send window size in bytes.
SndWnd seqnum.Size
// SndUna is the next unacknowledged sequence number.
SndUna seqnum.Value
// SndNxt is the sequence number of the next segment to be sent.
SndNxt seqnum.Value
// RTTMeasureSeqNum is the sequence number being used for the latest RTT
// measurement.
RTTMeasureSeqNum seqnum.Value
// RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
RTTMeasureTime time.Time
// Closed indicates that the caller has closed the endpoint for sending.
Closed bool
// SRTT is the smoothed round-trip time as defined in section 2 of
// RFC 6298.
SRTT time.Duration
// RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
RTO time.Duration
// RTTVar is the round-trip time variation as defined in section 2 of
// RFC 6298.
RTTVar time.Duration
// SRTTInited if true indicates take a valid RTT measurement has been
// completed.
SRTTInited bool
// MaxPayloadSize is the maximum size of the payload of a given segment.
// It is initialized on demand.
MaxPayloadSize int
// SndWndScale is the number of bits to shift left when reading the send
// window size from a segment.
SndWndScale uint8
// MaxSentAck is the highest acknowledgement number sent till now.
MaxSentAck seqnum.Value
// FastRecovery holds the fast recovery state for the endpoint.
FastRecovery TCPFastRecoveryState
// Cubic holds the state related to CUBIC congestion control.
Cubic TCPCubicState
}
// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
type TCPSACKInfo struct {
// Blocks is the list of SACK block currently received by the
// TCP endpoint.
Blocks []header.SACKBlock
}
// TCPEndpointState is a copy of the internal state of a TCP endpoint.
type TCPEndpointState struct {
// ID is a copy of the TransportEndpointID for the endpoint.
ID TCPEndpointID
// SegTime denotes the absolute time when this segment was received.
SegTime time.Time
// RcvBufSize is the size of the receive socket buffer for the endpoint.
RcvBufSize int
// RcvBufUsed is the amount of bytes actually held in the receive socket
// buffer for the endpoint.
RcvBufUsed int
// RcvClosed if true, indicates the endpoint has been closed for reading.
RcvClosed bool
// SendTSOk is used to indicate when the TS Option has been negotiated.
// When sendTSOk is true every non-RST segment should carry a TS as per
// RFC7323#section-1.1.
SendTSOk bool
// RecentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint.
RecentTS uint32
// TSOffset is a randomized offset added to the value of the TSVal field
// in the timestamp option.
TSOffset uint32
// SACKPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK.
SACKPermitted bool
// SACK holds TCP SACK related information for this endpoint.
SACK TCPSACKInfo
// SndBufSize is the size of the socket send buffer.
SndBufSize int
// SndBufUsed is the number of bytes held in the socket send buffer.
SndBufUsed int
// SndClosed indicates that the endpoint has been closed for sends.
SndClosed bool
// SndBufInQueue is the number of bytes in the send queue.
SndBufInQueue seqnum.Size
// PacketTooBigCount is used to notify the main protocol routine how
// many times a "packet too big" control packet is received.
PacketTooBigCount int
// SndMTU is the smallest MTU seen in the control packets received.
SndMTU int
// Receiver holds variables related to the TCP receiver for the endpoint.
Receiver TCPReceiverState
// Sender holds state related to the TCP Sender for the endpoint.
Sender TCPSenderState
}
// Stack is a networking stack, with all supported protocols, NICs, and route
// table.
type Stack struct {
transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState
networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
demux *transportDemuxer
stats tcpip.Stats
linkAddrCache *linkAddrCache
mu sync.RWMutex
nics map[tcpip.NICID]*NIC
forwarding bool
// route is the route table passed in by the user via SetRouteTable(),
// it is used by FindRoute() to build a route for a specific
// destination.
routeTable []tcpip.Route
*ports.PortManager
// If not nil, then any new endpoints will have this probe function
// invoked everytime they receive a TCP segment.
tcpProbeFunc TCPProbeFunc
// clock is used to generate user-visible times.
clock tcpip.Clock
}
// Options contains optional Stack configuration.
type Options struct {
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
Clock tcpip.Clock
// Stats are optional statistic counters.
Stats tcpip.Stats
}
// New allocates a new networking stack with only the requested networking and
// transport protocols configured with default options.
//
// Protocol options can be changed by calling the
// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
// stack. Please refer to individual protocol implementations as to what options
// are supported.
// 新建一个协议栈对象
func New(network []string, transport []string, opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
}
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
PortManager: ports.NewPortManager(),
clock: clock,
stats: opts.Stats.FillIn(),
}
// Add specified network protocols.
// 添加指定的网络层协议端如IPV4
for _, name := range network {
netProtoFactory, ok := networkProtocols[name]
if !ok {
continue
}
netProto := netProtoFactory()
s.networkProtocols[netProto.Number()] = netProto
// 判断该协议是否支持链路层地址解析协议接口,如果支持添加到 s.linkAddrResolvers 中,
// 如ARP协议会添加 IPV4-ARP 的对应关系
// 后面需要地址解析协议的时候会更改网络层协议号来找到相应的地址解析协议
if r, ok := netProto.(LinkAddressResolver); ok {
s.linkAddrResolvers[r.LinkAddressProtocol()] = r
}
}
// Add specified transport protocols.
// 添加传输层协议
for _, name := range transport {
transProtoFactory, ok := transportProtocols[name]
if !ok {
continue
}
transProto := transProtoFactory()
s.transportProtocols[transProto.Number()] = &transportProtocolState{
proto: transProto,
}
}
// Create the global transport demuxer.
// 新建协议栈全局传输层分流器
s.demux = newTransportDemuxer(s)
return s
}
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
}
return netProto.SetOption(option)
}
// NetworkProtocolOption allows retrieving individual protocol level option
// values. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation.
// e.g.
// var v ipv4.MyOption
// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
// if err != nil {
// ...
// }
func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
}
return netProto.Option(option)
}
// SetTransportProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
}
return transProtoState.proto.SetOption(option)
}
// TransportProtocolOption allows retrieving individual protocol level option
// values. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation.
// var v tcp.SACKEnabled
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
// ...
// }
func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
}
return transProtoState.proto.Option(option)
}
// SetTransportProtocolHandler sets the per-stack default handler for the given
// protocol.
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
}
}
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
func (s *Stack) NowNanoseconds() int64 {
return s.clock.NowNanoseconds()
}
// Stats returns a mutable copy of the current stats.
//
// This is not generally exported via the public interface, but is available
// internally.
func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
// SetForwarding enables or disables the packet forwarding between NICs.
func (s *Stack) SetForwarding(enable bool) {
// TODO: Expose via /proc/sys/net/ipv4/ip_forward.
s.mu.Lock()
s.forwarding = enable
s.mu.Unlock()
}
// Forwarding returns if the packet forwarding between NICs is enabled.
func (s *Stack) Forwarding() bool {
// TODO: Expose via /proc/sys/net/ipv4/ip_forward.
s.mu.RLock()
defer s.mu.RUnlock()
return s.forwarding
}
// SetRouteTable assigns the route table to be used by this stack. It
// specifies which NIC to use for given destination address ranges.
func (s *Stack) SetRouteTable(table []tcpip.Route) {
s.mu.Lock()
defer s.mu.Unlock()
s.routeTable = table
}
// GetRouteTable returns the route table which is currently in use.
func (s *Stack) GetRouteTable() []tcpip.Route {
s.mu.Lock()
defer s.mu.Unlock()
return append([]tcpip.Route(nil), s.routeTable...)
}
// NewEndpoint creates a new transport layer endpoint of the given protocol.
func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
t, ok := s.transportProtocols[transport]
if !ok {
return nil, tcpip.ErrUnknownProtocol
}
return t.proto.NewEndpoint(s, network, waiterQueue)
}
// createNIC creates a NIC with the provided id and link-layer endpoint, and
// optionally enable it.
// 新建一个网卡对象,并且激活它,激活的意思就是准备好从网卡中读取和写入数据。
func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
ep := FindLinkEndpoint(linkEP)
if ep == nil {
return tcpip.ErrBadLinkEndpoint
}
s.mu.Lock()
defer s.mu.Unlock()
// Make sure id is unique.
if _, ok := s.nics[id]; ok {
return tcpip.ErrDuplicateNICID
}
n := newNIC(s, id, name, ep)
s.nics[id] = n
if enabled {
n.attachLinkEndpoint()
}
return nil
}
// CreateNIC creates a NIC with the provided id and link-layer endpoint.
// 根据nic id和linkEP id来创建和注册一个网卡对象
func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
return s.createNIC(id, "", linkEP, true)
}
// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
// and a human-readable name.
func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
return s.createNIC(id, name, linkEP, true)
}
// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
// but leave it disable. Stack.EnableNIC must be called before the link-layer
// endpoint starts delivering packets to it.
func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
return s.createNIC(id, "", linkEP, false)
}
// CreateDisabledNamedNIC is a combination of CreateNamedNIC and
// CreateDisabledNIC.
func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
return s.createNIC(id, name, linkEP, false)
}
// EnableNIC enables the given NIC so that the link-layer endpoint can start
// delivering packets to it.
func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[id]
if nic == nil {
return tcpip.ErrUnknownNICID
}
nic.attachLinkEndpoint()
return nil
}
// NICSubnets returns a map of NICIDs to their associated subnets.
func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
s.mu.RLock()
defer s.mu.RUnlock()
nics := map[tcpip.NICID][]tcpip.Subnet{}
for id, nic := range s.nics {
nics[id] = append(nics[id], nic.Subnets()...)
}
return nics
}
// NICInfo captures the name and addresses assigned to a NIC.
type NICInfo struct {
Name string
LinkAddress tcpip.LinkAddress
ProtocolAddresses []tcpip.ProtocolAddress
// Flags indicate the state of the NIC.
Flags NICStateFlags
// MTU is the maximum transmission unit.
MTU uint32
}
// NICInfo returns a map of NICIDs to their associated information.
func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
s.mu.RLock()
defer s.mu.RUnlock()
nics := make(map[tcpip.NICID]NICInfo)
for id, nic := range s.nics {
flags := NICStateFlags{
Up: true, // Netstack interfaces are always up.
Running: nic.linkEP.IsAttached(),
Promiscuous: nic.isPromiscuousMode(),
Loopback: nic.linkEP.Capabilities()&CapabilityLoopback != 0,
}
nics[id] = NICInfo{
Name: nic.name,
LinkAddress: nic.linkEP.LinkAddress(),
ProtocolAddresses: nic.Addresses(),
Flags: flags,
MTU: nic.linkEP.MTU(),
}
}
return nics
}
// NICStateFlags holds information about the state of an NIC.
type NICStateFlags struct {
// Up indicates whether the interface is running.
Up bool
// Running indicates whether resources are allocated.
Running bool
// Promiscuous indicates whether the interface is in promiscuous mode.
Promiscuous bool
// Loopback indicates whether the interface is a loopback.
Loopback bool
}
// AddAddress adds a new network-layer address to the specified NIC.
func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
return s.AddAddressWithOptions(id, protocol, addr, CanBePrimaryEndpoint)
}
// AddAddressWithOptions is the same as AddAddress, but allows you to specify
// whether the new endpoint can be primary or not.
func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, peb PrimaryEndpointBehavior) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[id]
if nic == nil {
return tcpip.ErrUnknownNICID
}
return nic.AddAddressWithOptions(protocol, addr, peb)
}
// AddSubnet adds a subnet range to the specified NIC.
func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
nic.AddSubnet(protocol, subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
// RemoveSubnet removes the subnet range from the specified NIC.
func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
nic.RemoveSubnet(subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
// ContainsSubnet reports whether the specified NIC contains the specified
// subnet.
func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
return nic.ContainsSubnet(subnet), nil
}
return false, tcpip.ErrUnknownNICID
}
// RemoveAddress removes an existing network-layer address from the specified
// NIC.
func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
return nic.RemoveAddress(addr)
}
return tcpip.ErrUnknownNICID
}
// GetMainNICAddress returns the first primary address (and the subnet that
// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's
// address if no primary addresses exist. Returns an error if the NIC doesn't
// exist or has no endpoints.
func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.Address, tcpip.Subnet, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
return nic.getMainNICAddress(protocol)
}
return "", tcpip.Subnet{}, tcpip.ErrUnknownNICID
}
// FindRoute creates a route to the given destination address, leaving through
// the given nic and local address (if provided).
// 路由查找实现比如当tcp建立连接时会用该函数得到路由信息
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
// 遍历路由表
for i := range s.routeTable {
if (id != 0 && id != s.routeTable[i].NIC) || (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
continue
}
nic := s.nics[s.routeTable[i].NIC]
if nic == nil {
continue
}
var ref *referencedNetworkEndpoint
if len(localAddr) != 0 {
ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint)
} else {
ref = nic.primaryEndpoint(netProto)
}
if ref == nil {
continue
}
if len(remoteAddr) == 0 {
// If no remote address was provided, then the route
// provided will refer to the link local address.
remoteAddr = ref.ep.ID().LocalAddress
}
r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref)
r.NextHop = s.routeTable[i].Gateway
return r, nil
}
return Route{}, tcpip.ErrNoRoute
}
// CheckNetworkProtocol checks if a given network protocol is enabled in the
// stack.
func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool {
_, ok := s.networkProtocols[protocol]
return ok
}
// CheckLocalAddress determines if the given local address exists, and if it
// does, returns the id of the NIC it's bound to. Returns 0 if the address
// does not exist.
func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
s.mu.RLock()
defer s.mu.RUnlock()
// If a NIC is specified, we try to find the address there only.
if nicid != 0 {
nic := s.nics[nicid]
if nic == nil {
return 0
}
ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
if ref == nil {
return 0
}
ref.decRef()
return nic.id
}
// Go through all the NICs.
for _, nic := range s.nics {
ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
if ref != nil {
ref.decRef()
return nic.id
}
}
return 0
}
// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[nicID]
if nic == nil {
return tcpip.ErrUnknownNICID
}
nic.setPromiscuousMode(enable)
return nil
}
// SetSpoofing enables or disables address spoofing in the given NIC, allowing
// endpoints to bind to any address in the NIC.
func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[nicID]
if nic == nil {
return tcpip.ErrUnknownNICID
}
nic.setSpoofing(enable)
return nil
}
// AddLinkAddress adds a link address to the stack link cache.
func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
s.linkAddrCache.add(fullAddr, linkAddr)
// TODO: provide a way for a transport endpoint to receive a signal
// that AddLinkAddress for a particular address has been called.
}
// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
// 获取链路层地址的方法
func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
s.mu.RLock()
// 获取网卡对象
nic := s.nics[nicid]
if nic == nil {
s.mu.RUnlock()
return "", nil, tcpip.ErrUnknownNICID
}
s.mu.RUnlock()
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
// 根据网络层协议号找到对应的地址解析协议,
// 如IPV4协议会得到ARP协议
linkRes := s.linkAddrResolvers[protocol]
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
}
// RemoveWaker implements LinkAddressCache.RemoveWaker.
func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
s.mu.RLock()
defer s.mu.RUnlock()
if nic := s.nics[nicid]; nic == nil {
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
s.linkAddrCache.removeWaker(fullAddr, waker)
}
}
// RegisterTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
// RegisterTransportEndpoint 协议栈或者NIC的分流器注册给定传输层端点。
// 收到的与提供的id匹配的数据包将被传送到给定的端点;指定nic是可选的但特定于nic的ID优先于全局ID。
// 最终调用 demuxer.registerEndpoint 函数来实现注册。
func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
if nicID == 0 {
return s.demux.registerEndpoint(netProtos, protocol, id, ep)
}
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[nicID]
if nic == nil {
return tcpip.ErrUnknownNICID
}
return nic.demux.registerEndpoint(netProtos, protocol, id, ep)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
if nicID == 0 {
s.demux.unregisterEndpoint(netProtos, protocol, id)
return
}
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[nicID]
if nic != nil {
nic.demux.unregisterEndpoint(netProtos, protocol, id)
}
}
// NetworkProtocolInstance returns the protocol instance in the stack for the
// specified network protocol. This method is public for protocol implementers
// and tests to use.
func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol {
if p, ok := s.networkProtocols[num]; ok {
return p
}
return nil
}
// TransportProtocolInstance returns the protocol instance in the stack for the
// specified transport protocol. This method is public for protocol implementers
// and tests to use.
func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol {
if pState, ok := s.transportProtocols[num]; ok {
return pState.proto
}
return nil
}
// AddTCPProbe installs a probe function that will be invoked on every segment
// received by a given TCP endpoint. The probe function is passed a copy of the
// TCP endpoint state.
//
// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
// created prior to this call will not call the probe function.
//
// Further, installing two different probes back to back can result in some
// endpoints calling the first one and some the second one. There is no
// guarantee provided on which probe will be invoked. Ideally this should only
// be called once per stack.
func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
s.mu.Lock()
s.tcpProbeFunc = probe
s.mu.Unlock()
}
// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
// otherwise.
func (s *Stack) GetTCPProbe() TCPProbeFunc {
s.mu.Lock()
p := s.tcpProbeFunc
s.mu.Unlock()
return p
}
// RemoveTCPProbe removes an installed TCP probe.
//
// NOTE: This only ensures that endpoints created after this call do not
// have a probe attached. Endpoints already created will continue to invoke
// TCP probe.
func (s *Stack) RemoveTCPProbe() {
s.mu.Lock()
s.tcpProbeFunc = nil
s.mu.Unlock()
}
// JoinGroup joins the given multicast group on the given NIC.
func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
// TODO: notify network of subscription via igmp protocol.
return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint)
}
// LeaveGroup leaves the given multicast group on the given NIC.
func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error {
return s.RemoveAddress(nicID, multicastAddr)
}

View File

@@ -0,0 +1,849 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package stack_test contains tests for the stack. It is in its own package so
// that the tests can also validate that all definitions needed to implement
// transport and network protocols are properly exported by the stack package.
package stack_test
import (
"math"
"strings"
"testing"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/link/channel"
"tcpip/netstack/tcpip/stack"
)
const (
fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
fakeNetHeaderLen = 12
// fakeControlProtocol is used for control packets that represent
// destination port unreachable.
fakeControlProtocol tcpip.TransportProtocolNumber = 2
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
)
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
// received packets; the counts of all endpoints are aggregated in the protocol
// descriptor.
//
// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only
// use the first three: destination address, source address, and transport
// protocol. They're all one byte fields to simplify parsing.
type fakeNetworkEndpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
linkEP stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
return f.nicid
}
func (*fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
// Increment the received packet count in the protocol descriptor.
f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
// Consume the network header.
b := vv.First()
vv.TrimFront(fakeNetHeaderLen)
// Handle control packets.
if b[2] == uint8(fakeControlProtocol) {
nb := vv.First()
if len(nb) < fakeNetHeaderLen {
return
}
vv.TrimFront(fakeNetHeaderLen)
f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
return
}
// Dispatch the packet to the transport protocol.
f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv)
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
return 0
}
func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.linkEP.Capabilities()
}
func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8) *tcpip.Error {
// Increment the sent packet count in the protocol descriptor.
f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
// Add the protocol's header to the packet and send it to the link
// endpoint.
b := hdr.Prepend(fakeNetHeaderLen)
b[0] = r.RemoteAddress[0]
b[1] = f.id.LocalAddress[0]
b[2] = byte(protocol)
return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber)
}
func (*fakeNetworkEndpoint) Close() {}
type fakeNetGoodOption bool
type fakeNetBadOption bool
type fakeNetInvalidValueOption int
type fakeNetOptions struct {
good bool
}
// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
// number of packets sent and received via endpoints of this protocol. The index
// where packets are added is given by the packet's destination address MOD 10.
type fakeNetworkProtocol struct {
packetCount [10]int
sendPacketCount [10]int
opts fakeNetOptions
}
func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return fakeNetNumber
}
func (f *fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
}
func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
id: stack.NetworkEndpointID{addr},
proto: f,
dispatcher: dispatcher,
linkEP: linkEP,
}, nil
}
func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case fakeNetGoodOption:
f.opts.good = bool(v)
return nil
case fakeNetInvalidValueOption:
return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *fakeNetGoodOption:
*v = fakeNetGoodOption(f.opts.good)
return nil
default:
return tcpip.ErrUnknownProtocolOption
}
}
func TestNetworkReceive(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and two
// addresses attached to it: 1 & 2.
id, linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
// Make sure packet with wrong address is not delivered.
buf[0] = 3
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
if fakeNet.packetCount[2] != 0 {
t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
}
// Make sure packet is delivered to first endpoint.
buf[0] = 1
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
if fakeNet.packetCount[2] != 0 {
t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
}
// Make sure packet is delivered to second endpoint.
buf[0] = 2
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
if fakeNet.packetCount[2] != 1 {
t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
}
// Make sure packet is not delivered if protocol number is wrong.
linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
if fakeNet.packetCount[2] != 1 {
t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
}
// Make sure packet that is too small is dropped.
buf.CapLength(2)
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
if fakeNet.packetCount[2] != 1 {
t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
}
}
func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) {
r, err := s.FindRoute(0, "", addr, fakeNetNumber)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
defer r.Release()
hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
if err := r.WritePacket(hdr, buffer.VectorisedView{}, fakeTransNumber, 123); err != nil {
t.Errorf("WritePacket failed: %v", err)
return
}
}
func TestNetworkSend(t *testing.T) {
// Create a stack with the fake network protocol, one nic, and one
// address: 1. The route table sends all packets through the only
// existing nic.
id, linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("NewNIC failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
// Make sure that the link-layer endpoint received the outbound packet.
sendTo(t, s, "\x03")
if c := linkEP.Drain(); c != 1 {
t.Errorf("packetCount = %d, want %d", c, 1)
}
}
func TestNetworkSendMultiRoute(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, linkEP1 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
id2, linkEP2 := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(2, id2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
s.SetRouteTable([]tcpip.Route{
{"\x01", "\x01", "\x00", 1},
{"\x00", "\x01", "\x00", 2},
})
// Send a packet to an odd destination.
sendTo(t, s, "\x05")
if c := linkEP1.Drain(); c != 1 {
t.Errorf("packetCount = %d, want %d", c, 1)
}
// Send a packet to an even destination.
sendTo(t, s, "\x06")
if c := linkEP2.Drain(); c != 1 {
t.Errorf("packetCount = %d, want %d", c, 1)
}
}
func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
defer r.Release()
if r.LocalAddress != expectedSrcAddr {
t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress)
}
if r.RemoteAddress != dstAddr {
t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress)
}
}
func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
_, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
if err != tcpip.ErrNoRoute {
t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err)
}
}
func TestRoutes(t *testing.T) {
// Create a stack with the fake network protocol, two nics, and two
// addresses per nic, the first nic has odd address, the second one has
// even addresses.
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id1); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
id2, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(2, id2); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
// Set a route table that sends all packets with odd destination
// addresses through the first NIC, and all even destination address
// through the second one.
s.SetRouteTable([]tcpip.Route{
{"\x01", "\x01", "\x00", 1},
{"\x00", "\x01", "\x00", 2},
})
// Test routes to odd address.
testRoute(t, s, 0, "", "\x05", "\x01")
testRoute(t, s, 0, "\x01", "\x05", "\x01")
testRoute(t, s, 1, "\x01", "\x05", "\x01")
testRoute(t, s, 0, "\x03", "\x05", "\x03")
testRoute(t, s, 1, "\x03", "\x05", "\x03")
// Test routes to even address.
testRoute(t, s, 0, "", "\x06", "\x02")
testRoute(t, s, 0, "\x02", "\x06", "\x02")
testRoute(t, s, 2, "\x02", "\x06", "\x02")
testRoute(t, s, 0, "\x04", "\x06", "\x04")
testRoute(t, s, 2, "\x04", "\x06", "\x04")
// Try to send to odd numbered address from even numbered ones, then
// vice-versa.
testNoRoute(t, s, 0, "\x02", "\x05")
testNoRoute(t, s, 2, "\x02", "\x05")
testNoRoute(t, s, 0, "\x04", "\x05")
testNoRoute(t, s, 2, "\x04", "\x05")
testNoRoute(t, s, 0, "\x01", "\x06")
testNoRoute(t, s, 1, "\x01", "\x06")
testNoRoute(t, s, 0, "\x03", "\x06")
testNoRoute(t, s, 1, "\x03", "\x06")
}
func TestAddressRemoval(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
// Write a packet, and check that it gets delivered.
fakeNet.packetCount[1] = 0
buf[0] = 1
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
// Remove the address, then check that packet doesn't get delivered
// anymore.
if err := s.RemoveAddress(1, "\x01"); err != nil {
t.Fatalf("RemoveAddress failed: %v", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
// Check that removing the same address fails.
if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress failed: %v", err)
}
}
func TestDelayedRemovalDueToRoute(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{"\x00", "\x00", "\x00", 1},
})
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
// Write a packet, and check that it gets delivered.
fakeNet.packetCount[1] = 0
buf[0] = 1
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
// Get a route, check that packet is still deliverable.
r, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 2 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
}
// Remove the address, then check that packet is still deliverable
// because the route is keeping the address alive.
if err := s.RemoveAddress(1, "\x01"); err != nil {
t.Fatalf("RemoveAddress failed: %v", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 3 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
}
// Check that removing the same address fails.
if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
t.Fatalf("RemoveAddress failed: %v", err)
}
// Release the route, then check that packet is not deliverable anymore.
r.Release()
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 3 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
}
}
func TestPromiscuousMode(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{"\x00", "\x00", "\x00", 1},
})
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
// Write a packet, and check that it doesn't get delivered as we don't
// have a matching endpoint.
fakeNet.packetCount[1] = 0
buf[0] = 1
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
// Set promiscuous mode, then check that packet is delivered.
if err := s.SetPromiscuousMode(1, true); err != nil {
t.Fatalf("SetPromiscuousMode failed: %v", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
// Check that we can't get a route as there is no local address.
_, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
if err != tcpip.ErrNoRoute {
t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err)
}
// Set promiscuous mode to false, then check that packet can't be
// delivered anymore.
if err := s.SetPromiscuousMode(1, false); err != nil {
t.Fatalf("SetPromiscuousMode failed: %v", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
}
func TestAddressSpoofing(t *testing.T) {
srcAddr := tcpip.Address("\x01")
dstAddr := tcpip.Address("\x02")
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{"\x00", "\x00", "\x00", 1},
})
// With address spoofing disabled, FindRoute does not permit an address
// that was not added to the NIC to be used as the source.
r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
if err == nil {
t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
}
// With address spoofing enabled, FindRoute permits any address to be used
// as the source.
if err := s.SetSpoofing(1, true); err != nil {
t.Fatalf("SetSpoofing failed: %v", err)
}
r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
if err != nil {
t.Fatalf("FindRoute failed: %v", err)
}
if r.LocalAddress != srcAddr {
t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
}
if r.RemoteAddress != dstAddr {
t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
}
}
// Set the subnet, then check that packet is delivered.
func TestSubnetAcceptsMatchingPacket(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{"\x00", "\x00", "\x00", 1},
})
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
buf[0] = 1
fakeNet.packetCount[1] = 0
subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatalf("NewSubnet failed: %v", err)
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatalf("AddSubnet failed: %v", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 1 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
}
}
// Set destination outside the subnet, then check it doesn't get delivered.
func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, linkEP := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{"\x00", "\x00", "\x00", 1},
})
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
buf := buffer.NewView(30)
buf[0] = 1
fakeNet.packetCount[1] = 0
subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
if err != nil {
t.Fatalf("NewSubnet failed: %v", err)
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatalf("AddSubnet failed: %v", err)
}
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeNet.packetCount[1] != 0 {
t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
}
}
func TestNetworkOptions(t *testing.T) {
s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{})
// Try an unsupported network protocol.
if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
}
testCases := []struct {
option interface{}
wantErr *tcpip.Error
verifier func(t *testing.T, p stack.NetworkProtocol)
}{
{fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
t.Helper()
fakeNet := p.(*fakeNetworkProtocol)
if fakeNet.opts.good != true {
t.Fatalf("fakeNet.opts.good = false, want = true")
}
var v fakeNetGoodOption
if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
}
if v != true {
t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
}
}},
{fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
{fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
}
for _, tc := range testCases {
if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
}
if tc.verifier != nil {
tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
}
}
}
func TestSubnetAddRemove(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
addr := tcpip.Address("\x01\x01\x01\x01")
mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr)))
subnet, err := tcpip.NewSubnet(addr, mask)
if err != nil {
t.Fatalf("NewSubnet failed: %v", err)
}
if contained, err := s.ContainsSubnet(1, subnet); err != nil {
t.Fatalf("ContainsSubnet failed: %v", err)
} else if contained {
t.Fatal("got s.ContainsSubnet(...) = true, want = false")
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatalf("AddSubnet failed: %v", err)
}
if contained, err := s.ContainsSubnet(1, subnet); err != nil {
t.Fatalf("ContainsSubnet failed: %v", err)
} else if !contained {
t.Fatal("got s.ContainsSubnet(...) = false, want = true")
}
if err := s.RemoveSubnet(1, subnet); err != nil {
t.Fatalf("RemoveSubnet failed: %v", err)
}
if contained, err := s.ContainsSubnet(1, subnet); err != nil {
t.Fatalf("ContainsSubnet failed: %v", err)
} else if contained {
t.Fatal("got s.ContainsSubnet(...) = true, want = false")
}
}
func TestGetMainNICAddress(t *testing.T) {
s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id, _ := channel.New(10, defaultMTU, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
for _, tc := range []struct {
name string
address tcpip.Address
}{
{"IPv4", "\x01\x01\x01\x01"},
{"IPv6", "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"},
} {
t.Run(tc.name, func(t *testing.T) {
address := tc.address
mask := tcpip.AddressMask(strings.Repeat("\xff", len(address)))
subnet, err := tcpip.NewSubnet(address, mask)
if err != nil {
t.Fatalf("NewSubnet failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, address); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
t.Fatalf("AddSubnet failed: %v", err)
}
// Check that we get the right initial address and subnet.
if gotAddress, gotSubnet, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil {
t.Fatalf("GetMainNICAddress failed: %v", err)
} else if gotAddress != address {
t.Fatalf("got GetMainNICAddress = (%v, ...), want = (%v, ...)", gotAddress, address)
} else if gotSubnet != subnet {
t.Fatalf("got GetMainNICAddress = (..., %v), want = (..., %v)", gotSubnet, subnet)
}
if err := s.RemoveSubnet(1, subnet); err != nil {
t.Fatalf("RemoveSubnet failed: %v", err)
}
if err := s.RemoveAddress(1, address); err != nil {
t.Fatalf("RemoveAddress failed: %v", err)
}
// Check that we get an error after removal.
if _, _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress {
t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %v", err, tcpip.ErrNoLinkAddress)
}
})
}
}
func init() {
stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
})
}

View File

@@ -0,0 +1,192 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack
import (
"sync"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
)
// 网络层协议号和传输层协议号的组合当作分流器的key值
type protocolIDs struct {
network tcpip.NetworkProtocolNumber
transport tcpip.TransportProtocolNumber
}
// transportEndpoints manages all endpoints of a given protocol. It has its own
// mutex so as to reduce interference between protocols.
// transportEndpoints 管理给定协议的所有端点。
type transportEndpoints struct {
mu sync.RWMutex
endpoints map[TransportEndpointID]TransportEndpoint
}
// transportDemuxer demultiplexes packets targeted at a transport endpoint
// (i.e., after they've been parsed by the network layer). It does two levels
// of demultiplexing: first based on the network and transport protocols, then
// based on endpoints IDs.
// transportDemuxer 解复用针对传输端点的数据包(即,在它们被网络层解析之后)。
// 它执行两级解复用首先基于网络层协议和传输协议然后基于端点ID。
type transportDemuxer struct {
protocol map[protocolIDs]*transportEndpoints
}
// 新建一个分流器
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
// Add each network and transport pair to the demuxer.
for netProto := range stack.networkProtocols {
for proto := range stack.transportProtocols {
d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
}
}
return d
}
// registerEndpoint registers the given endpoint with the dispatcher such that
// packets that match the endpoint ID are delivered to it.
// registerEndpoint 向分发器注册给定端点以便将与端点ID匹配的数据包传递给它。
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id)
return err
}
}
return nil
}
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, protocol}]
if !ok {
return nil
}
eps.mu.Lock()
defer eps.mu.Unlock()
if _, ok := eps.endpoints[id]; ok {
return tcpip.ErrPortInUse
}
eps.endpoints[id] = ep
return nil
}
// unregisterEndpoint unregisters the endpoint with the given id such that it
// won't receive any more packets.
// unregisterEndpoint 使用给定的id注销端点使其不再接收任何数据包。
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
eps.mu.Lock()
delete(eps.endpoints, id)
eps.mu.Unlock()
}
}
}
// deliverPacket attempts to deliver the given packet. Returns true if it found
// an endpoint, false otherwise.
// 根据传输层的id来找到对应的传输端再将数据包交给这个传输端处理
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
// 先看看分流器里有没有注册相关协议端如果没有则返回false
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
// 从 eps 中找符合 id 的传输端
eps.mu.RLock()
ep := d.findEndpointLocked(eps, vv, id)
eps.mu.RUnlock()
// Fail if we didn't find one.
if ep == nil {
// UDP packet could not be delivered to an unknown destination port.
if protocol == header.UDPProtocolNumber {
r.Stats().UDP.UnknownPortErrors.Increment()
}
return false
}
// Deliver the packet.
// 找到传输端到,让它来处理数据包
ep.HandlePacket(r, id, vv)
return true
}
// deliverControlPacket attempts to deliver the given control packet. Returns
// true if it found an endpoint, false otherwise.
func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
eps, ok := d.protocol[protocolIDs{net, trans}]
if !ok {
return false
}
// Try to find the endpoint.
eps.mu.RLock()
ep := d.findEndpointLocked(eps, vv, id)
eps.mu.RUnlock()
// Fail if we didn't find one.
if ep == nil {
return false
}
// Deliver the packet.
ep.HandleControlPacket(id, typ, extra, vv)
return true
}
// 根据传输层id来找到相应的传输层端
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
// Try to find a match with the id as provided.
// 从 endpoints 中看有没有匹配到id的传输层端
if ep := eps.endpoints[id]; ep != nil {
return ep
}
// Try to find a match with the id minus the local address.
nid := id
// 如果上面的 endpoints 没有找到那么去掉本地ip地址看看有没有相应的传输层端
// 因为有时候传输层监听的时候没有绑定本地ip也就是 any address此时的 LocalAddress
// 为空。
nid.LocalAddress = ""
if ep := eps.endpoints[nid]; ep != nil {
return ep
}
// Try to find a match with the id minus the remote part.
nid.LocalAddress = id.LocalAddress
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep := eps.endpoints[nid]; ep != nil {
return ep
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
return eps.endpoints[nid]
}

View File

@@ -0,0 +1,422 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package stack_test
import (
"testing"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/link/channel"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
const (
fakeTransNumber tcpip.TransportProtocolNumber = 1
fakeTransHeaderLen = 3
)
// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
// received packets; the counts of all endpoints are aggregated in the protocol
// descriptor.
//
// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
// use it.
type fakeTransportEndpoint struct {
id stack.TransportEndpointID
stack *stack.Stack
netProto tcpip.NetworkProtocolNumber
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
}
func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto}
}
func (f *fakeTransportEndpoint) Close() {
f.route.Release()
}
func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
return mask
}
func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
return buffer.View{}, tcpip.ControlMessages{}, nil
}
func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
if len(f.route.RemoteAddress) == 0 {
return 0, nil, tcpip.ErrNoRoute
}
hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
v, err := p.Get(p.Size())
if err != nil {
return 0, nil, err
}
if err := f.route.WritePacket(hdr, buffer.View(v).ToVectorisedView(), fakeTransNumber, 123); err != nil {
return 0, nil, err
}
return uintptr(len(v)), nil, nil
}
func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
// SetSockOpt sets a socket option. Currently not supported.
func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch opt.(type) {
case tcpip.ErrorOption:
return nil
}
return tcpip.ErrInvalidEndpointState
}
func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
f.peerAddr = addr.Addr
// Find the route.
r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber)
if err != nil {
return tcpip.ErrNoRoute
}
defer r.Release()
// Try to register so that we can start receiving packets.
f.id.RemoteAddress = addr.Addr
err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f)
if err != nil {
return err
}
f.route = r.Clone()
return nil
}
func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
return nil
}
func (*fakeTransportEndpoint) Reset() {
}
func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
return nil
}
func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, nil
}
func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
return commit()
}
func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, nil
}
func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
return tcpip.FullAddress{}, nil
}
func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) {
// Increment the number of received control packets.
f.proto.controlCount++
}
type fakeTransportGoodOption bool
type fakeTransportBadOption bool
type fakeTransportInvalidValueOption int
type fakeTransportProtocolOptions struct {
good bool
}
// fakeTransportProtocol is a transport-layer protocol descriptor. It
// aggregates the number of packets received via endpoints of this protocol.
type fakeTransportProtocol struct {
packetCount int
controlCount int
opts fakeTransportProtocolOptions
}
func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
return fakeTransNumber
}
func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newFakeTransportEndpoint(stack, f, netProto), nil
}
func (*fakeTransportProtocol) MinimumPacketSize() int {
return fakeTransHeaderLen
}
func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
return 0, 0, nil
}
func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
return true
}
func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case fakeTransportGoodOption:
f.opts.good = bool(v)
return nil
case fakeTransportInvalidValueOption:
return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *fakeTransportGoodOption:
*v = fakeTransportGoodOption(f.opts.good)
return nil
default:
return tcpip.ErrUnknownProtocolOption
}
}
func TestTransportReceive(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
// Create endpoint and connect to remote address.
wq := waiter.Queue{}
ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
t.Fatalf("Connect failed: %v", err)
}
fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
// Create buffer that will hold the packet.
buf := buffer.NewView(30)
// Make sure packet with wrong protocol is not delivered.
buf[0] = 1
buf[2] = 0
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
// Make sure packet from the wrong source is not delivered.
buf[0] = 1
buf[1] = 3
buf[2] = byte(fakeTransNumber)
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.packetCount != 0 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
}
// Make sure packet is delivered.
buf[0] = 1
buf[1] = 2
buf[2] = byte(fakeTransNumber)
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.packetCount != 1 {
t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
}
}
func TestTransportControlReceive(t *testing.T) {
id, linkEP := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
// Create endpoint and connect to remote address.
wq := waiter.Queue{}
ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
t.Fatalf("Connect failed: %v", err)
}
fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
// Create buffer that will hold the control packet.
buf := buffer.NewView(2*fakeNetHeaderLen + 30)
// Outer packet contains the control protocol number.
buf[0] = 1
buf[1] = 0xfe
buf[2] = uint8(fakeControlProtocol)
// Make sure packet with wrong protocol is not delivered.
buf[fakeNetHeaderLen+0] = 0
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = 0
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
// Make sure packet from the wrong source is not delivered.
buf[fakeNetHeaderLen+0] = 3
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.controlCount != 0 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
}
// Make sure packet is delivered.
buf[fakeNetHeaderLen+0] = 2
buf[fakeNetHeaderLen+1] = 1
buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
linkEP.Inject(fakeNetNumber, buf.ToVectorisedView())
if fakeTrans.controlCount != 1 {
t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
}
}
func TestTransportSend(t *testing.T) {
id, _ := channel.New(10, defaultMTU, "")
s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
// Create endpoint and bind it.
wq := waiter.Queue{}
ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
t.Fatalf("Connect failed: %v", err)
}
// Create buffer that will hold the payload.
view := buffer.NewView(30)
_, _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
if err != nil {
t.Fatalf("write failed: %v", err)
}
fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
if fakeNet.sendPacketCount[2] != 1 {
t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1)
}
}
func TestTransportOptions(t *testing.T) {
s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
// Try an unsupported transport protocol.
if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
}
testCases := []struct {
option interface{}
wantErr *tcpip.Error
verifier func(t *testing.T, p stack.TransportProtocol)
}{
{fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
t.Helper()
fakeTrans := p.(*fakeTransportProtocol)
if fakeTrans.opts.good != true {
t.Fatalf("fakeTrans.opts.good = false, want = true")
}
var v fakeTransportGoodOption
if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
}
if v != true {
t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
}
}},
{fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
{fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
}
for _, tc := range testCases {
if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
}
if tc.verifier != nil {
tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
}
}
}
func init() {
stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
return &fakeTransportProtocol{}
})
}

834
netstack/tcpip/tcpip.go Normal file
View File

@@ -0,0 +1,834 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tcpip provides the interfaces and related types that users of the
// tcpip stack will use in order to create endpoints used to send and receive
// data over the network stack.
//
// The starting point is the creation and configuration of a stack. A stack can
// be created by calling the New() function of the tcpip/stack/stack package;
// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
// adding network addresses (via calls to Stack.AddAddress()), and
// setting a route table (via a call to Stack.SetRouteTable()).
//
// Once a stack is configured, endpoints can be created by calling
// Stack.NewEndpoint(). Such endpoints can be used to send/receive data, connect
// to peers, listen for connections, accept connections, etc., depending on the
// transport protocol selected.
package tcpip
import (
"errors"
"fmt"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/waiter"
)
// Error represents an error in the netstack error space. Using a special type
// ensures that errors outside of this space are not accidentally introduced.
//
// Note: to support save / restore, it is important that all tcpip errors have
// distinct error messages.
type Error struct {
msg string
ignoreStats bool
}
// String implements fmt.Stringer.String.
func (e *Error) String() string {
return e.msg
}
// IgnoreStats indicates whether this error type should be included in failure
// counts in tcpip.Stats structs.
func (e *Error) IgnoreStats() bool {
return e.ignoreStats
}
// Errors that can be returned by the network stack.
var (
ErrUnknownProtocol = &Error{msg: "unknown protocol"}
ErrUnknownNICID = &Error{msg: "unknown nic id"}
ErrUnknownProtocolOption = &Error{msg: "unknown option for protocol"}
ErrDuplicateNICID = &Error{msg: "duplicate nic id"}
ErrDuplicateAddress = &Error{msg: "duplicate address"}
ErrNoRoute = &Error{msg: "no route"}
ErrBadLinkEndpoint = &Error{msg: "bad link layer endpoint"}
ErrAlreadyBound = &Error{msg: "endpoint already bound", ignoreStats: true}
ErrInvalidEndpointState = &Error{msg: "endpoint is in invalid state"}
ErrAlreadyConnecting = &Error{msg: "endpoint is already connecting", ignoreStats: true}
ErrAlreadyConnected = &Error{msg: "endpoint is already connected", ignoreStats: true}
ErrNoPortAvailable = &Error{msg: "no ports are available"}
ErrPortInUse = &Error{msg: "port is in use"}
ErrBadLocalAddress = &Error{msg: "bad local address"}
ErrClosedForSend = &Error{msg: "endpoint is closed for send"}
ErrClosedForReceive = &Error{msg: "endpoint is closed for receive"}
ErrWouldBlock = &Error{msg: "operation would block", ignoreStats: true}
ErrConnectionRefused = &Error{msg: "connection was refused"}
ErrTimeout = &Error{msg: "operation timed out"}
ErrAborted = &Error{msg: "operation aborted"}
ErrConnectStarted = &Error{msg: "connection attempt started", ignoreStats: true}
ErrDestinationRequired = &Error{msg: "destination address is required"}
ErrNotSupported = &Error{msg: "operation not supported"}
ErrQueueSizeNotSupported = &Error{msg: "queue size querying not supported"}
ErrNotConnected = &Error{msg: "endpoint not connected"}
ErrConnectionReset = &Error{msg: "connection reset by peer"}
ErrConnectionAborted = &Error{msg: "connection aborted"}
ErrNoSuchFile = &Error{msg: "no such file"}
ErrInvalidOptionValue = &Error{msg: "invalid option value specified"}
ErrNoLinkAddress = &Error{msg: "no remote link address"}
ErrBadAddress = &Error{msg: "bad address"}
ErrNetworkUnreachable = &Error{msg: "network is unreachable"}
ErrMessageTooLong = &Error{msg: "message too long"}
ErrNoBufferSpace = &Error{msg: "no buffer space available"}
)
// Errors related to Subnet
var (
errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
)
// ErrSaveRejection indicates a failed save due to unsupported networking state.
// This type of errors is only used for save logic.
type ErrSaveRejection struct {
Err error
}
// Error returns a sensible description of the save rejection error.
func (e ErrSaveRejection) Error() string {
return "save rejected due to unsupported networking state: " + e.Err.Error()
}
// A Clock provides the current time.
//
// Times returned by a Clock should always be used for application-visible
// time, but never for netstack internal timekeeping.
type Clock interface {
// NowNanoseconds returns the current real time as a number of
// nanoseconds since the Unix epoch.
NowNanoseconds() int64
// NowMonotonic returns a monotonic time value.
NowMonotonic() int64
}
// Address is a byte slice cast as a string that represents the address of a
// network node. Or, in the case of unix endpoints, it may represent a path.
type Address string
// AddressMask is a bitmask for an address.
type AddressMask string
// String implements Stringer.
func (a AddressMask) String() string {
return Address(a).String()
}
// Subnet is a subnet defined by its address and mask.
type Subnet struct {
address Address
mask AddressMask
}
// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
func NewSubnet(a Address, m AddressMask) (Subnet, error) {
if len(a) != len(m) {
return Subnet{}, errSubnetLengthMismatch
}
for i := 0; i < len(a); i++ {
if a[i]&^m[i] != 0 {
return Subnet{}, errSubnetAddressMasked
}
}
return Subnet{a, m}, nil
}
// Contains returns true iff the address is of the same length and matches the
// subnet address and mask.
func (s *Subnet) Contains(a Address) bool {
if len(a) != len(s.address) {
return false
}
for i := 0; i < len(a); i++ {
if a[i]&s.mask[i] != s.address[i] {
return false
}
}
return true
}
// ID returns the subnet ID.
func (s *Subnet) ID() Address {
return s.address
}
// Bits returns the number of ones (network bits) and zeros (host bits) in the
// subnet mask.
func (s *Subnet) Bits() (ones int, zeros int) {
for _, b := range []byte(s.mask) {
for i := uint(0); i < 8; i++ {
if b&(1<<i) == 0 {
zeros++
} else {
ones++
}
}
}
return
}
// Prefix returns the number of bits before the first host bit.
func (s *Subnet) Prefix() int {
for i, b := range []byte(s.mask) {
for j := 7; j >= 0; j-- {
if b&(1<<uint(j)) == 0 {
return i*8 + 7 - j
}
}
}
return len(s.mask) * 8
}
// Mask returns the subnet mask.
func (s *Subnet) Mask() AddressMask {
return s.mask
}
// NICID is a number that uniquely identifies a NIC.
type NICID int32
// ShutdownFlags represents flags that can be passed to the Shutdown() method
// of the Endpoint interface.
type ShutdownFlags int
// Values of the flags that can be passed to the Shutdown() method. They can
// be OR'ed together.
const (
ShutdownRead ShutdownFlags = 1 << iota
ShutdownWrite
)
// FullAddress represents a full transport node address, as required by the
// Connect() and Bind() methods.
//
// +stateify savable
type FullAddress struct {
// NIC is the ID of the NIC this address refers to.
//
// This may not be used by all endpoint types.
NIC NICID
// Addr is the network address.
Addr Address
// Port is the transport port.
//
// This may not be used by all endpoint types.
Port uint16
}
func (fa FullAddress) String() string {
return fmt.Sprintf("%d:%s:%d", fa.NIC, fa.Addr, fa.Port)
}
// Payload provides an interface around data that is being sent to an endpoint.
// This allows the endpoint to request the amount of data it needs based on
// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
type Payload interface {
// Get returns a slice containing exactly 'min(size, p.Size())' bytes.
Get(size int) ([]byte, *Error)
// Size returns the payload size.
Size() int
}
// SlicePayload implements Payload on top of slices for convenience.
type SlicePayload []byte
// Get implements Payload.
func (s SlicePayload) Get(size int) ([]byte, *Error) {
if size > s.Size() {
size = s.Size()
}
return s[:size], nil
}
// Size implements Payload.
func (s SlicePayload) Size() int {
return len(s)
}
// A ControlMessages contains socket control messages for IP sockets.
//
// +stateify savable
type ControlMessages struct {
// HasTimestamp indicates whether Timestamp is valid/set.
HasTimestamp bool
// Timestamp is the time (in ns) that the last packed used to create
// the read data was received.
Timestamp int64
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
type Endpoint interface {
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
Close()
// Read reads data from the endpoint and optionally returns the sender.
//
// This method does not block if there is no data pending. It will also
// either return an error or data, never both.
//
// A timestamp (in ns) is optionally returned. A zero value indicates
// that no timestamp was available.
Read(*FullAddress) (buffer.View, ControlMessages, *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
//
// Unlike io.Writer.Write, Endpoint.Write transfers ownership of any bytes
// successfully written to the Endpoint. That is, if a call to
// Write(SlicePayload{data}) returns (n, err), it may retain data[:n], and
// the caller should not use data[:n] after Write returns.
//
// Note that unlike io.Writer.Write, it is not an error for Write to
// perform a partial write.
//
// For UDP and Ping sockets if address resolution is required,
// ErrNoLinkAddress and a notification channel is returned for the caller to
// block. Channel is closed once address resolution is complete (success or
// not). The channel is only non-nil in this case.
Write(Payload, WriteOptions) (uintptr, <-chan struct{}, *Error)
// Peek reads data without consuming it from the endpoint.
//
// This method does not block if there is no data pending.
//
// A timestamp (in ns) is optionally returned. A zero value indicates
// that no timestamp was available.
Peek([][]byte) (uintptr, ControlMessages, *Error)
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
//
// There are three classes of return values:
// nil -- the attempt to connect succeeded.
// ErrConnectStarted/ErrAlreadyConnecting -- the connect attempt started
// but hasn't completed yet. In this case, the caller must call Connect
// or GetSockOpt(ErrorOption) when the endpoint becomes writable to
// get the actual result. The first call to Connect after the socket has
// connected returns nil. Calling connect again results in ErrAlreadyConnected.
// Anything else -- the attempt to connect failed.
Connect(address FullAddress) *Error
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
Shutdown(flags ShutdownFlags) *Error
// Listen puts the endpoint in "listen" mode, which allows it to accept
// new connections.
Listen(backlog int) *Error
// Accept returns a new endpoint if a peer has established a connection
// to an endpoint previously set to listen mode. This method does not
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
Accept() (Endpoint, *waiter.Queue, *Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
//
// An optional commit function will be executed atomically with respect
// to binding the endpoint. If this returns an error, the bind will not
// occur and the error will be propagated back to the caller.
Bind(address FullAddress, commit func() *Error) *Error
// GetLocalAddress returns the address to which the endpoint is bound.
GetLocalAddress() (FullAddress, *Error)
// GetRemoteAddress returns the address to which the endpoint is
// connected.
GetRemoteAddress() (FullAddress, *Error)
// Readiness returns the current readiness of the endpoint. For example,
// if waiter.EventIn is set, the endpoint is immediately readable.
Readiness(mask waiter.EventMask) waiter.EventMask
// SetSockOpt sets a socket option. opt should be one of the *Option types.
SetSockOpt(opt interface{}) *Error
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// *Option types.
GetSockOpt(opt interface{}) *Error
}
// WriteOptions contains options for Endpoint.Write.
type WriteOptions struct {
// If To is not nil, write to the given address instead of the endpoint's
// peer.
To *FullAddress
// More has the same semantics as Linux's MSG_MORE.
More bool
// EndOfRecord has the same semantics as Linux's MSG_EOR.
EndOfRecord bool
}
// ErrorOption is used in GetSockOpt to specify that the last error reported by
// the endpoint should be cleared and returned.
type ErrorOption struct{}
// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
// buffer size option.
type SendBufferSizeOption int
// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
// receive buffer size option.
type ReceiveBufferSizeOption int
// SendQueueSizeOption is used in GetSockOpt to specify that the number of
// unread bytes in the output buffer should be returned.
type SendQueueSizeOption int
// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
// unread bytes in the input buffer should be returned.
type ReceiveQueueSizeOption int
// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
type V6OnlyOption int
// NoDelayOption is used by SetSockOpt/GetSockOpt to specify if data should be
// sent out immediately by the transport protocol. For TCP, it determines if the
// Nagle algorithm is on or off.
type NoDelayOption int
// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
// should allow reuse of local address.
type ReuseAddressOption int
// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
// SCM_CREDENTIALS socket control messages are enabled.
//
// Only supported on Unix sockets.
type PasscredOption int
// TimestampOption is used by SetSockOpt/GetSockOpt to specify whether
// SO_TIMESTAMP socket control messages are enabled.
type TimestampOption int
// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
//
// TODO: Add and populate stat fields.
type TCPInfoOption struct {
RTT time.Duration
RTTVar time.Duration
}
// KeepaliveEnabledOption is used by SetSockOpt/GetSockOpt to specify whether
// TCP keepalive is enabled for this socket.
type KeepaliveEnabledOption int
// KeepaliveIdleOption is used by SetSockOpt/GetSockOpt to specify the time a
// connection must remain idle before the first TCP keepalive packet is sent.
// Once this time is reached, KeepaliveIntervalOption is used instead.
type KeepaliveIdleOption time.Duration
// KeepaliveIntervalOption is used by SetSockOpt/GetSockOpt to specify the
// interval between sending TCP keepalive packets.
type KeepaliveIntervalOption time.Duration
// KeepaliveCountOption is used by SetSockOpt/GetSockOpt to specify the number
// of un-ACKed TCP keepalives that will be sent before the connection is
// closed.
type KeepaliveCountOption int
// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default
// TTL value for multicast messages. The default is 1.
type MulticastTTLOption uint8
// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to
// AddMembershipOption and RemoveMembershipOption.
type MembershipOption struct {
NIC NICID
InterfaceAddr Address
MulticastAddr Address
}
// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast
// group identified by the given multicast address, on the interface matching
// the given interface address.
type AddMembershipOption MembershipOption
// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast
// group identified by the given multicast address, on the interface matching
// the given interface address.
type RemoveMembershipOption MembershipOption
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination adddress in the row.
type Route struct {
// Destination is the address that must be matched against the masked
// target address to check if this row is viable.
Destination Address
// Mask specifies which bits of the Destination and the target address
// must match for this row to be viable.
Mask AddressMask
// Gateway is the gateway to be used if this row is viable.
Gateway Address
// NIC is the id of the nic to be used if this row is viable.
NIC NICID
}
// Match determines if r is viable for the given destination address.
func (r *Route) Match(addr Address) bool {
if len(addr) != len(r.Destination) {
return false
}
for i := 0; i < len(r.Destination); i++ {
if (addr[i] & r.Mask[i]) != r.Destination[i] {
return false
}
}
return true
}
// LinkEndpointID represents a data link layer endpoint.
type LinkEndpointID uint64
// TransportProtocolNumber is the number of a transport protocol.
type TransportProtocolNumber uint32
// NetworkProtocolNumber is the number of a network protocol.
type NetworkProtocolNumber uint32
// A StatCounter keeps track of a statistic.
type StatCounter struct {
count uint64
}
// Increment adds one to the counter.
func (s *StatCounter) Increment() {
s.IncrementBy(1)
}
// Value returns the current value of the counter.
func (s *StatCounter) Value() uint64 {
return atomic.LoadUint64(&s.count)
}
// IncrementBy increments the counter by v.
func (s *StatCounter) IncrementBy(v uint64) {
atomic.AddUint64(&s.count, v)
}
// IPStats collects IP-specific stats (both v4 and v6).
type IPStats struct {
// PacketsReceived is the total number of IP packets received from the link
// layer in nic.DeliverNetworkPacket.
PacketsReceived *StatCounter
// InvalidAddressesReceived is the total number of IP packets received
// with an unknown or invalid destination address.
InvalidAddressesReceived *StatCounter
// PacketsDelivered is the total number of incoming IP packets that
// are successfully delivered to the transport layer via HandlePacket.
PacketsDelivered *StatCounter
// PacketsSent is the total number of IP packets sent via WritePacket.
PacketsSent *StatCounter
// OutgoingPacketErrors is the total number of IP packets which failed
// to write to a link-layer endpoint.
OutgoingPacketErrors *StatCounter
}
// TCPStats collects TCP-specific stats.
type TCPStats struct {
// ActiveConnectionOpenings is the number of connections opened successfully
// via Connect.
ActiveConnectionOpenings *StatCounter
// PassiveConnectionOpenings is the number of connections opened
// successfully via Listen.
PassiveConnectionOpenings *StatCounter
// FailedConnectionAttempts is the number of calls to Connect or Listen
// (active and passive openings, respectively) that end in an error.
FailedConnectionAttempts *StatCounter
// ValidSegmentsReceived is the number of TCP segments received that the
// transport layer successfully parsed.
ValidSegmentsReceived *StatCounter
// InvalidSegmentsReceived is the number of TCP segments received that
// the transport layer could not parse.
InvalidSegmentsReceived *StatCounter
// SegmentsSent is the number of TCP segments sent.
SegmentsSent *StatCounter
// ResetsSent is the number of TCP resets sent.
ResetsSent *StatCounter
// ResetsReceived is the number of TCP resets received.
ResetsReceived *StatCounter
}
// UDPStats collects UDP-specific stats.
type UDPStats struct {
// PacketsReceived is the number of UDP datagrams received via
// HandlePacket.
PacketsReceived *StatCounter
// UnknownPortErrors is the number of incoming UDP datagrams dropped
// because they did not have a known destination port.
UnknownPortErrors *StatCounter
// ReceiveBufferErrors is the number of incoming UDP datagrams dropped
// due to the receiving buffer being in an invalid state.
ReceiveBufferErrors *StatCounter
// MalformedPacketsReceived is the number of incoming UDP datagrams
// dropped due to the UDP header being in a malformed state.
MalformedPacketsReceived *StatCounter
// PacketsSent is the number of UDP datagrams sent via sendUDP.
PacketsSent *StatCounter
}
// Stats holds statistics about the networking stack.
//
// All fields are optional.
type Stats struct {
// UnknownProtocolRcvdPackets is the number of packets received by the
// stack that were for an unknown or unsupported protocol.
UnknownProtocolRcvdPackets *StatCounter
// MalformedRcvPackets is the number of packets received by the stack
// that were deemed malformed.
MalformedRcvdPackets *StatCounter
// DroppedPackets is the number of packets dropped due to full queues.
DroppedPackets *StatCounter
// IP breaks out IP-specific stats (both v4 and v6).
IP IPStats
// TCP breaks out TCP-specific stats.
TCP TCPStats
// UDP breaks out UDP-specific stats.
UDP UDPStats
}
func fillIn(v reflect.Value) {
for i := 0; i < v.NumField(); i++ {
v := v.Field(i)
switch v.Kind() {
case reflect.Ptr:
if s, ok := v.Addr().Interface().(**StatCounter); ok {
if *s == nil {
*s = &StatCounter{}
}
}
case reflect.Struct:
fillIn(v)
}
}
}
// FillIn returns a copy of s with nil fields initialized to new StatCounters.
func (s Stats) FillIn() Stats {
fillIn(reflect.ValueOf(&s).Elem())
return s
}
// String implements the fmt.Stringer interface.
func (a Address) String() string {
switch len(a) {
case 4:
return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3]))
case 16:
// Find the longest subsequence of hexadecimal zeros.
start, end := -1, -1
for i := 0; i < len(a); i += 2 {
j := i
for j < len(a) && a[j] == 0 && a[j+1] == 0 {
j += 2
}
if j > i+2 && j-i > end-start {
start, end = i, j
}
}
var b strings.Builder
for i := 0; i < len(a); i += 2 {
if i == start {
b.WriteString("::")
i = end
if end >= len(a) {
break
}
} else if i > 0 {
b.WriteByte(':')
}
v := uint16(a[i+0])<<8 | uint16(a[i+1])
if v == 0 {
b.WriteByte('0')
} else {
const digits = "0123456789abcdef"
for i := uint(3); i < 4; i-- {
if v := v >> (i * 4); v != 0 {
b.WriteByte(digits[v&0xf])
}
}
}
}
return b.String()
default:
return fmt.Sprintf("%x", []byte(a))
}
}
// To4 converts the IPv4 address to a 4-byte representation.
// If the address is not an IPv4 address, To4 returns "".
func (a Address) To4() Address {
const (
ipv4len = 4
ipv6len = 16
)
if len(a) == ipv4len {
return a
}
if len(a) == ipv6len &&
isZeros(a[0:10]) &&
a[10] == 0xff &&
a[11] == 0xff {
return a[12:16]
}
return ""
}
// isZeros reports whether a is all zeros.
func isZeros(a Address) bool {
for i := 0; i < len(a); i++ {
if a[i] != 0 {
return false
}
}
return true
}
// LinkAddress is a byte slice cast as a string that represents a link address.
// It is typically a 6-byte MAC address.
type LinkAddress string
// String implements the fmt.Stringer interface.
func (a LinkAddress) String() string {
switch len(a) {
case 6:
return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
default:
return fmt.Sprintf("%x", []byte(a))
}
}
// ParseMACAddress parses an IEEE 802 address.
//
// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
func ParseMACAddress(s string) (LinkAddress, error) {
parts := strings.FieldsFunc(s, func(c rune) bool {
return c == ':' || c == '-'
})
if len(parts) != 6 {
return "", fmt.Errorf("inconsistent parts: %s", s)
}
addr := make([]byte, 0, len(parts))
for _, part := range parts {
u, err := strconv.ParseUint(part, 16, 8)
if err != nil {
return "", fmt.Errorf("invalid hex digits: %s", s)
}
addr = append(addr, byte(u))
}
return LinkAddress(addr), nil
}
// ProtocolAddress is an address and the network protocol it is associated
// with.
type ProtocolAddress struct {
// Protocol is the protocol of the address.
Protocol NetworkProtocolNumber
// Address is a network address.
Address Address
}
// danglingEndpointsMu protects access to danglingEndpoints.
var danglingEndpointsMu sync.Mutex
// danglingEndpoints tracks all dangling endpoints no longer owned by the app.
var danglingEndpoints = make(map[Endpoint]struct{})
// GetDanglingEndpoints returns all dangling endpoints.
func GetDanglingEndpoints() []Endpoint {
es := make([]Endpoint, 0, len(danglingEndpoints))
danglingEndpointsMu.Lock()
for e := range danglingEndpoints {
es = append(es, e)
}
danglingEndpointsMu.Unlock()
return es
}
// AddDanglingEndpoint adds a dangling endpoint.
func AddDanglingEndpoint(e Endpoint) {
danglingEndpointsMu.Lock()
danglingEndpoints[e] = struct{}{}
danglingEndpointsMu.Unlock()
}
// DeleteDanglingEndpoint removes a dangling endpoint.
func DeleteDanglingEndpoint(e Endpoint) {
danglingEndpointsMu.Lock()
delete(danglingEndpoints, e)
danglingEndpointsMu.Unlock()
}
// AsyncLoading is the global barrier for asynchronous endpoint loading
// activities.
var AsyncLoading sync.WaitGroup

View File

@@ -0,0 +1,195 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcpip
import (
"net"
"testing"
)
func TestSubnetContains(t *testing.T) {
tests := []struct {
s Address
m AddressMask
a Address
want bool
}{
{"\xa0", "\xf0", "\x90", false},
{"\xa0", "\xf0", "\xa0", true},
{"\xa0", "\xf0", "\xa5", true},
{"\xa0", "\xf0", "\xaf", true},
{"\xa0", "\xf0", "\xb0", false},
{"\xa0", "\xf0", "", false},
{"\xa0", "\xf0", "\xa0\x00", false},
{"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
{"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
{"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
{"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
}
for _, tt := range tests {
s, err := NewSubnet(tt.s, tt.m)
if err != nil {
t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err)
continue
}
if got := s.Contains(tt.a); got != tt.want {
t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want)
}
}
}
func TestSubnetBits(t *testing.T) {
tests := []struct {
a AddressMask
want1 int
want0 int
}{
{"\x00", 0, 8},
{"\x00\x00", 0, 16},
{"\x36", 4, 4},
{"\x5c", 4, 4},
{"\x5c\x5c", 8, 8},
{"\x5c\x36", 8, 8},
{"\x36\x5c", 8, 8},
{"\x36\x36", 8, 8},
{"\xff", 8, 0},
{"\xff\xff", 16, 0},
}
for _, tt := range tests {
s := &Subnet{mask: tt.a}
got1, got0 := s.Bits()
if got1 != tt.want1 || got0 != tt.want0 {
t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0)
}
}
}
func TestSubnetPrefix(t *testing.T) {
tests := []struct {
a AddressMask
want int
}{
{"\x00", 0},
{"\x00\x00", 0},
{"\x36", 0},
{"\x86", 1},
{"\xc5", 2},
{"\xff\x00", 8},
{"\xff\x36", 8},
{"\xff\x8c", 9},
{"\xff\xc8", 10},
{"\xff", 8},
{"\xff\xff", 16},
}
for _, tt := range tests {
s := &Subnet{mask: tt.a}
got := s.Prefix()
if got != tt.want {
t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want)
}
}
}
func TestSubnetCreation(t *testing.T) {
tests := []struct {
a Address
m AddressMask
want error
}{
{"\xa0", "\xf0", nil},
{"\xa0\xa0", "\xf0", errSubnetLengthMismatch},
{"\xaa", "\xf0", errSubnetAddressMasked},
{"", "", nil},
}
for _, tt := range tests {
if _, err := NewSubnet(tt.a, tt.m); err != tt.want {
t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want)
}
}
}
func TestRouteMatch(t *testing.T) {
tests := []struct {
d Address
m AddressMask
a Address
want bool
}{
{"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
{"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
{"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
{"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
}
for _, tt := range tests {
r := Route{Destination: tt.d, Mask: tt.m}
if got := r.Match(tt.a); got != tt.want {
t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want)
}
}
}
func TestAddressString(t *testing.T) {
for _, want := range []string{
// Taken from stdlib.
"2001:db8::123:12:1",
"2001:db8::1",
"2001:db8:0:1:0:1:0:1",
"2001:db8:1:0:1:0:1:0",
"2001::1:0:0:1",
"2001:db8:0:0:1::",
"2001:db8::1:0:0:1",
"2001:db8::a:b:c:d",
// Leading zeros.
"::1",
// Trailing zeros.
"8::",
// No zeros.
"1:1:1:1:1:1:1:1",
// Longer sequence is after other zeros, but not at the end.
"1:0:0:1::1",
// Longer sequence is at the beginning, shorter sequence is at
// the end.
"::1:1:1:0:0",
// Longer sequence is not at the beginning, shorter sequence is
// at the end.
"1::1:1:0:0",
// Longer sequence is at the beginning, shorter sequence is not
// at the end.
"::1:1:0:0:1",
// Neither sequence is at an end, longer is after shorter.
"1:0:0:1::1",
// Shorter sequence is at the beginning, longer sequence is not
// at the end.
"0:0:1:1::1",
// Shorter sequence is at the beginning, longer sequence is at
// the end.
"0:0:1:1:1::",
// Short sequences at both ends, longer one in the middle.
"0:1:1::1:1:0",
// Short sequences at both ends, longer one in the middle.
"0:1::1:0:0",
// Short sequences at both ends, longer one in the middle.
"0:0:1::1:0",
// Longer sequence surrounded by shorter sequences, but none at
// the end.
"1:0:1::1:0:1",
} {
addr := Address(net.ParseIP(want))
if got := addr.String(); got != want {
t.Errorf("Address(%x).String() = '%s', want = '%s'", addr, got, want)
}
}
}

15
netstack/tcpip/time.s Normal file
View File

@@ -0,0 +1,15 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Empty assembly file so empty func definitions work.

View File

@@ -0,0 +1,42 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build go1.9
package tcpip
import (
_ "time" // Used with go:linkname.
_ "unsafe" // Required for go:linkname.
)
// StdClock implements Clock with the time package.
type StdClock struct{}
var _ Clock = (*StdClock)(nil)
//go:linkname now time.now
func now() (sec int64, nsec int32, mono int64)
// NowNanoseconds implements Clock.NowNanoseconds.
func (*StdClock) NowNanoseconds() int64 {
sec, nsec, _ := now()
return sec*1e9 + int64(nsec)
}
// NowMonotonic implements Clock.NowMonotonic.
func (*StdClock) NowMonotonic() int64 {
_, _, mono := now()
return mono
}

View File

@@ -0,0 +1,717 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package ping
import (
"encoding/binary"
"sync"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
// +stateify savable
// ping包的结构体
type pingPacket struct {
pingPacketEntry
senderAddress tcpip.FullAddress
data buffer.VectorisedView
timestamp int64
hasTimestamp bool
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
}
type endpointState int
const (
stateInitial endpointState = iota
stateBound
stateConnected
stateClosed
)
// endpoint represents a ping endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
// synchronized.
// 表示ping端用作用户端和协议栈之间的接口
type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack
netProto tcpip.NetworkProtocolNumber
transProto tcpip.TransportProtocolNumber
waiterQueue *waiter.Queue
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
rcvMu sync.Mutex
rcvReady bool
rcvList pingPacketList
rcvBufSizeMax int
rcvBufSize int
rcvClosed bool
rcvTimestamp bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex
sndBufSize int
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
id stack.TransportEndpointID
state endpointState
bindNICID tcpip.NICID
bindAddr tcpip.Address
regNICID tcpip.NICID
route stack.Route
}
// 根据参数新建一个ping端
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber,
transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
stack: stack,
netProto: netProto,
transProto: transProto,
waiterQueue: waiterQueue,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
}
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
// 关闭ping端回收资源
func (e *endpoint) Close() {
e.mu.Lock()
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id)
}
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
e.rcvBufSize = 0
for !e.rcvList.Empty() {
p := e.rcvList.Front()
e.rcvList.Remove(p)
}
e.rcvMu.Unlock()
e.route.Release()
// Update the state.
e.state = stateClosed
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
return buffer.View{}, tcpip.ControlMessages{}, err
}
p := e.rcvList.Front()
e.rcvList.Remove(p)
e.rcvBufSize -= p.data.Size()
ts := e.rcvTimestamp
e.rcvMu.Unlock()
if addr != nil {
*addr = p.senderAddress
}
if ts && !p.hasTimestamp {
// Linux uses the current time.
p.timestamp = e.stack.NowNanoseconds()
}
return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
// binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
case stateInitial:
case stateConnected:
return false, nil
case stateBound:
if to == nil {
return false, tcpip.ErrDestinationRequired
}
return false, nil
default:
return false, tcpip.ErrInvalidEndpointState
}
e.mu.RUnlock()
defer e.mu.RLock()
e.mu.Lock()
defer e.mu.Unlock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
if e.state != stateInitial {
return true, nil
}
// The state is still 'initial', so try to bind the endpoint.
if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
return false, err
}
return true, nil
}
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
to := opts.To
e.mu.RLock()
defer e.mu.RUnlock()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
return 0, nil, tcpip.ErrClosedForSend
}
// Prepare for write.
for {
retry, err := e.prepareForWrite(to)
if err != nil {
return 0, nil, err
}
if !retry {
break
}
}
var route *stack.Route
if to == nil {
route = &e.route
if route.IsResolutionRequired() {
// Promote lock to exclusive if using a shared route, given that it may
// need to change in Route.Resolve() call below.
e.mu.RUnlock()
defer e.mu.RLock()
e.mu.Lock()
defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
if e.state != stateConnected {
return 0, nil, tcpip.ErrInvalidEndpointState
}
}
} else {
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicid := to.NIC
if e.bindNICID != 0 {
if nicid != 0 && nicid != e.bindNICID {
return 0, nil, tcpip.ErrNoRoute
}
nicid = e.bindNICID
}
toCopy := *to
to = &toCopy
netProto, err := e.checkV4Mapped(to, true)
if err != nil {
return 0, nil, err
}
// Find the enpoint.
r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto)
if err != nil {
return 0, nil, err
}
defer r.Release()
route = &r
}
if route.IsResolutionRequired() {
waker := &sleep.Waker{}
if ch, err := route.Resolve(waker); err != nil {
if err == tcpip.ErrWouldBlock {
// Link address needs to be resolved. Resolution was triggered the
// background. Better luck next time.
route.RemoveWaker(waker)
return 0, ch, tcpip.ErrNoLinkAddress
}
return 0, nil, err
}
}
v, err := p.Get(p.Size())
if err != nil {
return 0, nil, err
}
switch e.netProto {
case header.IPv4ProtocolNumber:
err = sendPing4(route, e.id.LocalPort, v)
case header.IPv6ProtocolNumber:
err = sendPing6(route, e.id.LocalPort, v)
}
return uintptr(len(v)), nil, err
}
// Peek only returns data from a single datagram, so do nothing here.
func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
// SetSockOpt sets a socket option. Currently not supported.
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.TimestampOption:
e.rcvMu.Lock()
e.rcvTimestamp = v != 0
e.rcvMu.Unlock()
}
return nil
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
return nil
case *tcpip.SendBufferSizeOption:
e.mu.Lock()
*o = tcpip.SendBufferSizeOption(e.sndBufSize)
e.mu.Unlock()
return nil
case *tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
*o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
e.rcvMu.Unlock()
return nil
case *tcpip.ReceiveQueueSizeOption:
e.rcvMu.Lock()
if e.rcvList.Empty() {
*o = 0
} else {
p := e.rcvList.Front()
*o = tcpip.ReceiveQueueSizeOption(p.data.Size())
}
e.rcvMu.Unlock()
return nil
case *tcpip.TimestampOption:
e.rcvMu.Lock()
*o = 0
if e.rcvTimestamp {
*o = 1
}
e.rcvMu.Unlock()
}
return tcpip.ErrUnknownProtocolOption
}
func sendPing4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
if len(data) < header.ICMPv4EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
// Set the ident. Sequence number is provided by the user.
binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], ident)
hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
copy(icmpv4, data)
data = data[header.ICMPv4EchoMinimumSize:]
// Linux performs these basic checks.
if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
return tcpip.ErrInvalidEndpointState
}
icmpv4.SetChecksum(0)
icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
}
func sendPing6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
if len(data) < header.ICMPv6EchoMinimumSize {
return tcpip.ErrInvalidEndpointState
}
// Set the ident. Sequence number is provided by the user.
binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
copy(icmpv6, data)
data = data[header.ICMPv6EchoMinimumSize:]
if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
return tcpip.ErrInvalidEndpointState
}
icmpv6.SetChecksum(0)
icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
return r.WritePacket(hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL())
}
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto
if header.IsV4MappedAddress(addr.Addr) {
return 0, tcpip.ErrNoRoute
}
// Fail if we're bound to an address length different from the one we're
// checking.
if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
return netProto, nil
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
nicid := addr.NIC
localPort := uint16(0)
switch e.state {
case stateBound, stateConnected:
localPort = e.id.LocalPort
if e.bindNICID == 0 {
break
}
if nicid != 0 && nicid != e.bindNICID {
return tcpip.ErrInvalidEndpointState
}
nicid = e.bindNICID
default:
return tcpip.ErrInvalidEndpointState
}
netProto, err := e.checkV4Mapped(&addr, false)
if err != nil {
return err
}
// Find a route to the desired destination.
r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto)
if err != nil {
return err
}
defer r.Release()
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
LocalPort: localPort,
RemoteAddress: r.RemoteAddress,
}
// Even if we're connected, this endpoint can still be used to send
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
id, err = e.registerWithStack(nicid, netProtos, id)
if err != nil {
return err
}
e.id = id
e.route = r.Clone()
e.regNICID = nicid
e.state = stateConnected
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
return nil
}
// ConnectEndpoint is not supported.
func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
e.shutdownFlags |= flags
if e.state != stateConnected {
return tcpip.ErrNotConnected
}
if flags&tcpip.ShutdownRead != 0 {
e.rcvMu.Lock()
wasClosed := e.rcvClosed
e.rcvClosed = true
e.rcvMu.Unlock()
if !wasClosed {
e.waiterQueue.Notify(waiter.EventIn)
}
}
return nil
}
// Listen is not supported by UDP, it just fails.
func (*endpoint) Listen(int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept is not supported by UDP, it just fails.
func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
return id, err
}
// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
id.LocalPort = p
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e)
switch err {
case nil:
return true, nil
case tcpip.ErrPortInUse:
return false, nil
default:
return false, err
}
})
return id, err
}
func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
if e.state != stateInitial {
return tcpip.ErrInvalidEndpointState
}
netProto, err := e.checkV4Mapped(&addr, false)
if err != nil {
return err
}
// Expand netProtos to include v4 and v6 if the caller is binding to a
// wildcard (empty) address, and this is an IPv6 endpoint with v6only
// set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
if len(addr.Addr) != 0 {
// A local address was specified, verify that it's valid.
if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
}
id := stack.TransportEndpointID{
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
id, err = e.registerWithStack(addr.NIC, netProtos, id)
if err != nil {
return err
}
if commit != nil {
if err := commit(); err != nil {
// Unregister, the commit failed.
e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, e.transProto, id)
return err
}
}
e.id = id
e.regNICID = addr.NIC
// Mark endpoint as bound.
e.state = stateBound
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
return nil
}
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
err := e.bindLocked(addr, commit)
if err != nil {
return err
}
e.bindNICID = addr.NIC
e.bindAddr = addr.Addr
return nil
}
// GetLocalAddress returns the address to which the endpoint is bound.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
return tcpip.FullAddress{
NIC: e.regNICID,
Addr: e.id.LocalAddress,
Port: e.id.LocalPort,
}, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.state != stateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
return tcpip.FullAddress{
NIC: e.regNICID,
Addr: e.id.RemoteAddress,
Port: e.id.RemotePort,
}, nil
}
// Readiness returns the current readiness of the endpoint. For example, if
// waiter.EventIn is set, the endpoint is immediately readable.
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// The endpoint is always writable.
result := waiter.EventOut & mask
// Determine if the endpoint is readable if requested.
if (mask & waiter.EventIn) != 0 {
e.rcvMu.Lock()
if !e.rcvList.Empty() || e.rcvClosed {
result |= waiter.EventIn
}
e.rcvMu.Unlock()
}
return result
}
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
e.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
e.rcvMu.Unlock()
return
}
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
pkt := &pingPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
},
}
pkt.data = vv.Clone(pkt.views[:])
e.rcvList.PushBack(pkt)
e.rcvBufSize += vv.Size()
if e.rcvTimestamp {
pkt.timestamp = e.stack.NowNanoseconds()
pkt.hasTimestamp = true
}
e.rcvMu.Unlock()
// Notify any waiters that there's data to be read now.
if wasEmpty {
e.waiterQueue.Notify(waiter.EventIn)
}
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID,
typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}

View File

@@ -0,0 +1,173 @@
package ping
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type pingPacketElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (pingPacketElementMapper) linkerFor(elem *pingPacket) *pingPacket { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type pingPacketList struct {
head *pingPacket
tail *pingPacket
}
// Reset resets list l to the empty state.
func (l *pingPacketList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
func (l *pingPacketList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
func (l *pingPacketList) Front() *pingPacket {
return l.head
}
// Back returns the last element of list l or nil.
func (l *pingPacketList) Back() *pingPacket {
return l.tail
}
// PushFront inserts the element e at the front of list l.
func (l *pingPacketList) PushFront(e *pingPacket) {
pingPacketElementMapper{}.linkerFor(e).SetNext(l.head)
pingPacketElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
pingPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushBack inserts the element e at the back of list l.
func (l *pingPacketList) PushBack(e *pingPacket) {
pingPacketElementMapper{}.linkerFor(e).SetNext(nil)
pingPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
pingPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
func (l *pingPacketList) PushBackList(m *pingPacketList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
pingPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
pingPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *pingPacketList) InsertAfter(b, e *pingPacket) {
a := pingPacketElementMapper{}.linkerFor(b).Next()
pingPacketElementMapper{}.linkerFor(e).SetNext(a)
pingPacketElementMapper{}.linkerFor(e).SetPrev(b)
pingPacketElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
pingPacketElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
func (l *pingPacketList) InsertBefore(a, e *pingPacket) {
b := pingPacketElementMapper{}.linkerFor(a).Prev()
pingPacketElementMapper{}.linkerFor(e).SetNext(a)
pingPacketElementMapper{}.linkerFor(e).SetPrev(b)
pingPacketElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
pingPacketElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
func (l *pingPacketList) Remove(e *pingPacket) {
prev := pingPacketElementMapper{}.linkerFor(e).Prev()
next := pingPacketElementMapper{}.linkerFor(e).Next()
if prev != nil {
pingPacketElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
pingPacketElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type pingPacketEntry struct {
next *pingPacket
prev *pingPacket
}
// Next returns the entry that follows e in the list.
func (e *pingPacketEntry) Next() *pingPacket {
return e.next
}
// Prev returns the entry that precedes e in the list.
func (e *pingPacketEntry) Prev() *pingPacket {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
func (e *pingPacketEntry) SetNext(elem *pingPacket) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *pingPacketEntry) SetPrev(elem *pingPacket) {
e.prev = elem
}

View File

@@ -0,0 +1,124 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package ping contains the implementation of the ICMP and IPv6-ICMP transport
// protocols for use in ping. To use it in the networking stack, this package
// must be added to the project, and
// activated on the stack by passing ping.ProtocolName (or "ping") and/or
// ping.ProtocolName6 (or "ping6") as one of the transport protocols when
// calling stack.New(). Then endpoints can be created by passing
// ping.ProtocolNumber or ping.ProtocolNumber6 as the transport protocol number
// when calling Stack.NewEndpoint().
package ping
import (
"encoding/binary"
"fmt"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
const (
// ProtocolName4 is the string representation of the ping protocol name.
ProtocolName4 = "ping4"
// ProtocolNumber4 is the ICMP protocol number.
ProtocolNumber4 = header.ICMPv4ProtocolNumber
// ProtocolName6 is the string representation of the ping protocol name.
ProtocolName6 = "ping6"
// ProtocolNumber6 is the IPv6-ICMP protocol number.
ProtocolNumber6 = header.ICMPv6ProtocolNumber
)
type protocol struct {
number tcpip.TransportProtocolNumber
}
// Number returns the ICMP protocol number.
func (p *protocol) Number() tcpip.TransportProtocolNumber {
return p.number
}
func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
switch p.number {
case ProtocolNumber4:
return header.IPv4ProtocolNumber
case ProtocolNumber6:
return header.IPv6ProtocolNumber
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
// NewEndpoint creates a new ping endpoint.
func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
if netProto != p.netProto() {
return nil, tcpip.ErrUnknownProtocol
}
return newEndpoint(stack, netProto, p.number, waiterQueue), nil
}
// MinimumPacketSize returns the minimum valid ping packet size.
func (p *protocol) MinimumPacketSize() int {
switch p.number {
case ProtocolNumber4:
return header.ICMPv4EchoMinimumSize
case ProtocolNumber6:
return header.ICMPv6EchoMinimumSize
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
// ParsePorts returns the source and destination ports stored in the given ping
// packet.
func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
switch p.number {
case ProtocolNumber4:
return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil
case ProtocolNumber6:
return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
}
panic(fmt.Sprint("unknown protocol number: ", p.number))
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
return true
}
// SetOption implements TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
return &protocol{ProtocolNumber4}
})
stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
return &protocol{ProtocolNumber6}
})
}

View File

@@ -0,0 +1,443 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"encoding/binary"
"hash"
"io"
"sync"
"time"
"crypto/sha1"
"tcpip/netstack/rand"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
const (
// tsLen is the length, in bits, of the timestamp in the SYN cookie.
tsLen = 8
// tsMask is a mask for timestamp values (i.e., tsLen bits).
tsMask = (1 << tsLen) - 1
// tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
tsOffset = 24
// hashMask is the mask for hash values (i.e., tsOffset bits).
hashMask = (1 << tsOffset) - 1
// maxTSDiff is the maximum allowed difference between a received cookie
// timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired.
maxTSDiff = 2
)
var (
// SynRcvdCountThreshold is the global maximum number of connections
// that are allowed to be in SYN-RCVD state before TCP starts using SYN
// cookies to accept connections.
//
// It is an exported variable only for testing, and should not otherwise
// be used by importers of this package.
SynRcvdCountThreshold uint64 = 1000
// mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460}
)
func encodeMSS(mss uint16) uint32 {
for i := len(mssTable) - 1; i > 0; i-- {
if mss >= mssTable[i] {
return uint32(i)
}
}
return 0
}
// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
// protected by a mutex so that we can increment only when it's guaranteed not
// to go above a threshold.
var synRcvdCount struct {
sync.Mutex
value uint64
pending sync.WaitGroup
}
// listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects.
type listenContext struct {
stack *stack.Stack
rcvWnd seqnum.Size
nonce [2][sha1.BlockSize]byte
hasherMu sync.Mutex
hasher hash.Hash
v6only bool
netProto tcpip.NetworkProtocolNumber
}
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask
}
// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
// state. It succeeds if the increment doesn't make the count go beyond the
// threshold, and fails otherwise.
func incSynRcvdCount() bool {
synRcvdCount.Lock()
defer synRcvdCount.Unlock()
if synRcvdCount.value >= SynRcvdCountThreshold {
return false
}
synRcvdCount.pending.Add(1)
synRcvdCount.value++
return true
}
// decSynRcvdCount atomically decrements the global number of endpoints in
// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
// succeeded.
func decSynRcvdCount() {
synRcvdCount.Lock()
defer synRcvdCount.Unlock()
synRcvdCount.value--
synRcvdCount.pending.Done()
}
// newListenContext creates a new listen context.
func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{
stack: stack,
rcvWnd: rcvWnd,
hasher: sha1.New(),
v6only: v6only,
netProto: netProto,
}
rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:])
return l
}
// cookieHash calculates the cookieHash for the given id, timestamp and nonce
// index. The hash is used to create and validate cookies.
func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
// Initialize block with fixed-size data: local ports and v.
var payload [8]byte
binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
binary.BigEndian.PutUint32(payload[4:], ts)
// Feed everything to the hasher.
l.hasherMu.Lock()
l.hasher.Reset()
l.hasher.Write(payload[:])
l.hasher.Write(l.nonce[nonceIndex][:])
io.WriteString(l.hasher, string(id.LocalAddress))
io.WriteString(l.hasher, string(id.RemoteAddress))
// Finalize the calculation of the hash and return the first 4 bytes.
h := make([]byte, 0, sha1.Size)
h = l.hasher.Sum(h)
l.hasherMu.Unlock()
return binary.BigEndian.Uint32(h[:])
}
// createCookie creates a SYN cookie for the given id and incoming sequence
// number.
func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
ts := timeStamp()
v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
v += (l.cookieHash(id, ts, 1) + data) & hashMask
return seqnum.Value(v)
}
// isCookieValid checks if the supplied cookie is valid for the given id and
// sequence number. If it is, it also returns the data originally encoded in the
// cookie when createCookie was called.
func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
ts := timeStamp()
v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
cookieTS := v >> tsOffset
if ((ts - cookieTS) & tsMask) > maxTSDiff {
return 0, false
}
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
}
// createConnectedEndpoint creates a new connected endpoint, with the connection
// parameters given by the arguments.
func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create a new endpoint.
netProto := l.netProto
if netProto == 0 {
netProto = s.route.NetProto
}
n := newEndpoint(l.stack, netProto, nil)
n.v6only = l.v6only
n.id = s.id
n.boundNICID = s.route.NICID()
n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd)
n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts)
// Register new endpoint so that packets are routed to it.
if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
n.Close()
return nil, err
}
n.isRegistered = true
n.state = stateConnected
// Create sender and receiver.
//
// The receiver at least temporarily has a zero receive window scale,
// but the caller may change it (before starting the protocol loop).
n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
return n, nil
}
// createEndpoint creates a new endpoint in connected state and then performs
// the TCP 3-way handshake.
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create new endpoint.
irs := s.sequenceNumber
cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
ep, err := l.createConnectedEndpoint(s, cookie, irs, opts)
if err != nil {
return nil, err
}
// Perform the 3-way handshake.
// 执行三次握手
h, err := newHandshake(ep, l.rcvWnd)
if err != nil {
ep.Close()
return nil, err
}
// 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack
h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil {
ep.Close()
return nil, err
}
// Update the receive window scaling. We can't do it before the
// handshake because it's possible that the peer doesn't support window
// scaling.
ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
return ep, nil
}
// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
// endpoint has transitioned out of the listen state, the new endpoint is closed
// instead.
func (e *endpoint) deliverAccepted(n *endpoint) {
e.mu.RLock()
if e.state == stateListen {
e.acceptedChan <- n
e.waiterQueue.Notify(waiter.EventIn)
} else {
n.Close()
}
e.mu.RUnlock()
}
// handleSynSegment is called in its own goroutine once the listening endpoint
// receives a SYN segment. It is responsible for completing the handshake and
// queueing the new endpoint for acceptance.
//
// A limited number of these goroutines are allowed before TCP starts using SYN
// cookies to accept connections.
// 一旦侦听端点收到SYN段handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。
// 在TCP开始使用SYN cookie接受连接之前允许使用有限数量的这些goroutine。
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer decSynRcvdCount()
defer s.decRef()
n, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil {
return
}
// 到这里,三次握手已经完成,那么分发一个新的连接
e.deliverAccepted(n)
}
// handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
switch s.flags {
case flagSyn:
// syn报文处理
// 分析tcp选项
opts := parseSynSegmentOptions(s)
// 如果处于 syn-recv 状态的连接没有超过 1000那么不需要生成 cookie
// 而是直接分配和新建连接。
if incSynRcvdCount() {
s.incRef()
go e.handleSynSegment(ctx, s, &opts)
} else {
cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
// Send SYN with window scaling because we currently
// dont't encode this information in the cookie.
//
// Enable Timestamp option if the original syn did have
// the timestamp option specified.
synOpts := header.TCPSynOptions{
WS: -1,
TS: opts.TS,
TSVal: tcpTimeStamp(timeStampOffset()),
TSEcr: opts.TSVal,
}
// 返回 syn+ack 报文
sendSynTCP(&s.route, s.id, flagSyn|flagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
}
case flagAck:
// 三次握手最后一次 ack 报文
if data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1); ok && int(data) < len(mssTable) {
// Create newly accepted endpoint and deliver it.
rcvdSynOptions := &header.TCPSynOptions{
MSS: mssTable[data],
// Disable Window scaling as original SYN is
// lost.
WS: -1,
}
// When syn cookies are in use we enable timestamp only
// if the ack specifies the timestamp option assuming
// that the other end did in fact negotiate the
// timestamp option in the original SYN.
if s.parsedOptions.TS {
rcvdSynOptions.TS = true
rcvdSynOptions.TSVal = s.parsedOptions.TSVal
rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
}
// 三次握手已经完成新建一个tcp连接
n, err := ctx.createConnectedEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
if err == nil {
// clear the tsOffset for the newly created
// endpoint as the Timestamp was already
// randomly offset when the original SYN-ACK was
// sent above.
n.tsOffset = 0
e.deliverAccepted(n)
}
}
}
}
// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
// its own goroutine and is responsible for handling connection requests.
// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行负责处理连接请求。
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
defer func() {
// Mark endpoint as closed. This will prevent goroutines running
// handleSynSegment() from attempting to queue new connections
// to the endpoint.
e.mu.Lock()
e.state = stateClosed
// Do cleanup if needed.
e.completeWorkerLocked()
if e.drainDone != nil {
close(e.drainDone)
}
e.mu.Unlock()
// Notify waiters that the endpoint is shutdown.
e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
}()
e.mu.Lock()
v6only := e.v6only
e.mu.Unlock()
ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto)
// 事件触发器的初始化和事件的添加
s := sleep.Sleeper{}
s.AddWaker(&e.notificationWaker, wakerForNotification)
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
// for循环拿到事件index然后进行相应的处理
for {
switch index, _ := s.Fetch(true); index {
case wakerForNotification:
n := e.fetchNotifications()
if n&notifyClose != 0 {
return nil
}
if n&notifyDrain != 0 {
for !e.segmentQueue.empty() {
s := e.segmentQueue.dequeue()
e.handleListenSegment(ctx, s)
s.decRef()
}
synRcvdCount.pending.Wait()
close(e.drainDone)
<-e.undrain
}
case wakerForNewSegment:
// Process at most maxSegmentsPerWake segments.
// 接收和处理tcp报文
mayRequeue := true
for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue()
if s == nil {
mayRequeue = false
break
}
// 处理tcp数据段 s
e.handleListenSegment(ctx, s)
s.decRef()
}
// If the queue is not empty, make sure we'll wake up
// in the next iteration.
if mayRequeue && !e.segmentQueue.empty() {
e.newSegmentWaker.Assert()
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,255 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"math"
"time"
)
// cubicState 存储与TCP CUBIC拥塞控制算法状态相关的变量。详见: https://tools.ietf.org/html/rfc8312.
type cubicState struct {
// wLastMax is the previous wMax value.
// 上次最大的拥塞窗口
wLastMax float64
// wMax is the value of the congestion window at the
// time of last congestion event.
// 上次拥塞时间时的拥塞窗口大小
wMax float64
// t denotes the time when the current congestion avoidance
// was entered.
// t 表进入拥塞避免阶段的时间。
t time.Time
// numCongestionEvents tracks the number of congestion events since last
// RTO.
// numCongestionEvents 跟踪自上次 RTO 以来的拥塞事件数。
numCongestionEvents int
// c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
// per RFC.
// 三次方函数的系数
c float64
// k is the time period that the above function takes to increase the
// current window size to W_max if there are no further congestion
// events and is calculated using the following equation:
// k是在没有其他拥塞事件时将当前窗口大小增加到W_max所需的时间段并使用以下公式计算
//
// K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
k float64
// beta is the CUBIC multiplication decrease factor. that is, when a
// congestion event is detected, CUBIC reduces its cwnd to
// W_cubic(0)=W_max*beta_cubic.
// beta 是CUBIC乘法减少因子。也就是说当检测到拥塞事件时CUBIC将其cwnd减少到
beta float64
// wC is window computed by CUBIC at time t. It's calculated using the
// formula:
// wC 是由CUBIC在时间t计算的窗口。它使用公式计算
//
// W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
wC float64
// wEst is the window computed by CUBIC at time t+RTT i.e
// wEs t是CUBIC在时间 t+RTT 计算的窗口,即
// W_cubic(t+RTT).
wEst float64
s *sender
}
// newCubicCC returns a partially initialized cubic state with the constants
// beta and c set and t set to current time.
// newCubicCC 返回部分初始化的 cubic 状态常量为beta和ct为当前时间。
func newCubicCC(s *sender) *cubicState {
return &cubicState{
t: time.Now(),
beta: 0.7,
c: 0.4,
s: s,
}
}
// enterCongestionAvoidance is used to initialize cubic in cases where we exit
// SlowStart without a real congestion event taking place. This can happen when
// a connection goes back to slow start due to a retransmit and we exceed the
// previously lowered ssThresh without experiencing packet loss.
//
// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
// enterCongestionAvoidance 用于在我们退出 SlowStart 而没有发生真正的拥塞事件的情况下初始化 cubic。
// 当连接由于重新传输而返回慢启动时会发生这种情况并且我们超过先前降低的ssThresh而不会遇到丢包。
func (c *cubicState) enterCongestionAvoidance() {
// See: https://tools.ietf.org/html/rfc8312#section-4.7 &
// https://tools.ietf.org/html/rfc8312#section-4.8
// 初次进入拥塞避免只记录当前拥塞避免的时间点当前的窗口wMax=cwnd
if c.numCongestionEvents == 0 {
c.k = 0
c.t = time.Now()
c.wLastMax = c.wMax
c.wMax = float64(c.s.sndCwnd)
}
}
// updateSlowStart will update the congestion window as per the slow-start
// algorithm used by NewReno. If after adjusting the congestion window we cross
// the ssThresh then it will return the number of packets that must be consumed
// in congestion avoidance mode.
// updateSlowStart 将根据NewReno使用的慢启动算法更新拥塞窗口。
// 如果在调整拥塞窗口之后我们越过ssThresh那么它将返回在拥塞避免模式下必须消耗的数据包的数量。
func (c *cubicState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
newcwnd := c.s.sndCwnd + packetsAcked
enterCA := false
if newcwnd >= c.s.sndSsthresh {
newcwnd = c.s.sndSsthresh
c.s.sndCAAckCount = 0
enterCA = true
}
packetsAcked -= newcwnd - c.s.sndCwnd
c.s.sndCwnd = newcwnd
if enterCA {
// 进入拥塞避免
c.enterCongestionAvoidance()
}
return packetsAcked
}
// Update updates cubic's internal state variables. It must be called on every
// ACK received.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
// Update 更新 cubic 内部状态变量每次收到ack都必须调用
func (c *cubicState) Update(packetsAcked int) {
if c.s.sndCwnd < c.s.sndSsthresh {
packetsAcked = c.updateSlowStart(packetsAcked)
if packetsAcked == 0 {
return
}
} else {
// 拥塞避免阶段
c.s.rtt.Lock()
srtt := c.s.rtt.srtt
c.s.rtt.Unlock()
c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
}
}
// cubicCwnd computes the CUBIC congestion window after t seconds from last
// congestion event.
// cubicCwnd 在上次拥塞事件发生t秒后计算CUBIC拥塞窗口。
func (c *cubicState) cubicCwnd(t float64) float64 {
return c.c*math.Pow(t, 3.0) + c.wMax
}
// getCwnd returns the current congestion window as computed by CUBIC.
// Refer: https://tools.ietf.org/html/rfc8312#section-4
// getCwnd 返回由CUBIC计算的当前拥塞窗口。
func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
elapsed := time.Since(c.t).Seconds()
// Compute the window as per Cubic after 'elapsed' time
// since last congestion event.
c.wC = c.cubicCwnd(elapsed - c.k)
// Compute the TCP friendly estimate of the congestion window.
c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
// Make sure in the TCP friendly region CUBIC performs at least
// as well as Reno.
if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
// TCP Friendly region of cubic.
return int(c.wEst)
}
// In Concave/Convex region of CUBIC, calculate what CUBIC window
// will be after 1 RTT and use that to grow congestion window
// for every ack.
tEst := (time.Since(c.t) + srtt).Seconds()
wtRtt := c.cubicCwnd(tEst - c.k)
// As per 4.3 for each received ACK cwnd must be incremented
// by (w_cubic(t+RTT) - cwnd/cwnd.
cwnd := float64(sndCwnd)
for i := 0; i < packetsAcked; i++ {
// Concave/Convex regions of cubic have the same formulas.
// See: https://tools.ietf.org/html/rfc8312#section-4.3
cwnd += (wtRtt - cwnd) / cwnd
}
return int(cwnd)
}
// HandleNDupAcks implements congestionControl.HandleNDupAcks.
// 收到三次重复ack调用 HandleNDupAcks
func (c *cubicState) HandleNDupAcks() {
// See: https://tools.ietf.org/html/rfc8312#section-4.5
c.numCongestionEvents++
c.t = time.Now()
c.wLastMax = c.wMax
c.wMax = float64(c.s.sndCwnd)
c.fastConvergence()
c.reduceSlowStartThreshold()
}
// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
// 发生重传时调用 HandleRTOExpired。
func (c *cubicState) HandleRTOExpired() {
// See: https://tools.ietf.org/html/rfc8312#section-4.6
c.t = time.Now()
c.numCongestionEvents = 0
c.wLastMax = c.wMax
c.wMax = float64(c.s.sndCwnd)
c.fastConvergence()
// We lost a packet, so reduce ssthresh.
c.reduceSlowStartThreshold()
// Reduce the congestion window to 1, i.e., enter slow-start. Per
// RFC 5681, page 7, we must use 1 regardless of the value of the
// initial congestion window.
c.s.sndCwnd = 1
}
// fastConvergence implements the logic for Fast Convergence algorithm as
// described in https://tools.ietf.org/html/rfc8312#section-4.6.
// 快速收敛
func (c *cubicState) fastConvergence() {
if c.wMax < c.wLastMax {
c.wLastMax = c.wMax
c.wMax = c.wMax * (1.0 + c.beta) / 2.0
} else {
c.wLastMax = c.wMax
}
// Recompute k as wMax may have changed.
c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
}
// PostRecovery implemements congestionControl.PostRecovery.
// 更新t为当前的时间当发送方退出快速恢复阶段时将调用 PostRecovery
func (c *cubicState) PostRecovery() {
c.t = time.Now()
}
// reduceSlowStartThreshold returns new SsThresh as described in
// https://tools.ietf.org/html/rfc8312#section-4.7.
// 按 cubic 的算法更新慢启动和拥塞避免之间的阈值
func (c *cubicState) reduceSlowStartThreshold() {
c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
}

View File

@@ -0,0 +1,560 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp_test
import (
"testing"
"time"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/checker"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/network/ipv4"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/tcpip/transport/tcp"
"tcpip/netstack/tcpip/transport/tcp/testing/context"
"tcpip/netstack/waiter"
)
func TestV4MappedConnectOnV6Only(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(true)
// Start connection attempt, it must fail.
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
if err != tcpip.ErrNoRoute {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
}
func testV4Connect(t *testing.T, c *context.Context) {
// Start connection attempt.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
if err != tcpip.ErrConnectStarted {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
// Receive SYN packet.
b := c.GetPacket()
checker.IPv4(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
),
)
tcp := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
iss := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
SrcPort: tcp.DestinationPort(),
DstPort: tcp.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
// Receive ACK packet.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(uint32(iss)+1),
),
)
// Wait for connection to be established.
select {
case <-ch:
err = c.EP.GetSockOpt(tcpip.ErrorOption{})
if err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
}
}
func TestV4MappedConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Test the connection request.
testV4Connect(t, c)
}
func TestV4ConnectWhenBoundToWildcard(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test the connection request.
testV4Connect(t, c)
}
func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind to v4 mapped wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test the connection request.
testV4Connect(t, c)
}
func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind to v4 mapped address.
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test the connection request.
testV4Connect(t, c)
}
func testV6Connect(t *testing.T, c *context.Context) {
// Start connection attempt to IPv6 address.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventOut)
defer c.WQ.EventUnregister(&we)
err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
if err != tcpip.ErrConnectStarted {
t.Fatalf("Unexpected return value from Connect: %v", err)
}
// Receive SYN packet.
b := c.GetV6Packet()
checker.IPv6(t, b,
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagSyn),
),
)
tcp := header.TCP(header.IPv6(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
iss := seqnum.Value(789)
c.SendV6Packet(nil, &context.Headers{
SrcPort: tcp.DestinationPort(),
DstPort: tcp.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
})
// Receive ACK packet.
checker.IPv6(t, c.GetV6Packet(),
checker.TCP(
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(uint32(iss)+1),
),
)
// Wait for connection to be established.
select {
case <-ch:
err = c.EP.GetSockOpt(tcpip.ErrorOption{})
if err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
}
}
func TestV6Connect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Test the connection request.
testV6Connect(t, c)
}
func TestV6ConnectV6Only(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(true)
// Test the connection request.
testV6Connect(t, c)
}
func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test the connection request.
testV6Connect(t, c)
}
func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind to local address.
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test the connection request.
testV6Connect(t, c)
}
func TestV4RefuseOnV6Only(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(true)
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Start listening.
if err := c.EP.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
// Send a SYN request.
irs := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
})
// Receive the RST reply.
checker.IPv4(t, c.GetPacket(),
checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
checker.AckNum(uint32(irs)+1),
),
)
}
func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind and listen.
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
if err := c.EP.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
// Send a SYN request.
irs := seqnum.Value(789)
c.SendV6Packet(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
})
// Receive the RST reply.
checker.IPv6(t, c.GetV6Packet(),
checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
checker.AckNum(uint32(irs)+1),
),
)
}
func testV4Accept(t *testing.T, c *context.Context) {
// Start listening.
if err := c.EP.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
// Send a SYN request.
irs := seqnum.Value(789)
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
iss := seqnum.Value(tcp.SequenceNumber())
checker.IPv4(t, b,
checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
checker.AckNum(uint32(irs)+1),
),
)
// Send ACK.
c.SendPacket(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagAck,
SeqNum: irs + 1,
AckNum: iss + 1,
RcvWnd: 30000,
})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
nep, _, err = c.EP.Accept()
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for accept")
}
}
// Make sure we get the same error when calling the original ep and the
// new one. This validates that v4-mapped endpoints are still able to
// query the V6Only flag, whereas pure v4 endpoints are not.
var v tcpip.V6OnlyOption
expected := c.EP.GetSockOpt(&v)
if err := nep.GetSockOpt(&v); err != expected {
t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
}
// Check the peer address.
addr, err := nep.GetRemoteAddress()
if err != nil {
t.Fatalf("GetRemoteAddress failed failed: %v", err)
}
if addr.Addr != context.TestAddr {
t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
}
}
func TestV4AcceptOnV6(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Accept(t, c)
}
func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind to v4 mapped wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Accept(t, c)
}
func TestV4AcceptOnBoundToV4Mapped(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind and listen.
if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Accept(t, c)
}
func TestV6AcceptOnV6(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateV6Endpoint(false)
// Bind and listen.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
if err := c.EP.Listen(10); err != nil {
t.Fatalf("Listen failed: %v", err)
}
// Send a SYN request.
irs := seqnum.Value(789)
c.SendV6Packet(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagSyn,
SeqNum: irs,
RcvWnd: 30000,
})
// Receive the SYN-ACK reply.
b := c.GetV6Packet()
tcp := header.TCP(header.IPv6(b).Payload())
iss := seqnum.Value(tcp.SequenceNumber())
checker.IPv6(t, b,
checker.TCP(
checker.SrcPort(context.StackPort),
checker.DstPort(context.TestPort),
checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
checker.AckNum(uint32(irs)+1),
),
)
// Send ACK.
c.SendV6Packet(nil, &context.Headers{
SrcPort: context.TestPort,
DstPort: context.StackPort,
Flags: header.TCPFlagAck,
SeqNum: irs + 1,
AckNum: iss + 1,
RcvWnd: 30000,
})
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
nep, _, err := c.EP.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
nep, _, err = c.EP.Accept()
if err != nil {
t.Fatalf("Accept failed: %v", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for accept")
}
}
// Make sure we can still query the v6 only status of the new endpoint,
// that is, that it is in fact a v6 socket.
var v tcpip.V6OnlyOption
if err := nep.GetSockOpt(&v); err != nil {
t.Fatalf("GetSockOpt failed failed: %v", err)
}
// Check the peer address.
addr, err := nep.GetRemoteAddress()
if err != nil {
t.Fatalf("GetRemoteAddress failed failed: %v", err)
}
if addr.Addr != context.TestV6Addr {
t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
}
}
func TestV4AcceptOnV4(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
// Bind to wildcard.
if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Accept(t, c)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,174 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"sync"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
// Forwarder is a connection request forwarder, which allows clients to decide
// what to do with a connection request, for example: ignore it, send a RST, or
// attempt to complete the 3-way handshake.
//
// The canonical way of using it is to pass the Forwarder.HandlePacket function
// to stack.SetTransportProtocolHandler.
// Forwarder是一个连接请求转发器它允许客户端决定如何处理连接请求例如忽略它发送RST或尝试完成3次握手。
//
//使用它的规范方法是将Forwarder.HandlePacket函数传递给stack.SetTransportProtocolHandler。
type Forwarder struct {
maxInFlight int
handler func(*ForwarderRequest)
mu sync.Mutex
inFlight map[stack.TransportEndpointID]struct{}
listen *listenContext
}
// NewForwarder allocates and initializes a new forwarder with the given
// maximum number of in-flight connection attempts. Once the maximum is reached
// new incoming connection requests will be ignored.
//
// If rcvWnd is set to zero, the default buffer size is used instead.
func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
if rcvWnd == 0 {
rcvWnd = DefaultBufferSize
}
return &Forwarder{
maxInFlight: maxInFlight,
handler: handler,
inFlight: make(map[stack.TransportEndpointID]struct{}),
listen: newListenContext(s, seqnum.Size(rcvWnd), true, 0),
}
}
// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if
// it's a SYN packet), returning true if it's the case. Otherwise the packet
// is not handled and false is returned.
//
// This function is expected to be passed as an argument to the
// stack.SetTransportProtocolHandler function.
func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
// We only care about well-formed SYN packets.
if !s.parse() || s.flags != flagSyn {
return false
}
opts := parseSynSegmentOptions(s)
f.mu.Lock()
defer f.mu.Unlock()
// We have an inflight request for this id, ignore this one for now.
if _, ok := f.inFlight[id]; ok {
return true
}
// Ignore the segment if we're beyond the limit.
if len(f.inFlight) >= f.maxInFlight {
return true
}
// Launch a new goroutine to handle the request.
f.inFlight[id] = struct{}{}
s.incRef()
go f.handler(&ForwarderRequest{
forwarder: f,
segment: s,
synOptions: opts,
})
return true
}
// ForwarderRequest represents a connection request received by the forwarder
// and passed to the client. Clients must eventually call Complete() on it, and
// may optionally create an endpoint to represent it via CreateEndpoint.
type ForwarderRequest struct {
mu sync.Mutex
forwarder *Forwarder
segment *segment
synOptions header.TCPSynOptions
}
// ID returns the 4-tuple (src address, src port, dst address, dst port) that
// represents the connection request.
func (r *ForwarderRequest) ID() stack.TransportEndpointID {
return r.segment.id
}
// Complete completes the request, and optionally sends a RST segment back to the
// sender.
func (r *ForwarderRequest) Complete(sendReset bool) {
r.mu.Lock()
defer r.mu.Unlock()
if r.segment == nil {
panic("Completing already completed forwarder request")
}
// Remove request from the forwarder.
r.forwarder.mu.Lock()
delete(r.forwarder.inFlight, r.segment.id)
r.forwarder.mu.Unlock()
// If the caller requested, send a reset.
if sendReset {
replyWithReset(r.segment)
}
// Release all resources.
r.segment.decRef()
r.segment = nil
r.forwarder = nil
}
// CreateEndpoint creates a TCP endpoint for the connection request, performing
// the 3-way handshake in the process.
func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.segment == nil {
return nil, tcpip.ErrInvalidEndpointState
}
f := r.forwarder
ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
MSS: r.synOptions.MSS,
WS: r.synOptions.WS,
TS: r.synOptions.TS,
TSVal: r.synOptions.TSVal,
TSEcr: r.synOptions.TSEcr,
SACKPermitted: r.synOptions.SACKPermitted,
})
if err != nil {
return nil, err
}
// Start the protocol goroutine.
ep.startAcceptedLoop(queue)
return ep, nil
}

View File

@@ -0,0 +1,241 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tcp contains the implementation of the TCP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing tcp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
package tcp
import (
"strings"
"sync"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
const (
// ProtocolName is the string representation of the tcp protocol name.
ProtocolName = "tcp"
// ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber
// MinBufferSize is the smallest size of a receive or send buffer.
minBufferSize = 4 << 10 // 4096 bytes.
// DefaultBufferSize is the default size of the receive and send buffers.
DefaultBufferSize = 1 << 20 // 1MB
// MaxBufferSize is the largest size a receive and send buffer can grow to.
maxBufferSize = 4 << 20 // 4MB
)
// SACKEnabled option can be used to enable SACK support in the TCP
// protocol. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool
// SendBufferSizeOption allows the default, min and max send buffer sizes for
// TCP endpoints to be queried or configured.
type SendBufferSizeOption struct {
Min int
Default int
Max int
}
// ReceiveBufferSizeOption allows the default, min and max receive buffer size
// for TCP endpoints to be queried or configured.
type ReceiveBufferSizeOption struct {
Min int
Default int
Max int
}
const (
ccReno = "reno"
ccCubic = "cubic"
)
// CongestionControlOption sets the current congestion control algorithm.
type CongestionControlOption string
// AvailableCongestionControlOption returns the supported congestion control
// algorithms.
type AvailableCongestionControlOption string
// tcp协议传输层接口的实现
type protocol struct {
mu sync.Mutex
sackEnabled bool
sendBufferSize SendBufferSizeOption
recvBufferSize ReceiveBufferSizeOption
congestionControl string
availableCongestionControl []string
allowedCongestionControl []string
}
// Number returns the tcp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
return ProtocolNumber
}
// NewEndpoint creates a new tcp endpoint.
func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
// MinimumPacketSize returns the minimum valid tcp packet size.
func (*protocol) MinimumPacketSize() int {
return header.TCPMinimumSize
}
// ParsePorts returns the source and destination ports stored in the given tcp
// packet.
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
h := header.TCP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
//
// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then
// a reset is sent in response to any incoming segment except another reset. In
// particular, SYNs addressed to a non-existent connection are rejected by this
// means."
func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
s := newSegment(r, id, vv)
defer s.decRef()
if !s.parse() {
return false
}
// There's nothing to do if this is already a reset packet.
if s.flagIsSet(flagRst) {
return true
}
replyWithReset(s)
return true
}
// replyWithReset replies to the given segment with a reset segment.
func replyWithReset(s *segment) {
// Get the seqnum from the packet if the ack flag is set.
seq := seqnum.Value(0)
if s.flagIsSet(flagAck) {
seq = s.ackNumber
}
ack := s.sequenceNumber.Add(s.logicalLen())
sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), flagRst|flagAck, seq, ack, 0, nil)
}
// SetOption implements TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
case SACKEnabled:
p.mu.Lock()
p.sackEnabled = bool(v)
p.mu.Unlock()
return nil
case SendBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
p.sendBufferSize = v
p.mu.Unlock()
return nil
case ReceiveBufferSizeOption:
if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
return tcpip.ErrInvalidOptionValue
}
p.mu.Lock()
p.recvBufferSize = v
p.mu.Unlock()
return nil
case CongestionControlOption:
for _, c := range p.availableCongestionControl {
if string(v) == c {
p.mu.Lock()
p.congestionControl = string(v)
p.mu.Unlock()
return nil
}
}
return tcpip.ErrInvalidOptionValue
default:
return tcpip.ErrUnknownProtocolOption
}
}
// Option implements TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
case *SACKEnabled:
p.mu.Lock()
*v = SACKEnabled(p.sackEnabled)
p.mu.Unlock()
return nil
case *SendBufferSizeOption:
p.mu.Lock()
*v = p.sendBufferSize
p.mu.Unlock()
return nil
case *ReceiveBufferSizeOption:
p.mu.Lock()
*v = p.recvBufferSize
p.mu.Unlock()
return nil
case *CongestionControlOption:
p.mu.Lock()
*v = CongestionControlOption(p.congestionControl)
p.mu.Unlock()
return nil
case *AvailableCongestionControlOption:
p.mu.Lock()
*v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
p.mu.Unlock()
return nil
default:
return tcpip.ErrUnknownProtocolOption
}
}
func init() {
// tcp协议在协议栈中的注册
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{
sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
congestionControl: ccReno,
availableCongestionControl: []string{ccReno, ccCubic},
}
})
}

View File

@@ -0,0 +1,251 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"container/heap"
"tcpip/netstack/tcpip/seqnum"
)
// receiver holds the state necessary to receive TCP segments and turn them
// into a stream of bytes.
//
// +stateify savable
type receiver struct {
ep *endpoint
rcvNxt seqnum.Value
// rcvAcc is one beyond the last acceptable sequence number. That is,
// the "largest" sequence value that the receiver has announced to the
// its peer that it's willing to accept. This may be different than
// rcvNxt + rcvWnd if the receive window is reduced; in that case we
// have to reduce the window as we receive more data instead of
// shrinking it.
// rcvAcc 超出了最后一个可接受的序列号。也就是说,接收方向其同行宣布它愿意接受的“最大”序列值。
// 如果接收窗口减少这可能与rcvNxt + rcvWnd不同;在这种情况下,我们必须减少窗口,因为我们收到更多数据而不是缩小它。
rcvAcc seqnum.Value
rcvWndScale uint8
closed bool
pendingRcvdSegments segmentHeap
pendingBufUsed seqnum.Size
pendingBufSize seqnum.Size
}
// 初始化接收器
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
return &receiver{
ep: ep,
rcvNxt: irs + 1,
rcvAcc: irs.Add(rcvWnd + 1),
rcvWndScale: rcvWndScale,
pendingBufSize: rcvWnd,
}
}
// acceptable checks if the segment sequence number range is acceptable
// according to the table on page 26 of RFC 793.
// tcp流量控制判断 segSeq 在窗口內
func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
rcvWnd := r.rcvNxt.Size(r.rcvAcc)
if rcvWnd == 0 {
return segLen == 0 && segSeq == r.rcvNxt
}
return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
}
// getSendParams returns the parameters needed by the sender when building
// segments to send.
// getSendParams 在构建要发送的段时,返回发送方所需的参数。
// 并且更新接收窗口的指标 rcvAcc
func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
// Calculate the window size based on the current buffer size.
n := r.ep.receiveBufferAvailable()
acc := r.rcvNxt.Add(seqnum.Size(n))
if r.rcvAcc.LessThan(acc) {
r.rcvAcc = acc
}
return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale
}
// nonZeroWindow is called when the receive window grows from zero to nonzero;
// in such cases we may need to send an ack to indicate to our peer that it can
// resume sending data.
// tcp流量控制当接收窗口从零增长到非零时调用 nonZeroWindow;在这种情况下,
// 我们可能需要发送一个 ack以便向对端表明它可以恢复发送数据。
func (r *receiver) nonZeroWindow() {
if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
// We never got around to announcing a zero window size, so we
// don't need to immediately announce a nonzero one.
return
}
// Immediately send an ack.
r.ep.snd.sendAck()
}
// consumeSegment attempts to consume a segment that was received by r. The
// segment may have just been received or may have been received earlier but
// wasn't ready to be consumed then.
//
// Returns true if the segment was consumed, false if it cannot be consumed
// yet because of a missing segment.
// tcp可靠性consumeSegment 尝试使用r接收tcp段。该数据段可能刚刚收到或可能已经收到但尚未准备好被消费。
// 如果数据段被消耗则返回true如果由于缺少段而无法消耗则返回false。
func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
if segLen > 0 {
// If the segment doesn't include the seqnum we're expecting to
// consume now, we're missing a segment. We cannot proceed until
// we receive that segment though.
// 我们期望接收到的序列号范围应该是 seqStart <= rcvNxt < seqEnd
// 如果不在这个范围内说明我们少了数据段返回false表示不能立马消费
if !r.rcvNxt.InWindow(segSeq, segLen) {
return false
}
// Trim segment to eliminate already acknowledged data.
// 尝试去除已经确认过的数据
if segSeq.LessThan(r.rcvNxt) {
diff := segSeq.Size(r.rcvNxt)
segLen -= diff
segSeq.UpdateForward(diff)
s.sequenceNumber.UpdateForward(diff)
s.data.TrimFront(int(diff))
}
// Move segment to ready-to-deliver list. Wakeup any waiters.
// 将tcp段插入接收链表并通知应用层用数据来了
r.ep.readyToRead(s)
} else if segSeq != r.rcvNxt {
return false
}
// Update the segment that we're expecting to consume.
// 因为前面已经收到正确按序到达的数据,那么我们应该更新一下我们期望下次收到的序列号了
r.rcvNxt = segSeq.Add(segLen)
// Trim SACK Blocks to remove any SACK information that covers
// sequence numbers that have been consumed.
// 修剪SACK块以删除任何涵盖已消耗序列号的SACK信息。
TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
// 如果收到 fin 报文
if s.flagIsSet(flagFin) {
// 控制报文消耗一个字节的序列号因此这边期望下次收到的序列号加1
r.rcvNxt++
// 收到 fin立即回复 ack
r.ep.snd.sendAck()
// Tell any readers that no more data will come.
// 标记接收器关闭
// 触发上层应用可以读取
r.closed = true
r.ep.readyToRead(nil)
// Flush out any pending segments, except the very first one if
// it happens to be the one we're handling now because the
// caller is using it.
first := 0
if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s {
first = 1
}
for i := first; i < len(r.pendingRcvdSegments); i++ {
r.pendingRcvdSegments[i].decRef()
}
r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
}
return true
}
// handleRcvdSegment handles TCP segments directed at the connection managed by
// r as they arrive. It is called by the protocol main loop.
// 从 handleSegments 接收到tcp段然后进行处理消费所谓的消费就是将负载内容插入到接收队列中
func (r *receiver) handleRcvdSegment(s *segment) {
// We don't care about receive processing anymore if the receive side
// is closed.
if r.closed {
return
}
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// If the sequence number range is outside the acceptable range, just
// send an ACK. This is according to RFC 793, page 37.
// tcp流量控制判断该数据段的序列号是否在接收窗口内如果不在立即返回ack给对端。
if !r.acceptable(segSeq, segLen) {
r.ep.snd.sendAck()
return
}
// Defer segment processing if it can't be consumed now.
// tcp可靠性r.consumeSegment 返回值是个bool类型如果是true表示已经消费该数据段
// 如果不是,那么进行下面的处理,插入到 pendingRcvdSegments且进行堆排序。
if !r.consumeSegment(s, segSeq, segLen) {
// 如果有负载数据或者是 fin 报文,立即回复一个 ack 报文
if segLen > 0 || s.flagIsSet(flagFin) {
// We only store the segment if it's within our buffer
// size limit.
// tcp可靠性对于乱序的tcp段应该在等待处理段中缓存
if r.pendingBufUsed < r.pendingBufSize {
r.pendingBufUsed += s.logicalLen()
s.incRef()
// 插入堆中,且进行排序
heap.Push(&r.pendingRcvdSegments, s)
}
// tcp的可靠性更新 sack 块信息
UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
// Immediately send an ack so that the peer knows it may
// have to retransmit.
r.ep.snd.sendAck()
}
return
}
// By consuming the current segment, we may have filled a gap in the
// sequence number domain that allows pending segments to be consumed
// now. So try to do it.
// tcp的可靠性通过使用当前段我们可能填补了序列号域中的间隙该间隙允许现在使用待处理段。
// 所以试着去消费等待处理段。
for !r.closed && r.pendingRcvdSegments.Len() > 0 {
s := r.pendingRcvdSegments[0]
segLen := seqnum.Size(s.data.Size())
segSeq := s.sequenceNumber
// Skip segment altogether if it has already been acknowledged.
if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
!r.consumeSegment(s, segSeq, segLen) {
break
}
// 如果该tcp段已经被正确消费那么中等待处理段中删除
heap.Pop(&r.pendingRcvdSegments)
r.pendingBufUsed -= s.logicalLen()
s.decRef()
}
}

View File

@@ -0,0 +1,125 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
// renoState stores the variables related to TCP New Reno congestion
// control algorithm.
//
// +stateify savable
type renoState struct {
s *sender
}
// newRenoCC initializes the state for the NewReno congestion control algorithm.
// 新建 reno 算法对象
func newRenoCC(s *sender) *renoState {
return &renoState{s: s}
}
// updateSlowStart will update the congestion window as per the slow-start
// algorithm used by NewReno. If after adjusting the congestion window
// we cross the SSthreshold then it will return the number of packets that
// must be consumed in congestion avoidance mode.
// updateSlowStart 将根据NewReno使用的慢启动算法更新拥塞窗口。
// 如果在调整拥塞窗口后我们越过了 SSthreshold ,那么它将返回在拥塞避免模式下必须消耗的数据包的数量。
func (r *renoState) updateSlowStart(packetsAcked int) int {
// Don't let the congestion window cross into the congestion
// avoidance range.
// 在慢启动阶段每次收到acksndCwnd加上已确认的段数
newcwnd := r.s.sndCwnd + packetsAcked
// 判断增大过后的拥塞窗口是否超过慢启动阀值 sndSsthresh
// 如果超过 sndSsthresh ,将窗口调整为 sndSsthresh
if newcwnd >= r.s.sndSsthresh {
newcwnd = r.s.sndSsthresh
r.s.sndCAAckCount = 0
}
// 是否超过 sndSsthresh packetsAcked>0表示超过
packetsAcked -= newcwnd - r.s.sndCwnd
// 更新拥塞窗口
r.s.sndCwnd = newcwnd
return packetsAcked
}
// updateCongestionAvoidance will update congestion window in congestion
// avoidance mode as described in RFC5681 section 3.1
// updateCongestionAvoidance 将在拥塞避免模式下更新拥塞窗口,
// 如RFC5681第3.1节所述
func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
// Consume the packets in congestion avoidance mode.
// sndCAAckCount 累计收到的tcp段数
r.s.sndCAAckCount += packetsAcked
// 如果累计的段数超过当前的拥塞窗口,那么 sndCwnd 加上 sndCAAckCount/sndCwnd 的整数倍
if r.s.sndCAAckCount >= r.s.sndCwnd {
r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
}
}
// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
// page 6, eq. 4. It is called when we detect congestion in the network.
// 当检测到网络拥塞时,调用 reduceSlowStartThreshold。
// 它将 sndSsthresh 变为 outstanding 的一半。
// sndSsthresh 最小为2因为至少要比丢包后的拥塞窗口cwnd=1来的大才会进入慢启动阶段。
func (r *renoState) reduceSlowStartThreshold() {
r.s.sndSsthresh = r.s.outstanding / 2
if r.s.sndSsthresh < 2 {
r.s.sndSsthresh = 2
}
}
// Update updates the congestion state based on the number of packets that
// were acknowledged.
// Update implements congestionControl.Update.
// packetsAcked 表示已确认的tcp段数
func (r *renoState) Update(packetsAcked int) {
// 当拥塞窗口没有超过慢启动阀值的时候,使用慢启动来增大窗口,
// 否则进入拥塞避免阶段
if r.s.sndCwnd < r.s.sndSsthresh {
packetsAcked = r.updateSlowStart(packetsAcked)
if packetsAcked == 0 {
return
}
}
// 进入拥塞避免阶段
r.updateCongestionAvoidance(packetsAcked)
}
// HandleNDupAcks implements congestionControl.HandleNDupAcks.
// 当收到三个重复ack时调用 HandleNDupAcks 来处理。
func (r *renoState) HandleNDupAcks() {
// A retransmit was triggered due to nDupAckThreshold
// being hit. Reduce our slow start threshold.
// 减小慢启动阀值
r.reduceSlowStartThreshold()
}
// HandleRTOExpired implements congestionControl.HandleRTOExpired.
// 当当发生重传包时,调用 HandleRTOExpired 来处理。
func (r *renoState) HandleRTOExpired() {
// We lost a packet, so reduce ssthresh.
// 减小慢启动阀值
r.reduceSlowStartThreshold()
// Reduce the congestion window to 1, i.e., enter slow-start. Per
// RFC 5681, page 7, we must use 1 regardless of the value of the
// initial congestion window.
// 更新拥塞窗口为1这样就会重新进入慢启动
r.s.sndCwnd = 1
}
// PostRecovery implements congestionControl.PostRecovery.
func (r *renoState) PostRecovery() {
// noop.
}

View File

@@ -0,0 +1,103 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
)
const (
// MaxSACKBlocks is the maximum number of SACK blocks stored
// at receiver side.
// MaxSACKBlocks 是接收端存储的最大SACK块数。
MaxSACKBlocks = 6
)
// UpdateSACKBlocks updates the list of SACK blocks to include the segment
// specified by segStart->segEnd. If the segment happens to be an out of order
// delivery then the first block in the sack.blocks always includes the
// segment identified by segStart->segEnd.
// tcp的可靠性UpdateSACKBlocks 更新SACK块列表以包含 segStart-segEnd 指定的段只有没有被消费掉的seg才会被用来更新sack。
// 如果该段恰好是无序传递那么sack.blocks中的第一个块总是包括由 segStart-segEnd 标识的段。
func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
newSB := header.SACKBlock{Start: segStart, End: segEnd}
if sack.NumBlocks == 0 {
sack.Blocks[0] = newSB
sack.NumBlocks = 1
return
}
var n = 0
for i := 0; i < sack.NumBlocks; i++ {
start, end := sack.Blocks[i].Start, sack.Blocks[i].End
if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
// Discard any invalid blocks where end is before start
// and discard any sack blocks that are before rcvNxt as
// those have already been acked.
continue
}
if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) {
// Merge this SACK block into newSB and discard this SACK
// block.
if start.LessThan(newSB.Start) {
newSB.Start = start
}
if newSB.End.LessThan(end) {
newSB.End = end
}
} else {
// Save this block.
sack.Blocks[n] = sack.Blocks[i]
n++
}
}
if rcvNxt.LessThan(newSB.Start) {
// If this was an out of order segment then make sure that the
// first SACK block is the one that includes the segment.
//
// See the first bullet point in
// https://tools.ietf.org/html/rfc2018#section-4
if n == MaxSACKBlocks {
// If the number of SACK blocks is equal to
// MaxSACKBlocks then discard the last SACK block.
n--
}
for i := n - 1; i >= 0; i-- {
sack.Blocks[i+1] = sack.Blocks[i]
}
sack.Blocks[0] = newSB
n++
}
sack.NumBlocks = n
}
// TrimSACKBlockList updates the sack block list by removing/modifying any block
// where start is < rcvNxt.
// tcp的可靠性TrimSACKBlockList 通过删除/修改 start为 <rcvNxt 的任何块来更新sack块列表。
func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) {
n := 0
for i := 0; i < sack.NumBlocks; i++ {
if sack.Blocks[i].End.LessThanEq(rcvNxt) {
continue
}
if sack.Blocks[i].Start.LessThan(rcvNxt) {
// Shrink this SACK block.
sack.Blocks[i].Start = rcvNxt
}
sack.Blocks[n] = sack.Blocks[i]
n++
}
sack.NumBlocks = n
}

View File

@@ -0,0 +1,184 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"strings"
"sync/atomic"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/tcpip/stack"
)
// Flags that may be set in a TCP segment.
const (
flagFin = 1 << iota
flagSyn
flagRst
flagPsh
flagAck
flagUrg
)
func flagString(flags uint8) string {
var s []string
if (flags & flagAck) != 0 {
s = append(s, "ack")
}
if (flags & flagFin) != 0 {
s = append(s, "fin")
}
if (flags & flagPsh) != 0 {
s = append(s, "psh")
}
if (flags & flagRst) != 0 {
s = append(s, "rst")
}
if (flags & flagSyn) != 0 {
s = append(s, "syn")
}
if (flags & flagUrg) != 0 {
s = append(s, "urg")
}
return strings.Join(s, "|")
}
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
// segment is mostly immutable, the only field allowed to change is viewToDeliver.
//
// +stateify savable
type segment struct {
segmentEntry
refCnt int32
id stack.TransportEndpointID
route stack.Route
data buffer.VectorisedView
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
// viewToDeliver keeps track of the next View that should be
// delivered by the Read endpoint.
viewToDeliver int
sequenceNumber seqnum.Value
ackNumber seqnum.Value
flags uint8
window seqnum.Size
// parsedOptions stores the parsed values from the options in the segment.
parsedOptions header.TCPOptions
options []byte
}
func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.data = vv.Clone(s.views[:])
return s
}
func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{
refCnt: 1,
id: id,
route: r.Clone(),
}
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
return s
}
func (s *segment) clone() *segment {
t := &segment{
refCnt: 1,
id: s.id,
sequenceNumber: s.sequenceNumber,
ackNumber: s.ackNumber,
flags: s.flags,
window: s.window,
route: s.route.Clone(),
viewToDeliver: s.viewToDeliver,
}
t.data = s.data.Clone(t.views[:])
return t
}
func (s *segment) flagIsSet(flag uint8) bool {
return (s.flags & flag) != 0
}
func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 {
s.route.Release()
}
}
func (s *segment) incRef() {
atomic.AddInt32(&s.refCnt, 1)
}
// logicalLen is the segment length in the sequence number space. It's defined
// as the data length plus one for each of the SYN and FIN bits set.
// 计算tcp段的逻辑长度包括负载数据的长度如果有控制标记需要加1
func (s *segment) logicalLen() seqnum.Size {
l := seqnum.Size(s.data.Size())
if s.flagIsSet(flagSyn) {
l++
}
if s.flagIsSet(flagFin) {
l++
}
return l
}
// parse populates the sequence & ack numbers, flags, and window fields of the
// segment from the TCP header stored in the data. It then updates the view to
// skip the data. Returns boolean indicating if the parsing was successful.
// parse 解析存储在数据中的TCP头填充tcp段的序列和确认号标志和窗口字段。
// 返回 bool值指示解析是否成功。
func (s *segment) parse() bool {
h := header.TCP(s.data.First())
// h is the header followed by the payload. We check that the offset to
// the data respects the following constraints:
// 1. That it's at least the minimum header size; if we don't do this
// then part of the header would be delivered to user.
// 2. That the header fits within the buffer; if we don't do this, we
// would panic when we tried to access data beyond the buffer.
//
// N.B. The segment has already been validated as having at least the
// minimum TCP size before reaching here, so it's safe to read the
// fields.
offset := int(h.DataOffset())
if offset < header.TCPMinimumSize || offset > len(h) {
return false
}
s.options = []byte(h[header.TCPMinimumSize:offset])
s.parsedOptions = header.ParseTCPOptions(s.options)
s.data.TrimFront(offset)
s.sequenceNumber = seqnum.Value(h.SequenceNumber())
s.ackNumber = seqnum.Value(h.AckNumber())
s.flags = h.Flags()
s.window = seqnum.Size(h.WindowSize())
return true
}

View File

@@ -0,0 +1,48 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
// tcp段的堆用来暂存失序的tcp段
// 实现了堆排序
type segmentHeap []*segment
// Len returns the length of h.
func (h segmentHeap) Len() int {
return len(h)
}
// Less determines whether the i-th element of h is less than the j-th element.
func (h segmentHeap) Less(i, j int) bool {
return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
}
// Swap swaps the i-th and j-th elements of h.
func (h segmentHeap) Swap(i, j int) {
h[i], h[j] = h[j], h[i]
}
// Push adds x as the last element of h.
func (h *segmentHeap) Push(x interface{}) {
*h = append(*h, x.(*segment))
}
// Pop removes the last element of h and returns it.
func (h *segmentHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[:n-1]
return x
}

View File

@@ -0,0 +1,81 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"sync"
"tcpip/netstack/tcpip/header"
)
// segmentQueue is a bounded, thread-safe queue of TCP segments.
//
// +stateify savable
type segmentQueue struct {
mu sync.Mutex
list segmentList
limit int
used int
}
// empty determines if the queue is empty.
func (q *segmentQueue) empty() bool {
q.mu.Lock()
r := q.used == 0
q.mu.Unlock()
return r
}
// setLimit updates the limit. No segments are immediately dropped in case the
// queue becomes full due to the new limit.
func (q *segmentQueue) setLimit(limit int) {
q.mu.Lock()
q.limit = limit
q.mu.Unlock()
}
// enqueue adds the given segment to the queue.
//
// Returns true when the segment is successfully added to the queue, in which
// case ownership of the reference is transferred to the queue. And returns
// false if the queue is full, in which case ownership is retained by the
// caller.
func (q *segmentQueue) enqueue(s *segment) bool {
q.mu.Lock()
r := q.used < q.limit
if r {
q.list.PushBack(s)
q.used += s.data.Size() + header.TCPMinimumSize
}
q.mu.Unlock()
return r
}
// dequeue removes and returns the next segment from queue, if one exists.
// Ownership is transferred to the caller, who is responsible for decrementing
// the ref count when done.
func (q *segmentQueue) dequeue() *segment {
q.mu.Lock()
s := q.list.Front()
if s != nil {
q.list.Remove(s)
q.used -= s.data.Size() + header.TCPMinimumSize
}
q.mu.Unlock()
return s
}

View File

@@ -0,0 +1,797 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"math"
"sync"
"time"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
)
const (
// minRTO is the minimum allowed value for the retransmit timeout.
minRTO = 200 * time.Millisecond
// InitialCwnd is the initial congestion window.
// 初始拥塞窗口大小
InitialCwnd = 10
// nDupAckThreshold is the number of duplicate ACK's required
// before fast-retransmit is entered.
nDupAckThreshold = 3
)
// congestionControl is an interface that must be implemented by any supported
// congestion control algorithm.
// tcp拥塞控制拥塞控制算法的接口
type congestionControl interface {
// HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold
// just before entering fast retransmit.
// 在进入快速重新传输之前,当 sender.dupAckCount> = nDupAckThreshold 时调用HandleNDupAcks。
HandleNDupAcks()
// HandleRTOExpired is invoked when the retransmit timer expires.
// 当重新传输计时器到期时调用HandleRTOExpired。
HandleRTOExpired()
// Update is invoked when processing inbound acks. It's passed the
// number of packet's that were acked by the most recent cumulative
// acknowledgement.
// 已经有数据包被确认时调用 Update。它传递了最近累积确认所确认的数据包数。
Update(packetsAcked int)
// PostRecovery is invoked when the sender is exiting a fast retransmit/
// recovery phase. This provides congestion control algorithms a way
// to adjust their state when exiting recovery.
// 当发送方退出快速重新传输/恢复阶段时将调用PostRecovery。
// 这为拥塞控制算法提供了一种在退出恢复时调整其状态的方法。
PostRecovery()
}
/*
+-------> sndwnd <-------+
| |
---------------------+-------------+----------+--------------------
| acked | * * * * * * | # # # # #| unable send
---------------------+-------------+----------+--------------------
^ ^
| |
sndUna sndNxt
***** in flight data
##### able send date
*/
// sender holds the state necessary to send TCP segments.
//
// +stateify savable
// tcp发送器它维护了tcp必要的状态
type sender struct {
ep *endpoint
// lastSendTime is the timestamp when the last packet was sent.
// lastSendTime 是发送最后一个数据包的时间戳。
lastSendTime time.Time
// dupAckCount is the number of duplicated acks received. It is used for
// fast retransmit.
// dupAckCount 是收到的重复ack数。它用于快速重传。
dupAckCount int
// fr holds state related to fast recovery.
// fr 持有与快速恢复有关的状态。
fr fastRecovery
// sndCwnd is the congestion window, in packets.
// sndCwnd 是拥塞窗口,单位是包
sndCwnd int
// sndSsthresh is the threshold between slow start and congestion
// avoidance.
// sndSsthresh 是慢启动和拥塞避免之间的阈值。
sndSsthresh int
// sndCAAckCount is the number of packets acknowledged during congestion
// avoidance. When enough packets have been ack'd (typically cwnd
// packets), the congestion window is incremented by one.
// sndCAAckCount 是拥塞避免期间确认的数据包数。当已经确认了足够的分组通常是cwnd分组拥塞窗口增加1。
sndCAAckCount int
// outstanding is the number of outstanding packets, that is, packets
// that have been sent but not yet acknowledged.
// outstanding 是正在发送的数据包的数量,即已发送但尚未确认的数据包。
outstanding int
// sndWnd is the send window size.
// 发送窗口大小,单位是字节
sndWnd seqnum.Size
// sndUna is the next unacknowledged sequence number.
// sndUna 是下一个未确认的序列号
sndUna seqnum.Value
// sndNxt is the sequence number of the next segment to be sent.
// sndNxt 是要发送的下一个段的序列号。
sndNxt seqnum.Value
// sndNxtList is the sequence number of the next segment to be added to
// the send list.
// sndNxtList 是要添加到发送列表的下一个段的序列号。
sndNxtList seqnum.Value
// rttMeasureSeqNum is the sequence number being used for the latest RTT
// measurement.
rttMeasureSeqNum seqnum.Value
// rttMeasureTime is the time when the rttMeasureSeqNum was sent.
rttMeasureTime time.Time
closed bool
writeNext *segment
// 发送链表
writeList segmentList
resendTimer timer
resendWaker sleep.Waker
// rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
// "round-trip time variation" and "retransmit timeout", as defined in
// section 2 of RFC 6298.
rtt rtt
rto time.Duration
srttInited bool
// maxPayloadSize is the maximum size of the payload of a given segment.
// It is initialized on demand.
maxPayloadSize int
// sndWndScale is the number of bits to shift left when reading the send
// window size from a segment.
sndWndScale uint8
// maxSentAck is the maxium acknowledgement actually sent.
maxSentAck seqnum.Value
// cc is the congestion control algorithm in use for this sender.
// cc 是实现拥塞控制算法的接口
cc congestionControl
}
// rtt is a synchronization wrapper used to appease stateify. See the comment
// in sender, where it is used.
//
// +stateify savable
type rtt struct {
sync.Mutex
srtt time.Duration
rttvar time.Duration
}
// fastRecovery holds information related to fast recovery from a packet loss.
//
// +stateify savable
// fastRecovery 保存与数据包丢失快速恢复相关的信息
type fastRecovery struct {
// active whether the endpoint is in fast recovery. The following fields
// are only meaningful when active is true.
active bool
// first and last represent the inclusive sequence number range being
// recovered.
first seqnum.Value
last seqnum.Value
// maxCwnd is the maximum value the congestion window may be inflated to
// due to duplicate acks. This exists to avoid attacks where the
// receiver intentionally sends duplicate acks to artificially inflate
// the sender's cwnd.
maxCwnd int
}
// 新建并初始化发送器
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
s := &sender{
ep: ep,
sndCwnd: InitialCwnd,
sndSsthresh: math.MaxInt64,
sndWnd: sndWnd,
sndUna: iss + 1,
sndNxt: iss + 1,
sndNxtList: iss + 1,
rto: 1 * time.Second,
rttMeasureSeqNum: iss + 1,
lastSendTime: time.Now(),
maxPayloadSize: int(mss),
maxSentAck: irs + 1,
fr: fastRecovery{
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
last: iss,
},
}
// 拥塞控制算法的初始化
s.cc = s.initCongestionControl(ep.cc)
// A negative sndWndScale means that no scaling is in use, otherwise we
// store the scaling value.
if sndWndScale > 0 {
s.sndWndScale = uint8(sndWndScale)
}
s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
s.resendTimer.init(&s.resendWaker)
return s
}
// tcp拥塞控制根据算法名新建拥塞控制算法和初始化
func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
switch congestionControlName {
case ccCubic:
return newCubicCC(s)
case ccReno:
fallthrough
default:
return newRenoCC(s)
}
}
// updateMaxPayloadSize updates the maximum payload size based on the given
// MTU. If this is in response to "packet too big" control packets (indicated
// by the count argument), it also reduces the number of outstanding packets and
// attempts to retransmit the first packet above the MTU size.
func (s *sender) updateMaxPayloadSize(mtu, count int) {
m := mtu - header.TCPMinimumSize
// Calculate the maximum option size.
// 计算MSS的大小
var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock
options := s.ep.makeOptions(maxSackBlocks[:])
m -= len(options)
putOptions(options)
// We don't adjust up for now.
if m >= s.maxPayloadSize {
return
}
// Make sure we can transmit at least one byte.
if m <= 0 {
m = 1
}
s.maxPayloadSize = m
s.outstanding -= count
if s.outstanding < 0 {
s.outstanding = 0
}
// Rewind writeNext to the first segment exceeding the MTU. Do nothing
// if it is already before such a packet.
for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
if seg == s.writeNext {
// We got to writeNext before we could find a segment
// exceeding the MTU.
break
}
if seg.data.Size() > m {
// We found a segment exceeding the MTU. Rewind
// writeNext and try to retransmit it.
s.writeNext = seg
break
}
}
// Since we likely reduced the number of outstanding packets, we may be
// ready to send some more.
s.sendData()
}
// sendAck sends an ACK segment.
// 发送一个ack tcp段
func (s *sender) sendAck() {
s.sendSegment(buffer.VectorisedView{}, flagAck, s.sndNxt)
}
// updateRTO updates the retransmit timeout when a new roud-trip time is
// available. This is done in accordance with section 2 of RFC 6298.
// 根据rtt来更新计算rto
func (s *sender) updateRTO(rtt time.Duration) {
s.rtt.Lock()
// 第一次计算
if !s.srttInited {
s.rtt.rttvar = rtt / 2
s.rtt.srtt = rtt
s.srttInited = true
} else {
diff := s.rtt.srtt - rtt
if diff < 0 {
diff = -diff
}
// Use RFC6298 standard algorithm to update rttvar and srtt when
// no timestamps are available.
if !s.ep.sendTSOk {
s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
} else {
// When we are taking RTT measurements of every ACK then
// we need to use a modified method as specified in
// https://tools.ietf.org/html/rfc7323#appendix-G
if s.outstanding == 0 {
s.rtt.Unlock()
return
}
// Netstack measures congestion window/inflight all in
// terms of packets and not bytes. This is similar to
// how linux also does cwnd and inflight. In practice
// this approximation works as expected.
expectedSamples := math.Ceil(float64(s.outstanding) / 2)
// alpha & beta values are the original values as recommended in
// https://tools.ietf.org/html/rfc6298#section-2.3.
const alpha = 0.125
const beta = 0.25
alphaPrime := alpha / expectedSamples
betaPrime := beta / expectedSamples
rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
s.rtt.srtt = time.Duration(srtt * float64(time.Second))
}
}
// 更新 rto 值
s.rto = s.rtt.srtt + 4*s.rtt.rttvar
s.rtt.Unlock()
if s.rto < minRTO {
s.rto = minRTO
}
}
// resendSegment resends the first unacknowledged segment.
// tcp的拥塞控制快速重传
func (s *sender) resendSegment() {
// Don't use any segments we already sent to measure RTT as they may
// have been affected by packets being lost.
s.rttMeasureSeqNum = s.sndNxt
// Resend the segment.
if seg := s.writeList.Front(); seg != nil {
s.sendSegment(seg.data, seg.flags, seg.sequenceNumber)
}
}
// retransmitTimerExpired is called when the retransmit timer expires, and
// unacknowledged segments are assumed lost, and thus need to be resent.
// Returns true if the connection is still usable, or false if the connection
// is deemed lost.
// tcp的可靠性重传定时器触发的时候调用这个函数也就是超时重传
// tcp的拥塞控制发生重传即认为发送丢包拥塞控制需要对丢包进行相应的处理。
func (s *sender) retransmitTimerExpired() bool {
// Check if the timer actually expired or if it's a spurious wake due
// to a previously orphaned runtime timer.
// 检查计时器是否真的到期
if !s.resendTimer.checkExpiration() {
return true
}
// Give up if we've waited more than a minute since the last resend.
// 如果rto已经超过了1分钟直接放弃发送返回错误
if s.rto >= 60*time.Second {
return false
}
// Set new timeout. The timer will be restarted by the call to sendData
// below.
// 每次超时rto都变成原来的2倍
s.rto *= 2
// tcp的拥塞控制如果已经是快速恢复阶段那么退出因为现在丢包了
if s.fr.active {
// We were attempting fast recovery but were not successful.
// Leave the state. We don't need to update ssthresh because it
// has already been updated when entered fast-recovery.
// 我们试图快速恢复,但没有成功,退出这个状态。
// 我们不需要更新ssthresh因为它在进入快速恢复时已经更新。
s.leaveFastRecovery()
}
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
// We store the highest sequence number transmitted in cases where
// we were not in fast recovery.
s.fr.last = s.sndNxt - 1
// tcp的拥塞控制处理丢包的情况
s.cc.HandleRTOExpired()
// Mark the next segment to be sent as the first unacknowledged one and
// start sending again. Set the number of outstanding packets to 0 so
// that we'll be able to retransmit.
//
// We'll keep on transmitting (or retransmitting) as we get acks for
// the data we transmit.
// tcp可靠性将下一个段标记为第一个未确认的段然后再次开始发送。将未完成的数据包数设置为0以便我们能够重新传输。
// 当我们收到我们传输的数据时,我们将继续传输(或重新传输)。
s.outstanding = 0
s.writeNext = s.writeList.Front()
// 重新发送数据包
s.sendData()
return true
}
// sendData sends new data segments. It is called when data becomes available or
// when the send window opens up.
// 发送数据段,最终掉用 sendSegment 来发送
func (s *sender) sendData() {
limit := s.maxPayloadSize
// Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
// "A TCP SHOULD set cwnd to no more than RW before beginning
// transmission if the TCP has not sent data in the interval exceeding
// the retrasmission timeout."
// 根据RFC 5681第10页将拥塞窗口减少到minIWcwnd
// 如果TCP在超过重新传输超时的时间间隔内没有发送数据TCP应该在开始传输之前将cwnd设置为不超过RW。
if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
if s.sndCwnd > InitialCwnd {
s.sndCwnd = InitialCwnd
}
}
// TODO: We currently don't merge multiple send buffers
// into one segment if they happen to fit. We should do that
// eventually.
var seg *segment
end := s.sndUna.Add(s.sndWnd)
var dataSent bool
// 遍历发送链表,发送数据
// tcp拥塞控制s.outstanding < s.sndCwnd 判断正在发送的数据量不能超过拥塞窗口。
for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
// We abuse the flags field to determine if we have already
// assigned a sequence number to this segment.
// 如果seg的flags是0将flags改为psh|ack
if seg.flags == 0 {
seg.sequenceNumber = s.sndNxt
seg.flags = flagAck | flagPsh
}
var segEnd seqnum.Value
if seg.data.Size() == 0 { // 数据段没有负载,表示要结束连接
if s.writeList.Back() != seg {
panic("FIN segments must be the final segment in the write list.")
}
// 发送 fin 报文
seg.flags = flagAck | flagFin
// fin 报文需要确认,且消耗一个字节序列号
segEnd = seg.sequenceNumber.Add(1)
} else {
// We're sending a non-FIN segment.
if seg.flags&flagFin != 0 {
panic("Netstack queues FIN segments without data.")
}
if !seg.sequenceNumber.LessThan(end) {
break
}
// tcp流量控制计算最多一次发送多大数据
available := int(seg.sequenceNumber.Size(end))
if available > limit {
available = limit
}
// 如果seg的payload字节数大于available
// 将seg进行分段并且插入到该seg的后面
if seg.data.Size() > available {
// Split this segment up.
nSeg := seg.clone()
nSeg.data.TrimFront(available)
nSeg.sequenceNumber.UpdateForward(seqnum.Size(available))
s.writeList.InsertAfter(seg, nSeg)
seg.data.CapLength(available)
}
s.outstanding++
segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
}
if !dataSent {
dataSent = true
// We are sending data, so we should stop the keepalive timer to
// ensure that no keepalives are sent while there is pending data.
s.ep.disableKeepaliveTimer()
}
s.sendSegment(seg.data, seg.flags, seg.sequenceNumber)
// Update sndNxt if we actually sent new data (as opposed to
// retransmitting some previously sent data).
// 发送一个数据段后更新sndNxt
if s.sndNxt.LessThan(segEnd) {
s.sndNxt = segEnd
}
}
// Remember the next segment we'll write.
s.writeNext = seg
// Enable the timer if we have pending data and it's not enabled yet.
// tcp的可靠性如果重发定时器没有启动 且 snduna 不等于 sndNxt启动定时器
if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
// 启动定时器并且设定定时器的间隔为s.rto
s.resendTimer.enable(s.rto)
}
// If we have no more pending data, start the keepalive timer.
if s.sndUna == s.sndNxt {
s.ep.resetKeepaliveTimer(false)
}
}
// tcp拥塞控制进入快速恢复状态和相应的处理
func (s *sender) enterFastRecovery() {
s.fr.active = true
// Save state to reflect we're now in fast recovery.
// See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
// We inflat the cwnd by 3 to account for the 3 packets which triggered
// the 3 duplicate ACKs and are now not in flight.
s.sndCwnd = s.sndSsthresh + 3
s.fr.first = s.sndUna
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = s.sndCwnd + s.outstanding
}
// tcp拥塞控制退出快速恢复状态和相应的处理
func (s *sender) leaveFastRecovery() {
s.fr.active = false
s.fr.first = 0
s.fr.last = s.sndNxt - 1
s.fr.maxCwnd = 0
s.dupAckCount = 0
// Deflate cwnd. It had been artificially inflated when new dups arrived.
s.sndCwnd = s.sndSsthresh
s.cc.PostRecovery()
}
// checkDuplicateAck is called when an ack is received. It manages the state
// related to duplicate acks and determines if a retransmit is needed according
// to the rules in RFC 6582 (NewReno).
// tcp拥塞控制收到确认时调用 checkDuplicateAck。它管理与重复确认相关的状态
// 并根据RFC 6582NewReno中的规则确定是否需要重新传输
func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
ack := seg.ackNumber
// 已经启动了快速恢复
if s.fr.active {
// We are in fast recovery mode. Ignore the ack if it's out of
// range.
if !ack.InRange(s.sndUna, s.sndNxt+1) {
return false
}
// Leave fast recovery if it acknowledges all the data covered by
// this fast recovery session.
// 如果它确认此快速恢复涵盖的所有数据,退出快速恢复。
if s.fr.last.LessThan(ack) {
s.leaveFastRecovery()
return false
}
// Don't count this as a duplicate if it is carrying data or
// updating the window.
// 如果该tcp段有负载或者正在更新窗口那么不计算这个ack
if seg.logicalLen() != 0 || s.sndWnd != seg.window {
return false
}
// Inflate the congestion window if we're getting duplicate acks
// for the packet we retransmitted.
if ack == s.fr.first {
// We received a dup, inflate the congestion window by 1
// packet if we're not at the max yet.
if s.sndCwnd < s.fr.maxCwnd {
s.sndCwnd++
}
return false
}
// A partial ack was received. Retransmit this packet and
// remember it so that we don't retransmit it again. We don't
// inflate the window because we're putting the same packet back
// onto the wire.
//
// N.B. The retransmit timer will be reset by the caller.
s.fr.first = ack
s.dupAckCount = 0
return true
}
// We're not in fast recovery yet. A segment is considered a duplicate
// only if it doesn't carry any data and doesn't update the send window,
// because if it does, it wasn't sent in response to an out-of-order
// segment.
// 我们还没有进入快速恢复状态,只有当段不携带任何数据并且不更新发送窗口时,才认为该段是重复的。
if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
s.dupAckCount = 0
return false
}
// 到这表示收到一个重复的ack
s.dupAckCount++
// 收到三次的重复ack才会进入快速恢复。
if s.dupAckCount < nDupAckThreshold {
return false
}
// See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
//
// We only do the check here, the incrementing of last to the highest
// sequence number transmitted till now is done when enterFastRecovery
// is invoked.
// 我们只在这里进行检查当调用enterFastRecovery时最后一次到最高序列号的递增完成。
if !s.fr.last.LessThan(seg.ackNumber) {
s.dupAckCount = 0
return false
}
// 调用拥塞控制的 HandleNDupAcks 处理三次重复ack
s.cc.HandleNDupAcks()
// 进入快速恢复状态
s.enterFastRecovery()
s.dupAckCount = 0
return true
}
// handleRcvdSegment is called when a segment is received; it is responsible for
// updating the send-related state.
//收到段时调用 handleRcvdSegment; 它负责更新与发送相关的状态。
func (s *sender) handleRcvdSegment(seg *segment) {
// Check if we can extract an RTT measurement from this ack.
// 如果rtt测量seq小于ack num更新rto
if !s.ep.sendTSOk && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
s.updateRTO(time.Now().Sub(s.rttMeasureTime))
s.rttMeasureSeqNum = s.sndNxt
}
// Update Timestamp if required. See RFC7323, section-4.3.
s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
// Count the duplicates and do the fast retransmit if needed.
// tcp的拥塞控制检查是否有重复的ack是否进入快速重传和快速恢复状态
rtx := s.checkDuplicateAck(seg)
// Stash away the current window size.
// 存放当前窗口大小。
s.sndWnd = seg.window
// Ignore ack if it doesn't acknowledge any new data.
// 获取确认号
ack := seg.ackNumber
// 如果ack在最小未确认的seq和下一seg的seq之间
if (ack - 1).InRange(s.sndUna, s.sndNxt) {
s.dupAckCount = 0
// When an ack is received we must reset the timer. We stop it
// here and it will be restarted later if needed.
s.resendTimer.disable()
// See : https://tools.ietf.org/html/rfc1323#section-3.3.
// Specifically we should only update the RTO using TSEcr if the
// following condition holds:
//
// A TSecr value received in a segment is used to update the
// averaged RTT measurement only if the segment acknowledges
// some new data, i.e., only if it advances the left edge of
// the send window.
if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
// TSVal/Ecr values sent by Netstack are at a millisecond
// granularity.
elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
s.updateRTO(elapsed)
}
// 获取这次确认的字节数,即 ack - snaUna
acked := s.sndUna.Size(ack)
// 更新下一个未确认的序列号
s.sndUna = ack
ackLeft := acked
originalOutstanding := s.outstanding
// 从发送链表中删除已经确认的数据,发送窗口的滑动。
for ackLeft > 0 {
// We use logicalLen here because we can have FIN
// segments (which are always at the end of list) that
// have no data, but do consume a sequence number.
seg := s.writeList.Front()
datalen := seg.logicalLen()
if datalen > ackLeft {
seg.data.TrimFront(int(ackLeft))
break
}
if s.writeNext == seg {
s.writeNext = seg.Next()
}
// 从发送链表中删除已确认的tcp段。
s.writeList.Remove(seg)
// 因为有一个tcp段确认了所以 outstanding 减1
s.outstanding--
seg.decRef()
ackLeft -= datalen
}
// Update the send buffer usage and notify potential waiters.
// 当收到ack确认时需要更新发送缓冲占用
s.ep.updateSndBufferUsage(int(acked))
// If we are not in fast recovery then update the congestion
// window based on the number of acknowledged packets.
// tcp拥塞控制如果没有进入快速恢复状态那么根据确认的数据包的数量更新拥塞窗口。
if !s.fr.active {
// 调用相应拥塞控制算法的 Update
s.cc.Update(originalOutstanding - s.outstanding)
}
// It is possible for s.outstanding to drop below zero if we get
// a retransmit timeout, reset outstanding to zero but later
// get an ack that cover previously sent data.
// 如果发生超时重传时s.outstanding可能会降到零以下
// 重置为零但后来得到一个覆盖先前发送数据的确认。
if s.outstanding < 0 {
s.outstanding = 0
}
}
// Now that we've popped all acknowledged data from the retransmit
// queue, retransmit if needed.
// tcp拥塞控制如果需要快速重传则重传数据但是只是重传下一个数据
if rtx {
// tcp拥塞控制快速重传
s.resendSegment()
}
// Send more data now that some of the pending data has been ack'd, or
// that the window opened up, or the congestion window was inflated due
// to a duplicate ack during fast recovery. This will also re-enable
// the retransmit timer if needed.
// 现在某些待处理数据已被确认或者窗口打开或者由于快速恢复期间出现重复的ack而导致拥塞窗口膨胀
// 因此发送更多数据。如果需要,这也将重新启用重传计时器。
s.sendData()
}
// sendSegment sends a new segment containing the given payload, flags and
// sequence number.
// 根据给定的参数负载数据、flags标记和序列号来发送数据
func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
s.lastSendTime = time.Now()
if seq == s.rttMeasureSeqNum {
s.rttMeasureTime = s.lastSendTime
}
rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
// Remember the max sent ack.
s.maxSentAck = rcvNxt
return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
}

View File

@@ -0,0 +1,350 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp_test
import (
"fmt"
"reflect"
"testing"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/tcpip/transport/tcp"
"tcpip/netstack/tcpip/transport/tcp/testing/context"
)
// createConnectWithSACKPermittedOption creates and connects c.ep with the
// SACKPermitted option enabled if the stack in the context has the SACK support
// enabled.
func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
}
func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
t.Helper()
if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
}
}
// TestSackPermittedConnect establishes a connection with the SACK option
// enabled.
func TestSackPermittedConnect(t *testing.T) {
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
setStackSACKPermitted(t, c, sackEnabled)
rep := createConnectedWithSACKPermittedOption(c)
data := []byte{1, 2, 3}
rep.SendPacket(data, nil)
savedSeqNum := rep.NextSeqNum
rep.VerifyACKNoSACK()
// Make an out of order packet and send it.
rep.NextSeqNum += 3
sackBlocks := []header.SACKBlock{
{rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
}
rep.SendPacket(data, nil)
// Restore the saved sequence number so that the
// VerifyXXX calls use the right sequence number for
// checking ACK numbers.
rep.NextSeqNum = savedSeqNum
if sackEnabled {
rep.VerifyACKHasSACK(sackBlocks)
} else {
rep.VerifyACKNoSACK()
}
// Send the missing segment.
rep.SendPacket(data, nil)
// The ACK should contain the cumulative ACK for all 9
// bytes sent and no SACK blocks.
rep.NextSeqNum += 3
// Check that no SACK block is returned in the ACK.
rep.VerifyACKNoSACK()
})
}
}
// TestSackDisabledConnect establishes a connection with the SACK option
// disabled and verifies that no SACKs are sent for out of order segments.
func TestSackDisabledConnect(t *testing.T) {
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
setStackSACKPermitted(t, c, sackEnabled)
rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
data := []byte{1, 2, 3}
rep.SendPacket(data, nil)
savedSeqNum := rep.NextSeqNum
rep.VerifyACKNoSACK()
// Make an out of order packet and send it.
rep.NextSeqNum += 3
rep.SendPacket(data, nil)
// The ACK should contain the older sequence number and
// no SACK blocks.
rep.NextSeqNum = savedSeqNum
rep.VerifyACKNoSACK()
// Send the missing segment.
rep.SendPacket(data, nil)
// The ACK should contain the cumulative ACK for all 9
// bytes sent and no SACK blocks.
rep.NextSeqNum += 3
// Check that no SACK block is returned in the ACK.
rep.VerifyACKNoSACK()
})
}
}
// TestSackPermittedAccept accepts and establishes a connection with the
// SACKPermitted option enabled if the connection request specifies the
// SACKPermitted option. In case of SYN cookies SACK should be disabled as we
// don't encode the SACK information in the cookie.
func TestSackPermittedAccept(t *testing.T) {
type testCase struct {
cookieEnabled bool
sackPermitted bool
wndScale int
wndSize uint16
}
testCases := []testCase{
// When cookie is used window scaling is disabled.
{true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
savedSynCountThreshold := tcp.SynRcvdCountThreshold
defer func() {
tcp.SynRcvdCountThreshold = savedSynCountThreshold
}()
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
if tc.cookieEnabled {
tcp.SynRcvdCountThreshold = 0
} else {
tcp.SynRcvdCountThreshold = savedSynCountThreshold
}
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
// Now verify no SACK blocks are
// received when sack is disabled.
data := []byte{1, 2, 3}
rep.SendPacket(data, nil)
rep.VerifyACKNoSACK()
savedSeqNum := rep.NextSeqNum
// Make an out of order packet and send
// it.
rep.NextSeqNum += 3
sackBlocks := []header.SACKBlock{
{rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
}
rep.SendPacket(data, nil)
// The ACK should contain the older
// sequence number.
rep.NextSeqNum = savedSeqNum
if sackEnabled && tc.sackPermitted {
rep.VerifyACKHasSACK(sackBlocks)
} else {
rep.VerifyACKNoSACK()
}
// Send the missing segment.
rep.SendPacket(data, nil)
// The ACK should contain the cumulative
// ACK for all 9 bytes sent and no SACK
// blocks.
rep.NextSeqNum += 3
// Check that no SACK block is returned
// in the ACK.
rep.VerifyACKNoSACK()
})
}
})
}
}
// TestSackDisabledAccept accepts and establishes a connection with
// the SACKPermitted option disabled and verifies that no SACKs are
// sent for out of order packets.
func TestSackDisabledAccept(t *testing.T) {
type testCase struct {
cookieEnabled bool
wndScale int
wndSize uint16
}
testCases := []testCase{
// When cookie is used window scaling is disabled.
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
savedSynCountThreshold := tcp.SynRcvdCountThreshold
defer func() {
tcp.SynRcvdCountThreshold = savedSynCountThreshold
}()
for _, tc := range testCases {
t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
if tc.cookieEnabled {
tcp.SynRcvdCountThreshold = 0
} else {
tcp.SynRcvdCountThreshold = savedSynCountThreshold
}
for _, sackEnabled := range []bool{false, true} {
t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
setStackSACKPermitted(t, c, sackEnabled)
rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now verify no SACK blocks are
// received when sack is disabled.
data := []byte{1, 2, 3}
rep.SendPacket(data, nil)
rep.VerifyACKNoSACK()
savedSeqNum := rep.NextSeqNum
// Make an out of order packet and send
// it.
rep.NextSeqNum += 3
rep.SendPacket(data, nil)
// The ACK should contain the older
// sequence number and no SACK blocks.
rep.NextSeqNum = savedSeqNum
rep.VerifyACKNoSACK()
// Send the missing segment.
rep.SendPacket(data, nil)
// The ACK should contain the cumulative
// ACK for all 9 bytes sent and no SACK
// blocks.
rep.NextSeqNum += 3
// Check that no SACK block is returned
// in the ACK.
rep.VerifyACKNoSACK()
})
}
})
}
}
func TestUpdateSACKBlocks(t *testing.T) {
testCases := []struct {
segStart seqnum.Value
segEnd seqnum.Value
rcvNxt seqnum.Value
sackBlocks []header.SACKBlock
updated []header.SACKBlock
}{
// Trivial cases where current SACK block list is empty and we
// have an out of order delivery.
{10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
{10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
{10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
// Cases where current SACK block list is not empty and we have
// an out of order delivery. Tests that the updated SACK block
// list has the first block as the one that contains the new
// SACK block representing the segment that was just delivered.
{10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
{24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
{24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
// Ensure that we only retain header.MaxSACKBlocks and drop the
// oldest one if adding a new block exceeds
// header.MaxSACKBlocks.
{24, 30, 9,
[]header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
[]header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
// Cases where segment extends an existing SACK block.
{10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
{10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
{10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
{15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
{15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
{11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
{10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
{10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
{10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
{15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
{15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
{11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
// Cases where segment contains rcvNxt.
{10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
}
for _, tc := range testCases {
var sack tcp.SACKInfo
copy(sack.Blocks[:], tc.sackBlocks)
sack.NumBlocks = len(tc.sackBlocks)
tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) {
t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
}
}
}
func TestTrimSackBlockList(t *testing.T) {
testCases := []struct {
rcvNxt seqnum.Value
sackBlocks []header.SACKBlock
trimmed []header.SACKBlock
}{
// Simple cases where we trim whole entries.
{2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
{21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
{31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
{40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
// Cases where we need to update a block.
{12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
{23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
{33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
{41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
}
for _, tc := range testCases {
var sack tcp.SACKInfo
copy(sack.Blocks[:], tc.sackBlocks)
sack.NumBlocks = len(tc.sackBlocks)
tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) {
t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
}
}
}

View File

@@ -0,0 +1,173 @@
package tcp
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type segmentElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
type segmentList struct {
head *segment
tail *segment
}
// Reset resets list l to the empty state.
func (l *segmentList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
func (l *segmentList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
func (l *segmentList) Front() *segment {
return l.head
}
// Back returns the last element of list l or nil.
func (l *segmentList) Back() *segment {
return l.tail
}
// PushFront inserts the element e at the front of list l.
func (l *segmentList) PushFront(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(l.head)
segmentElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushBack inserts the element e at the back of list l.
func (l *segmentList) PushBack(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(nil)
segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
func (l *segmentList) PushBackList(m *segmentList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *segmentList) InsertAfter(b, e *segment) {
a := segmentElementMapper{}.linkerFor(b).Next()
segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
segmentElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
func (l *segmentList) InsertBefore(a, e *segment) {
b := segmentElementMapper{}.linkerFor(a).Prev()
segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
segmentElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
func (l *segmentList) Remove(e *segment) {
prev := segmentElementMapper{}.linkerFor(e).Prev()
next := segmentElementMapper{}.linkerFor(e).Next()
if prev != nil {
segmentElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
segmentElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type segmentEntry struct {
next *segment
prev *segment
}
// Next returns the entry that follows e in the list.
func (e *segmentEntry) Next() *segment {
return e.next
}
// Prev returns the entry that precedes e in the list.
func (e *segmentEntry) Prev() *segment {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
func (e *segmentEntry) SetNext(elem *segment) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *segmentEntry) SetPrev(elem *segment) {
e.prev = elem
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,316 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp_test
import (
"bytes"
"math/rand"
"testing"
"time"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/checker"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/transport/tcp"
"tcpip/netstack/tcpip/transport/tcp/testing/context"
"tcpip/netstack/waiter"
)
// createConnectedWithTimestampOption creates and connects c.ep with the
// timestamp option enabled.
func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
}
// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
// an active connect and sets the TS Echo Reply fields correctly when the
// SYN-ACK also indicates support for the TS option and provides a TSVal.
func TestTimeStampEnabledConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
rep := createConnectedWithTimestampOption(c)
// Register for read and validate that we have data to read.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
// The following tests ensure that TS option once enabled behaves
// correctly as described in
// https://tools.ietf.org/html/rfc7323#section-4.3.
//
// We are not testing delayed ACKs here, but we do test out of order
// packet delivery and filling the sequence number hole created due to
// the out of order packet.
//
// The test also verifies that the sequence numbers and timestamps are
// as expected.
data := []byte{1, 2, 3}
// First we increment tsVal by a small amount.
tsVal := rep.TSVal + 100
rep.SendPacketWithTS(data, tsVal)
rep.VerifyACKWithTS(tsVal)
// Next we send an out of order packet.
rep.NextSeqNum += 3
tsVal += 200
rep.SendPacketWithTS(data, tsVal)
// The ACK should contain the original sequenceNumber and an older TS.
rep.NextSeqNum -= 6
rep.VerifyACKWithTS(tsVal - 200)
// Next we fill the hole and the returned ACK should contain the
// cumulative sequence number acking all data sent till now and have the
// latest timestamp sent below in its TSEcr field.
tsVal -= 100
rep.SendPacketWithTS(data, tsVal)
rep.NextSeqNum += 3
rep.VerifyACKWithTS(tsVal)
// Increment tsVal by a large value that doesn't result in a wrap around.
tsVal += 0x7fffffff
rep.SendPacketWithTS(data, tsVal)
rep.VerifyACKWithTS(tsVal)
// Increment tsVal again by a large value which should cause the
// timestamp value to wrap around. The returned ACK should contain the
// wrapped around timestamp in its tsEcr field and not the tsVal from
// the previous packet sent above.
tsVal += 0x7fffffff
rep.SendPacketWithTS(data, tsVal)
rep.VerifyACKWithTS(tsVal)
select {
case <-ch:
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for data to arrive")
}
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
got, _, err := c.EP.Read(nil)
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
if want := data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
}
// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an
// active connect but if the SYN-ACK doesn't specify the TS option then
// timestamp option is not enabled and future packets do not contain a
// timestamp.
func TestTimeStampDisabledConnect(t *testing.T) {
c := context.New(t, defaultMTU)
defer c.Cleanup()
c.CreateConnectedWithOptions(header.TCPSynOptions{})
}
func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
savedSynCountThreshold := tcp.SynRcvdCountThreshold
defer func() {
tcp.SynRcvdCountThreshold = savedSynCountThreshold
}()
if cookieEnabled {
tcp.SynRcvdCountThreshold = 0
}
c := context.New(t, defaultMTU)
defer c.Cleanup()
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
tsVal := rand.Uint32()
c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
// Now send some data and validate that timestamp is echoed correctly in the ACK.
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %v", err)
}
// Check that data is received and that the timestamp option TSEcr field
// matches the expected value.
b := c.GetPacket()
checker.IPv4(t, b,
// Add 12 bytes for the timestamp option + 2 NOPs to align at 4
// byte boundary.
checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(790),
checker.Window(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(true, 0, tsVal+1),
),
)
}
// TestTimeStampEnabledAccept tests that if the SYN on a passive connect
// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK
// and echoes the tsVal field of the original SYN in the tcEcr field of the
// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify
// that Timestamp option is enabled in both cases if requested in the original
// SYN.
func TestTimeStampEnabledAccept(t *testing.T) {
testCases := []struct {
cookieEnabled bool
wndScale int
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
for _, tc := range testCases {
timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
}
}
func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
savedSynCountThreshold := tcp.SynRcvdCountThreshold
defer func() {
tcp.SynRcvdCountThreshold = savedSynCountThreshold
}()
if cookieEnabled {
tcp.SynRcvdCountThreshold = 0
}
c := context.New(t, defaultMTU)
defer c.Cleanup()
t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
// Now send some data with the accepted connection endpoint and validate
// that no timestamp option is sent in the TCP segment.
data := []byte{1, 2, 3}
view := buffer.NewView(len(data))
copy(view, data)
if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
t.Fatalf("Unexpected error from Write: %v", err)
}
// Check that data is received and that the timestamp option is disabled
// when SYN cookies are enabled/disabled.
b := c.GetPacket()
checker.IPv4(t, b,
checker.PayloadLen(len(data)+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(context.TestPort),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(790),
checker.Window(wndSize),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
checker.TCPTimestampChecker(false, 0, 0),
),
)
}
// TestTimeStampDisabledAccept tests that Timestamp option is not used when the
// peer doesn't advertise it and connection is established with Accept().
func TestTimeStampDisabledAccept(t *testing.T) {
testCases := []struct {
cookieEnabled bool
wndScale int
wndSize uint16
}{
{true, -1, 0xffff}, // When cookie is used window scaling is disabled.
{false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
}
for _, tc := range testCases {
timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
}
}
func TestSendGreaterThanMTUWithOptions(t *testing.T) {
const maxPayload = 100
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
createConnectedWithTimestampOption(c)
testBrokenUpWrite(t, c, maxPayload)
}
func TestSegmentDropWhenTimestampMissing(t *testing.T) {
const maxPayload = 100
c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
defer c.Cleanup()
rep := createConnectedWithTimestampOption(c)
// Register for read.
we, ch := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
droppedPacketsStat := c.Stack().Stats().DroppedPackets
droppedPackets := droppedPacketsStat.Value()
data := []byte{1, 2, 3}
// Save the sequence number as we will reset it later down
// in the test.
savedSeqNum := rep.NextSeqNum
rep.SendPacket(data, nil)
select {
case <-ch:
t.Fatalf("Got data to read when we expect packet to be dropped")
case <-time.After(1 * time.Second):
// We expect that no data will be available to read.
}
// Assert that DroppedPackets was incremented by 1.
if got, want := droppedPacketsStat.Value(), droppedPackets+1; got != want {
t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
}
droppedPackets = droppedPacketsStat.Value()
// Reset the sequence number so that the other endpoint accepts
// this segment and does not treat it like an out of order delivery.
rep.NextSeqNum = savedSeqNum
// Now send a packet with timestamp and we should get the data.
rep.SendPacketWithTS(data, rep.TSVal+1)
select {
case <-ch:
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for data to arrive")
}
// Assert that DroppedPackets was not incremented by 1.
if got, want := droppedPacketsStat.Value(), droppedPackets; got != want {
t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
}
// Issue a read and we should data.
got, _, err := c.EP.Read(nil)
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
if want := data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}

View File

@@ -0,0 +1,960 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package context provides a test context for use in tcp tests. It also
// provides helper methods to assert/check certain behaviours.
package context
import (
"bytes"
"testing"
"time"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/checker"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/link/channel"
"tcpip/netstack/tcpip/network/ipv4"
"tcpip/netstack/tcpip/network/ipv6"
"tcpip/netstack/tcpip/seqnum"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/tcpip/transport/tcp"
"tcpip/netstack/waiter"
)
const (
// StackAddr is the IPv4 address assigned to the stack.
StackAddr = "\x0a\x00\x00\x01"
// StackPort is used as the listening port in tests for passive
// connects.
StackPort = 1234
// TestAddr is the source address for packets sent to the stack via the
// link layer endpoint.
TestAddr = "\x0a\x00\x00\x02"
// TestPort is the TCP port used for packets sent to the stack
// via the link layer endpoint.
TestPort = 4096
// StackV6Addr is the IPv6 address assigned to the stack.
StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
// TestV6Addr is the source address for packets sent to the stack via
// the link layer endpoint.
TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
// StackV4MappedAddr is StackAddr as a mapped v6 address.
StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
// TestV4MappedAddr is TestAddr as a mapped v6 address.
TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr
// V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
// testInitialSequenceNumber is the initial sequence number sent in packets that
// are sent in response to a SYN or in the initial SYN sent to the stack.
testInitialSequenceNumber = 789
)
// defaultWindowScale value specified here depends on the tcp.DefaultBufferSize
// constant defined in the tcp/endpoint.go because the tcp.DefaultBufferSize is
// used in tcp.newHandshake to determine the window scale to use when sending a
// SYN/SYN-ACK.
var defaultWindowScale = tcp.FindWndScale(tcp.DefaultBufferSize)
// Headers is used to represent the TCP header fields when building a
// new packet.
type Headers struct {
// SrcPort holds the src port value to be used in the packet.
SrcPort uint16
// DstPort holds the destination port value to be used in the packet.
DstPort uint16
// SeqNum is the value of the sequence number field in the TCP header.
SeqNum seqnum.Value
// AckNum represents the acknowledgement number field in the TCP header.
AckNum seqnum.Value
// Flags are the TCP flags in the TCP header.
Flags int
// RcvWnd is the window to be advertised in the ReceiveWindow field of
// the TCP header.
RcvWnd seqnum.Size
// TCPOpts holds the options to be sent in the option field of the TCP
// header.
TCPOpts []byte
}
// Context provides an initialized Network stack and a link layer endpoint
// for use in TCP tests.
type Context struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
// IRS holds the initial sequence number in the SYN sent by endpoint in
// case of an active connect or the sequence number sent by the endpoint
// in the SYN-ACK sent in response to a SYN when listening in passive
// mode.
IRS seqnum.Value
// Port holds the port bound by EP below in case of an active connect or
// the listening port number in case of a passive connect.
Port uint16
// EP is the test endpoint in the stack owned by this context. This endpoint
// is used in various tests to either initiate an active connect or is used
// as a passive listening endpoint to accept inbound connections.
EP tcpip.Endpoint
// Wq is the wait queue associated with EP and is used to block for events
// on EP.
WQ waiter.Queue
// TimeStampEnabled is true if ep is connected with the timestamp option
// enabled.
TimeStampEnabled bool
}
// New allocates and initializes a test context containing a new
// stack and a link-layer endpoint.
func New(t *testing.T, mtu uint32) *Context {
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}, stack.Options{})
// Allow minimum send/receive buffer sizes to be 1 during tests.
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
t.Fatalf("SetTransportProtocolOption failed: %v", err)
}
// Some of the congestion control tests send up to 640 packets, we so
// set the channel size to 1000.
id, linkEP := channel.New(1000, mtu, "")
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: "\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
},
{
Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
},
})
return &Context{
t: t,
s: s,
linkEP: linkEP,
}
}
// Cleanup closes the context endpoint if required.
func (c *Context) Cleanup() {
if c.EP != nil {
c.EP.Close()
}
}
// Stack returns a reference to the stack in the Context.
func (c *Context) Stack() *stack.Stack {
return c.s
}
// CheckNoPacketTimeout verifies that no packet is received during the time
// specified by wait.
func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
c.t.Helper()
select {
case <-c.linkEP.C:
c.t.Fatal(errMsg)
case <-time.After(wait):
}
}
// CheckNoPacket verifies that no packet is received for 1 second.
func (c *Context) CheckNoPacket(errMsg string) {
c.CheckNoPacketTimeout(errMsg, 1*time.Second)
}
// GetPacket reads a packet from the link layer endpoint and verifies
// that it is an IPv4 packet with the expected source and destination
// addresses. It will fail with an error if no packet is received for
// 2 seconds.
func (c *Context) GetPacket() []byte {
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
case <-time.After(2 * time.Second):
c.t.Fatalf("Packet wasn't written out")
}
return nil
}
// GetPacketNonBlocking reads a packet from the link layer endpoint
// and verifies that it is an IPv4 packet with the expected source
// and destination address. If no packet is available it will return
// nil immediately.
func (c *Context) GetPacketNonBlocking() []byte {
select {
case p := <-c.linkEP.C:
if p.Proto != ipv4.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
return b
default:
return nil
}
}
// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
// Allocate a buffer data and headers.
buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2))
if len(buf) > maxTotalSize {
buf = buf[:maxTotalSize]
}
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: TestAddr,
DstAddr: StackAddr,
})
ip.SetChecksum(^ip.CalculateChecksum())
icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
icmp.SetType(typ)
icmp.SetCode(code)
copy(icmp[header.ICMPv4MinimumSize:], p1)
copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2)
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
}
// BuildSegment builds a TCP segment based on the given Headers and payload.
func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
copy(buf[len(buf)-len(payload):], payload)
copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts)
// Initialize the IP header.
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(tcp.ProtocolNumber),
SrcAddr: TestAddr,
DstAddr: StackAddr,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the TCP header.
t := header.TCP(buf[header.IPv4MinimumSize:])
t.Encode(&header.TCPFields{
SrcPort: h.SrcPort,
DstPort: h.DstPort,
SeqNum: uint32(h.SeqNum),
AckNum: uint32(h.AckNum),
DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
Flags: uint8(h.Flags),
WindowSize: uint16(h.RcvWnd),
})
// Calculate the TCP pseudo-header checksum.
xsum := header.Checksum([]byte(TestAddr), 0)
xsum = header.Checksum([]byte(StackAddr), xsum)
xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum)
// Calculate the TCP checksum and set it.
length := uint16(header.TCPMinimumSize + len(h.TCPOpts) + len(payload))
xsum = header.Checksum(payload, xsum)
t.SetChecksum(^t.CalculateChecksum(xsum, length))
// Inject packet.
return buf.ToVectorisedView()
}
// SendSegment sends a TCP segment that has already been built and written to a
// buffer.VectorisedView.
func (c *Context) SendSegment(s buffer.VectorisedView) {
c.linkEP.Inject(ipv4.ProtocolNumber, s)
}
// SendPacket builds and sends a TCP segment(with the provided payload & TCP
// headers) in an IPv4 packet via the link layer endpoint.
func (c *Context) SendPacket(payload []byte, h *Headers) {
c.linkEP.Inject(ipv4.ProtocolNumber, c.BuildSegment(payload, h))
}
// SendAck sends an ACK packet.
func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
c.SendPacket(nil, &Headers{
SrcPort: TestPort,
DstPort: c.Port,
Flags: header.TCPFlagAck,
SeqNum: seqnum.Value(testInitialSequenceNumber).Add(1),
AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
RcvWnd: 30000,
})
}
// ReceiveAndCheckPacket reads a packet from the link layer endpoint and
// verifies that the packet packet payload of packet matches the slice
// of data indicated by offset & size.
func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
b := c.GetPacket()
checker.IPv4(c.t, b,
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(TestPort),
checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
pdata := data[offset:][:size]
if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
}
}
// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint
// and verifies that the packet packet payload of packet matches the slice of
// data indicated by offset & size. It returns true if a packet was received and
// processed.
func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool {
b := c.GetPacketNonBlocking()
if b == nil {
return false
}
checker.IPv4(c.t, b,
checker.PayloadLen(size+header.TCPMinimumSize),
checker.TCP(
checker.DstPort(TestPort),
checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
),
)
pdata := data[offset:][:size]
if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
}
return true
}
// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
// only endpoint instead of a default dual stack socket.
func (c *Context) CreateV6Endpoint(v6only bool) {
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
var v tcpip.V6OnlyOption
if v6only {
v = 1
}
if err := c.EP.SetSockOpt(v); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
// GetV6Packet reads a single packet from the link layer endpoint of the context
// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
func (c *Context) GetV6Packet() []byte {
select {
case p := <-c.linkEP.C:
if p.Proto != ipv6.ProtocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
return b
case <-time.After(2 * time.Second):
c.t.Fatalf("Packet wasn't written out")
}
return nil
}
// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
// the context.
func (c *Context) SendV6Packet(payload []byte, h *Headers) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
NextHeader: uint8(tcp.ProtocolNumber),
HopLimit: 65,
SrcAddr: TestV6Addr,
DstAddr: StackV6Addr,
})
// Initialize the TCP header.
t := header.TCP(buf[header.IPv6MinimumSize:])
t.Encode(&header.TCPFields{
SrcPort: h.SrcPort,
DstPort: h.DstPort,
SeqNum: uint32(h.SeqNum),
AckNum: uint32(h.AckNum),
DataOffset: header.TCPMinimumSize,
Flags: uint8(h.Flags),
WindowSize: uint16(h.RcvWnd),
})
// Calculate the TCP pseudo-header checksum.
xsum := header.Checksum([]byte(TestV6Addr), 0)
xsum = header.Checksum([]byte(StackV6Addr), xsum)
xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum)
// Calculate the TCP checksum and set it.
length := uint16(header.TCPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
t.SetChecksum(^t.CalculateChecksum(xsum, length))
// Inject packet.
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
// CreateConnected creates a connected TCP endpoint.
func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
}
// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
// the specified option bytes as the Option field in the initial SYN packet.
//
// It also sets the receive buffer for the endpoint to the specified
// value in epRcvBuf.
func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
// Create TCP endpoint.
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
if epRcvBuf != nil {
if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
if err != tcpip.ErrConnectStarted {
c.t.Fatalf("Unexpected return value from Connect: %v", err)
}
// Receive SYN packet.
b := c.GetPacket()
checker.IPv4(c.t, b,
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagSyn),
),
)
tcp := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
c.SendPacket(nil, &Headers{
SrcPort: tcp.DestinationPort(),
DstPort: tcp.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: rcvWnd,
TCPOpts: options,
})
// Receive ACK packet.
checker.IPv4(c.t, c.GetPacket(),
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
checker.SeqNum(uint32(c.IRS)+1),
checker.AckNum(uint32(iss)+1),
),
)
// Wait for connection to be established.
select {
case <-notifyCh:
err = c.EP.GetSockOpt(tcpip.ErrorOption{})
if err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for connection")
}
c.Port = tcp.SourcePort()
}
// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
// sending data and ACK packets easy while being able to manipulate the sequence
// numbers and timestamp values as needed.
type RawEndpoint struct {
C *Context
SrcPort uint16
DstPort uint16
Flags int
NextSeqNum seqnum.Value
AckNum seqnum.Value
WndSize seqnum.Size
RecentTS uint32 // Stores the latest timestamp to echo back.
TSVal uint32 // TSVal stores the last timestamp sent by this endpoint.
// SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
SACKPermitted bool
}
// SendPacketWithTS embeds the provided tsVal in the Timestamp option
// for the packet to be sent out.
func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
r.TSVal = tsVal
tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
r.SendPacket(payload, tsOpt[:])
}
// SendPacket is a small wrapper function to build and send packets.
func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
packetHeaders := &Headers{
SrcPort: r.SrcPort,
DstPort: r.DstPort,
Flags: r.Flags,
SeqNum: r.NextSeqNum,
AckNum: r.AckNum,
RcvWnd: r.WndSize,
TCPOpts: opts,
}
r.C.SendPacket(payload, packetHeaders)
r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
}
// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
// tsVal.
func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
// Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
checker.SeqNum(uint32(r.AckNum)),
checker.AckNum(uint32(r.NextSeqNum)),
checker.TCPTimestampChecker(true, 0, tsVal),
),
)
// Store the parsed TSVal from the ack as recentTS.
tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
opts := tcpSeg.ParsedOptions()
r.RecentTS = opts.TSVal
}
// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
func (r *RawEndpoint) VerifyACKNoSACK() {
r.VerifyACKHasSACK(nil)
}
// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
// Read ACK and verify that the TCP options in the segment do
// not contain a SACK block.
ackPacket := r.C.GetPacket()
checker.IPv4(r.C.t, ackPacket,
checker.TCP(
checker.DstPort(r.SrcPort),
checker.TCPFlags(header.TCPFlagAck),
checker.SeqNum(uint32(r.AckNum)),
checker.AckNum(uint32(r.NextSeqNum)),
checker.TCPSACKBlockChecker(sackBlocks),
),
)
}
// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
// options enabled and returns a RawEndpoint which represents the other end of
// the connection.
//
// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
// does not carry an option that was not requested.
func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
var err *tcpip.Error
c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
if err != nil {
c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
}
// Start connection attempt.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
c.WQ.EventRegister(&waitEntry, waiter.EventOut)
defer c.WQ.EventUnregister(&waitEntry)
testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
err = c.EP.Connect(testFullAddr)
if err != tcpip.ErrConnectStarted {
c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
}
// Receive SYN packet.
b := c.GetPacket()
// Validate that the syn has the timestamp option and a valid
// TS value.
checker.IPv4(c.t, b,
checker.TCP(
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagSyn),
checker.TCPSynOptions(header.TCPSynOptions{
MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize),
TS: true,
WS: defaultWindowScale,
SACKPermitted: c.SACKEnabled(),
}),
),
)
tcpSeg := header.TCP(header.IPv4(b).Payload())
synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
// Build options w/ tsVal to be sent in the SYN-ACK.
synAckOptions := make([]byte, 40)
offset := 0
if wantOptions.TS {
offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
}
if wantOptions.SACKPermitted {
offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
}
offset += header.AddTCPOptionPadding(synAckOptions, offset)
// Build SYN-ACK.
c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
iss := seqnum.Value(testInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
Flags: header.TCPFlagSyn | header.TCPFlagAck,
SeqNum: iss,
AckNum: c.IRS.Add(1),
RcvWnd: 30000,
TCPOpts: synAckOptions[:offset],
})
// Read ACK.
ackPacket := c.GetPacket()
// Verify TCP header fields.
tcpCheckers := []checker.TransportChecker{
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck),
checker.SeqNum(uint32(c.IRS) + 1),
checker.AckNum(uint32(iss) + 1),
}
// Verify that tsEcr of ACK packet is wantOptions.TSVal if the
// timestamp option was enabled, if not then we verify that
// there is no timestamp in the ACK packet.
if wantOptions.TS {
tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal))
} else {
tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
}
checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...))
ackSeg := header.TCP(header.IPv4(ackPacket).Payload())
ackOptions := ackSeg.ParsedOptions()
// Wait for connection to be established.
select {
case <-notifyCh:
err = c.EP.GetSockOpt(tcpip.ErrorOption{})
if err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for connection")
}
// Store the source port in use by the endpoint.
c.Port = tcpSeg.SourcePort()
// Mark in context that timestamp option is enabled for this endpoint.
c.TimeStampEnabled = true
return &RawEndpoint{
C: c,
SrcPort: tcpSeg.DestinationPort(),
DstPort: tcpSeg.SourcePort(),
Flags: header.TCPFlagAck | header.TCPFlagPsh,
NextSeqNum: iss + 1,
AckNum: c.IRS.Add(1),
WndSize: 30000,
RecentTS: ackOptions.TSVal,
TSVal: wantOptions.TSVal,
SACKPermitted: wantOptions.SACKPermitted,
}
}
// AcceptWithOptions initializes a listening endpoint and connects to it with the
// provided options enabled. It also verifies that the SYN-ACK has the expected
// values for the provided options.
//
// The function returns a RawEndpoint representing the other end of the accepted
// endpoint.
func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
// Create EP and start listening.
wq := &waiter.Queue{}
ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: StackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
if err := ep.Listen(10); err != nil {
c.t.Fatalf("Listen failed: %v", err)
}
rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
// Try to accept the connection.
we, ch := waiter.NewChannelEntry(nil)
wq.EventRegister(&we, waiter.EventIn)
defer wq.EventUnregister(&we)
c.EP, _, err = ep.Accept()
if err == tcpip.ErrWouldBlock {
// Wait for connection to be established.
select {
case <-ch:
c.EP, _, err = ep.Accept()
if err != nil {
c.t.Fatalf("Accept failed: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for accept")
}
}
return rep
}
// PassiveConnect just disables WindowScaling and delegates the call to
// PassiveConnectWithOptions.
func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
synOptions.WS = -1
c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
}
// PassiveConnectWithOptions initiates a new connection (with the specified TCP
// options enabled) to the port on which the Context.ep is listening for new
// connections. It also validates that the SYN-ACK has the expected values for
// the enabled options.
//
// NOTE: MSS is not a negotiated option and it can be asymmetric
// in each direction. This function uses the maxPayload to set the MSS to be
// sent to the peer on a connect and validates that the MSS in the SYN-ACK
// response is equal to the MTU - (tcphdr len + iphdr len).
//
// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
// value of the window scaling option to be sent in the SYN. If synOptions.WS >
// 0 then we send the WindowScale option.
func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
opts := make([]byte, 40)
offset := 0
offset += header.EncodeMSSOption(uint32(maxPayload), opts)
if synOptions.WS >= 0 {
offset += header.EncodeWSOption(3, opts[offset:])
}
if synOptions.TS {
offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
}
if synOptions.SACKPermitted {
offset += header.EncodeSACKPermittedOption(opts[offset:])
}
paddingToAdd := 4 - offset%4
// Now add any padding bytes that might be required to quad align the
// options.
for i := offset; i < offset+paddingToAdd; i++ {
opts[i] = header.TCPOptionNOP
}
offset += paddingToAdd
// Send a SYN request.
iss := seqnum.Value(testInitialSequenceNumber)
c.SendPacket(nil, &Headers{
SrcPort: TestPort,
DstPort: StackPort,
Flags: header.TCPFlagSyn,
SeqNum: iss,
RcvWnd: 30000,
TCPOpts: opts[:offset],
})
// Receive the SYN-ACK reply. Make sure MSS and other expected options
// are present.
b := c.GetPacket()
tcp := header.TCP(header.IPv4(b).Payload())
c.IRS = seqnum.Value(tcp.SequenceNumber())
tcpCheckers := []checker.TransportChecker{
checker.SrcPort(StackPort),
checker.DstPort(TestPort),
checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
checker.AckNum(uint32(iss) + 1),
checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
}
// If TS option was enabled in the original SYN then add a checker to
// validate the Timestamp option in the SYN-ACK.
if synOptions.TS {
tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal))
} else {
tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
}
checker.IPv4(c.t, b, checker.TCP(tcpCheckers...))
rcvWnd := seqnum.Size(30000)
ackHeaders := &Headers{
SrcPort: TestPort,
DstPort: StackPort,
Flags: header.TCPFlagAck,
SeqNum: iss + 1,
AckNum: c.IRS + 1,
RcvWnd: rcvWnd,
}
// If WS was expected to be in effect then scale the advertised window
// correspondingly.
if synOptions.WS > 0 {
ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
}
parsedOpts := tcp.ParsedOptions()
if synOptions.TS {
// Echo the tsVal back to the peer in the tsEcr field of the
// timestamp option.
// Increment TSVal by 1 from the value sent in the SYN and echo
// the TSVal in the SYN-ACK in the TSEcr field.
opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
ackHeaders.TCPOpts = opts[:]
}
// Send ACK.
c.SendPacket(nil, ackHeaders)
c.Port = StackPort
return &RawEndpoint{
C: c,
SrcPort: TestPort,
DstPort: StackPort,
Flags: header.TCPFlagPsh | header.TCPFlagAck,
NextSeqNum: iss + 1,
AckNum: c.IRS + 1,
WndSize: rcvWnd,
SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
RecentTS: parsedOpts.TSVal,
TSVal: synOptions.TSVal + 1,
}
}
// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
// for the Stack in the context.
func (c *Context) SACKEnabled() bool {
var v tcp.SACKEnabled
if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return.
return false
}
return bool(v)
}

View File

@@ -0,0 +1,142 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcp
import (
"time"
"tcpip/netstack/sleep"
)
type timerState int
const (
timerStateDisabled timerState = iota
timerStateEnabled
timerStateOrphaned
)
// timer is a timer implementation that reduces the interactions with the
// runtime timer infrastructure by letting timers run (and potentially
// eventually expire) even if they are stopped. It makes it cheaper to
// disable/reenable timers at the expense of spurious wakes. This is useful for
// cases when the same timer is disabled/reenabled repeatedly with relatively
// long timeouts farther into the future.
//
// TCP retransmit timers benefit from this because they the timeouts are long
// (currently at least 200ms), and get disabled when acks are received, and
// reenabled when new pending segments are sent.
//
// It is advantageous to avoid interacting with the runtime because it acquires
// a global mutex and performs O(log n) operations, where n is the global number
// of timers, whenever a timer is enabled or disabled, and may make a syscall.
//
// This struct is thread-compatible.
// 定时器的实现
type timer struct {
// state is the current state of the timer, it can be one of the
// following values:
// disabled - the timer is disabled.
// orphaned - the timer is disabled, but the runtime timer is
// enabled, which means that it will evetually cause a
// spurious wake (unless it gets enabled again before
// then).
// enabled - the timer is enabled, but the runtime timer may be set
// to an earlier expiration time due to a previous
// orphaned state.
state timerState
// target is the expiration time of the current timer. It is only
// meaningful in the enabled state.
target time.Time
// runtimeTarget is the expiration time of the runtime timer. It is
// meaningful in the enabled and orphaned states.
runtimeTarget time.Time
// timer is the runtime timer used to wait on.
timer *time.Timer
}
// init initializes the timer. Once it expires, it the given waker will be
// asserted.
func (t *timer) init(w *sleep.Waker) {
t.state = timerStateDisabled
// Initialize a runtime timer that will assert the waker, then
// immediately stop it.
t.timer = time.AfterFunc(time.Hour, func() {
w.Assert()
})
t.timer.Stop()
}
// cleanup frees all resources associated with the timer.
func (t *timer) cleanup() {
t.timer.Stop()
}
// checkExpiration checks if the given timer has actually expired, it should be
// called whenever a sleeper wakes up due to the waker being asserted, and is
// used to check if it's a supurious wake (due to a previously orphaned timer)
// or a legitimate one.
func (t *timer) checkExpiration() bool {
// Transition to fully disabled state if we're just consuming an
// orphaned timer.
if t.state == timerStateOrphaned {
t.state = timerStateDisabled
return false
}
// The timer is enabled, but it may have expired early. Check if that's
// the case, and if so, reset the runtime timer to the correct time.
now := time.Now()
if now.Before(t.target) {
t.runtimeTarget = t.target
t.timer.Reset(t.target.Sub(now))
return false
}
// The timer has actually expired, disable it for now and inform the
// caller.
t.state = timerStateDisabled
return true
}
// disable disables the timer, leaving it in an orphaned state if it wasn't
// already disabled.
func (t *timer) disable() {
if t.state != timerStateDisabled {
t.state = timerStateOrphaned
}
}
// enabled returns true if the timer is currently enabled, false otherwise.
func (t *timer) enabled() bool {
return t.state == timerStateEnabled
}
// enable enables the timer, programming the runtime timer if necessary.
func (t *timer) enable(d time.Duration) {
t.target = time.Now().Add(d)
// Check if we need to set the runtime timer.
if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
t.runtimeTarget = t.target
t.timer.Reset(d)
}
t.state = timerStateEnabled
}

View File

@@ -0,0 +1,348 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package tcpconntrack implements a TCP connection tracking object. It allows
// users with access to a segment stream to figure out when a connection is
// established, reset, and closed (and in the last case, who closed first).
package tcpconntrack
import (
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/seqnum"
)
// Result is returned when the state of a TCB is updated in response to an
// inbound or outbound segment.
type Result int
const (
// ResultDrop indicates that the segment should be dropped.
ResultDrop Result = iota
// ResultConnecting indicates that the connection remains in a
// connecting state.
ResultConnecting
// ResultAlive indicates that the connection remains alive (connected).
ResultAlive
// ResultReset indicates that the connection was reset.
ResultReset
// ResultClosedByPeer indicates that the connection was gracefully
// closed, and the inbound stream was closed first.
ResultClosedByPeer
// ResultClosedBySelf indicates that the connection was gracefully
// closed, and the outbound stream was closed first.
ResultClosedBySelf
)
// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP
// connection and inform the caller when the connection has been closed.
type TCB struct {
inbound stream
outbound stream
// State handlers.
handlerInbound func(*TCB, header.TCP) Result
handlerOutbound func(*TCB, header.TCP) Result
// firstFin holds a pointer to the first stream to send a FIN.
firstFin *stream
// state is the current state of the stream.
state Result
}
// Init initializes the state of the TCB according to the initial SYN.
func (t *TCB) Init(initialSyn header.TCP) {
t.handlerInbound = synSentStateInbound
t.handlerOutbound = synSentStateOutbound
iss := seqnum.Value(initialSyn.SequenceNumber())
t.outbound.una = iss
t.outbound.nxt = iss.Add(logicalLen(initialSyn))
t.outbound.end = t.outbound.nxt
// Even though "end" is a sequence number, we don't know the initial
// receive sequence number yet, so we store the window size until we get
// a SYN from the peer.
t.inbound.una = 0
t.inbound.nxt = 0
t.inbound.end = seqnum.Value(initialSyn.WindowSize())
t.state = ResultConnecting
}
// UpdateStateInbound updates the state of the TCB based on the supplied inbound
// segment.
func (t *TCB) UpdateStateInbound(tcp header.TCP) Result {
st := t.handlerInbound(t, tcp)
if st != ResultDrop {
t.state = st
}
return st
}
// UpdateStateOutbound updates the state of the TCB based on the supplied
// outbound segment.
func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
st := t.handlerOutbound(t, tcp)
if st != ResultDrop {
t.state = st
}
return st
}
// IsAlive returns true as long as the connection is established(Alive)
// or connecting state.
func (t *TCB) IsAlive() bool {
return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed())
}
// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream.
func (t *TCB) OutboundSendSequenceNumber() seqnum.Value {
return t.outbound.nxt
}
// InboundSendSequenceNumber returns the snd.NXT for the inbound stream.
func (t *TCB) InboundSendSequenceNumber() seqnum.Value {
return t.inbound.nxt
}
// adapResult modifies the supplied "Result" according to the state of the TCB;
// if r is anything other than "Alive", or if one of the streams isn't closed
// yet, it is returned unmodified. Otherwise it's converted to either
// ClosedBySelf or ClosedByPeer depending on which stream was closed first.
func (t *TCB) adaptResult(r Result) Result {
// Check the unmodified case.
if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() {
return r
}
// Find out which was closed first.
if t.firstFin == &t.outbound {
return ResultClosedBySelf
}
return ResultClosedByPeer
}
// synSentStateInbound is the state handler for inbound segments when the
// connection is in SYN-SENT state.
func synSentStateInbound(t *TCB, tcp header.TCP) Result {
flags := tcp.Flags()
ackPresent := flags&header.TCPFlagAck != 0
ack := seqnum.Value(tcp.AckNumber())
// Ignore segment if ack is present but not acceptable.
if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) {
return ResultConnecting
}
// If reset is specified, we will let the packet through no matter what
// but we will also destroy the connection if the ACK is present (and
// implicitly acceptable).
if flags&header.TCPFlagRst != 0 {
if ackPresent {
t.inbound.rstSeen = true
return ResultReset
}
return ResultConnecting
}
// Ignore segment if SYN is not set.
if flags&header.TCPFlagSyn == 0 {
return ResultConnecting
}
// Update state informed by this SYN.
irs := seqnum.Value(tcp.SequenceNumber())
t.inbound.una = irs
t.inbound.nxt = irs.Add(logicalLen(tcp))
t.inbound.end += irs
t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize()))
// If the ACK was set (it is acceptable), update our unacknowledgement
// tracking.
if ackPresent {
// Advance the "una" and "end" indices of the outbound stream.
if t.outbound.una.LessThan(ack) {
t.outbound.una = ack
}
if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) {
t.outbound.end = end
}
}
// Update handlers so that new calls will be handled by new state.
t.handlerInbound = allOtherInbound
t.handlerOutbound = allOtherOutbound
return ResultAlive
}
// synSentStateOutbound is the state handler for outbound segments when the
// connection is in SYN-SENT state.
func synSentStateOutbound(t *TCB, tcp header.TCP) Result {
// Drop outbound segments that aren't retransmits of the original one.
if tcp.Flags() != header.TCPFlagSyn ||
tcp.SequenceNumber() != uint32(t.outbound.una) {
return ResultDrop
}
// Update the receive window. We only remember the largest value seen.
if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end {
t.inbound.end = wnd
}
return ResultConnecting
}
// update updates the state of inbound and outbound streams, given the supplied
// inbound segment. For outbound segments, this same function can be called with
// swapped inbound/outbound streams.
func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result {
// Ignore segments out of the window.
s := seqnum.Value(tcp.SequenceNumber())
if !inbound.acceptable(s, dataLen(tcp)) {
return ResultAlive
}
flags := tcp.Flags()
if flags&header.TCPFlagRst != 0 {
inbound.rstSeen = true
return ResultReset
}
// Ignore segments that don't have the ACK flag, and those with the SYN
// flag.
if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 {
return ResultAlive
}
// Ignore segments that acknowledge not yet sent data.
ack := seqnum.Value(tcp.AckNumber())
if outbound.nxt.LessThan(ack) {
return ResultAlive
}
// Advance the "una" and "end" indices of the outbound stream.
if outbound.una.LessThan(ack) {
outbound.una = ack
}
if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) {
outbound.end = end
}
// Advance the "nxt" index of the inbound stream.
end := s.Add(logicalLen(tcp))
if inbound.nxt.LessThan(end) {
inbound.nxt = end
}
// Note the index of the FIN segment. And stash away a pointer to the
// first stream to see a FIN.
if flags&header.TCPFlagFin != 0 && !inbound.finSeen {
inbound.finSeen = true
inbound.fin = end - 1
if *firstFin == nil {
*firstFin = inbound
}
}
return ResultAlive
}
// allOtherInbound is the state handler for inbound segments in all states
// except SYN-SENT.
func allOtherInbound(t *TCB, tcp header.TCP) Result {
return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin))
}
// allOtherOutbound is the state handler for outbound segments in all states
// except SYN-SENT.
func allOtherOutbound(t *TCB, tcp header.TCP) Result {
return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin))
}
// streams holds the state of a TCP unidirectional stream.
type stream struct {
// The interval [una, end) is the allowed interval as defined by the
// receiver, i.e., anything less than una has already been acknowledged
// and anything greater than or equal to end is beyond the receiver
// window. The interval [una, nxt) is the acknowledgable range, whose
// right edge indicates the sequence number of the next byte to be sent
// by the sender, i.e., anything greater than or equal to nxt hasn't
// been sent yet.
una seqnum.Value
nxt seqnum.Value
end seqnum.Value
// finSeen indicates if a FIN has already been sent on this stream.
finSeen bool
// fin is the sequence number of the FIN. It is only valid after finSeen
// is set to true.
fin seqnum.Value
// rstSeen indicates if a RST has already been sent on this stream.
rstSeen bool
}
// acceptable determines if the segment with the given sequence number and data
// length is acceptable, i.e., if it's within the [una, end) window or, in case
// the window is zero, if it's a packet with no payload and sequence number
// equal to una.
func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
wnd := s.una.Size(s.end)
if wnd == 0 {
return segLen == 0 && segSeq == s.una
}
// Make sure [segSeq, seqSeq+segLen) is non-empty.
if segLen == 0 {
segLen = 1
}
return seqnum.Overlap(s.una, wnd, segSeq, segLen)
}
// closed determines if the stream has already been closed. This happens when
// a FIN has been set by the sender and acknowledged by the receiver.
func (s *stream) closed() bool {
return s.finSeen && s.fin.LessThan(s.una)
}
// dataLen returns the length of the TCP segment payload.
func dataLen(tcp header.TCP) seqnum.Size {
return seqnum.Size(len(tcp) - int(tcp.DataOffset()))
}
// logicalLen calculates the logical length of the TCP segment.
func logicalLen(tcp header.TCP) seqnum.Size {
l := dataLen(tcp)
flags := tcp.Flags()
if flags&header.TCPFlagSyn != 0 {
l++
}
if flags&header.TCPFlagFin != 0 {
l++
}
return l
}

View File

@@ -0,0 +1,511 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tcpconntrack_test
import (
"testing"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/transport/tcpconntrack"
)
// connected creates a connection tracker TCB and sets it to a connected state
// by performing a 3-way handshake.
func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB {
// Send SYN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: iss,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: irw,
})
tcb := tcpconntrack.TCB{}
tcb.Init(tcp)
// Receive SYN-ACK.
tcp.Encode(&header.TCPFields{
SeqNum: irs,
AckNum: iss + 1,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn | header.TCPFlagAck,
WindowSize: isw,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Send ACK.
tcp.Encode(&header.TCPFields{
SeqNum: iss + 1,
AckNum: irs + 1,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: irw,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
return &tcb
}
func TestConnectionRefused(t *testing.T) {
// Send SYN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 30000,
})
tcb := tcpconntrack.TCB{}
tcb.Init(tcp)
// Receive RST.
tcp.Encode(&header.TCPFields{
SeqNum: 789,
AckNum: 1235,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagRst | header.TCPFlagAck,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
}
}
func TestConnectionRefusedInSynRcvd(t *testing.T) {
// Send SYN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 30000,
})
tcb := tcpconntrack.TCB{}
tcb.Init(tcp)
// Receive SYN.
tcp.Encode(&header.TCPFields{
SeqNum: 789,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Receive RST with no ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 790,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagRst,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
}
}
func TestConnectionResetInSynRcvd(t *testing.T) {
// Send SYN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 30000,
})
tcb := tcpconntrack.TCB{}
tcb.Init(tcp)
// Receive SYN.
tcp.Encode(&header.TCPFields{
SeqNum: 789,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Send RST with no ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 1235,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagRst,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
}
}
func TestRetransmitOnSynSent(t *testing.T) {
// Send initial SYN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 30000,
})
tcb := tcpconntrack.TCB{}
tcb.Init(tcp)
// Retransmit the same SYN.
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting)
}
}
func TestRetransmitOnSynRcvd(t *testing.T) {
// Send initial SYN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 30000,
})
tcb := tcpconntrack.TCB{}
tcb.Init(tcp)
// Receive SYN. This will cause the state to go to SYN-RCVD.
tcp.Encode(&header.TCPFields{
SeqNum: 789,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Retransmit the original SYN.
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Transmit a SYN-ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 790,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn | header.TCPFlagAck,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
}
func TestClosedBySelf(t *testing.T) {
tcb := connected(t, 1234, 789, 30000, 50000)
// Send FIN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 1235,
AckNum: 790,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck | header.TCPFlagFin,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Receive FIN/ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 790,
AckNum: 1236,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck | header.TCPFlagFin,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Send ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 1236,
AckNum: 791,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
}
}
func TestClosedByPeer(t *testing.T) {
tcb := connected(t, 1234, 789, 30000, 50000)
// Receive FIN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 790,
AckNum: 1235,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck | header.TCPFlagFin,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Send FIN/ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 1235,
AckNum: 791,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck | header.TCPFlagFin,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Receive ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 791,
AckNum: 1236,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer)
}
}
func TestSendAndReceiveDataClosedBySelf(t *testing.T) {
sseq := uint32(1234)
rseq := uint32(789)
tcb := connected(t, sseq, rseq, 30000, 50000)
sseq++
rseq++
// Send some data.
tcp := make(header.TCP, header.TCPMinimumSize+1024)
for i := uint32(0); i < 10; i++ {
// Send some data.
tcp.Encode(&header.TCPFields{
SeqNum: sseq,
AckNum: rseq,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 30000,
})
sseq += uint32(len(tcp)) - header.TCPMinimumSize
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Receive ack for data.
tcp.Encode(&header.TCPFields{
SeqNum: rseq,
AckNum: sseq,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
}
for i := uint32(0); i < 10; i++ {
// Receive some data.
tcp.Encode(&header.TCPFields{
SeqNum: rseq,
AckNum: sseq,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 50000,
})
rseq += uint32(len(tcp)) - header.TCPMinimumSize
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Send ack for data.
tcp.Encode(&header.TCPFields{
SeqNum: sseq,
AckNum: rseq,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
}
// Send FIN.
tcp = tcp[:header.TCPMinimumSize]
tcp.Encode(&header.TCPFields{
SeqNum: sseq,
AckNum: rseq,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck | header.TCPFlagFin,
WindowSize: 30000,
})
sseq++
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Receive FIN/ACK.
tcp.Encode(&header.TCPFields{
SeqNum: rseq,
AckNum: sseq,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck | header.TCPFlagFin,
WindowSize: 50000,
})
rseq++
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Send ACK.
tcp.Encode(&header.TCPFields{
SeqNum: sseq,
AckNum: rseq,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
}
}
func TestIgnoreBadResetOnSynSent(t *testing.T) {
// Send SYN.
tcp := make(header.TCP, header.TCPMinimumSize)
tcp.Encode(&header.TCPFields{
SeqNum: 1234,
AckNum: 0,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn,
WindowSize: 30000,
})
tcb := tcpconntrack.TCB{}
tcb.Init(tcp)
// Receive a RST with a bad ACK, it should not cause the connection to
// be reset.
acks := []uint32{1234, 1236, 1000, 5000}
flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
for _, a := range acks {
for _, f := range flags {
tcp.Encode(&header.TCPFields{
SeqNum: 789,
AckNum: a,
DataOffset: header.TCPMinimumSize,
Flags: f,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
}
}
// Complete the handshake.
// Receive SYN-ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 789,
AckNum: 1235,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagSyn | header.TCPFlagAck,
WindowSize: 50000,
})
if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
// Send ACK.
tcp.Encode(&header.TCPFields{
SeqNum: 1235,
AckNum: 790,
DataOffset: header.TCPMinimumSize,
Flags: header.TCPFlagAck,
WindowSize: 30000,
})
if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
}
}

View File

@@ -0,0 +1,980 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package udp
import (
"log"
"math"
"sync"
"tcpip/netstack/sleep"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
// +stateify savable
// udp报文结构当接收到udp报文时会用这个结构来保存udp报文数据
// 并插入到接收链表中。
type udpPacket struct {
udpPacketEntry
senderAddress tcpip.FullAddress
data buffer.VectorisedView
timestamp int64
hasTimestamp bool
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
}
type endpointState int
// 表示UDP端的状态参数
const (
stateInitial endpointState = iota
stateBound
stateConnected
stateClosed
)
// endpoint represents a UDP endpoint. This struct serves as the interface
// between users of the endpoint and the protocol implementation; it is legal to
// have concurrent goroutines make calls into the endpoint, they are properly
// synchronized.
//
// +stateify savable
// 表示UDP协议端的结构
type endpoint struct {
// The following fields are initialized at creation time and do not
// change throughout the lifetime of the endpoint.
stack *stack.Stack
netProto tcpip.NetworkProtocolNumber
waiterQueue *waiter.Queue
// The following fields are used to manage the receive queue, and are
// protected by rcvMu.
rcvMu sync.Mutex
rcvReady bool
rcvList udpPacketList
rcvBufSizeMax int
rcvBufSize int
rcvClosed bool
rcvTimestamp bool
// The following fields are protected by the mu mutex.
mu sync.RWMutex
sndBufSize int
id stack.TransportEndpointID
state endpointState
bindNICID tcpip.NICID
regNICID tcpip.NICID
route stack.Route
dstPort uint16
v6only bool
multicastTTL uint8
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
// multicastMemberships that need to be remvoed when the endpoint is
// closed. Protected by the mu mutex.
multicastMemberships []multicastMembership
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address).
effectiveNetProtos []tcpip.NetworkProtocolNumber
}
// 多播的成员关系包括多播地址和网卡ID
type multicastMembership struct {
nicID tcpip.NICID
multicastAddr tcpip.Address
}
// newEndpoint 新建一个UDP端
// 默认多播的TTL为1且接收和发送的缓存为32k。
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
return &endpoint{
stack: stack,
netProto: netProto,
waiterQueue: waiterQueue,
// RFC 1075 section 5.4 recommends a TTL of 1 for membership
// requests.
//
// RFC 5135 4.2.1 appears to assume that IGMP messages have a
// TTL of 1.
//
// RFC 5135 Appendix A defines TTL=1: A multicast source that
// wants its traffic to not traverse a router (e.g., leave a
// home network) may find it useful to send traffic with IP
// TTL=1.
//
// Linux defaults to TTL=1.
multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024,
}
}
// NewConnectedEndpoint creates a new endpoint in the connected state using the
// provided route.
func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(stack, r.NetProto, waiterQueue)
// Register new endpoint so that packets are routed to it.
if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
ep.Close()
return nil, err
}
ep.id = id
ep.route = r.Clone()
ep.dstPort = id.RemotePort
ep.regNICID = r.NICID()
ep.state = stateConnected
return ep, nil
}
// Close puts the endpoint in a closed state and frees all resources
// associated with it.
// UDP端的关闭释放相应的资源
func (e *endpoint) Close() {
e.mu.Lock()
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
// 释放在协议栈中注册的UDP端
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
// 释放端口占用
e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
}
for _, mem := range e.multicastMemberships {
e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
}
e.multicastMemberships = nil
// Close the receive list and drain it.
e.rcvMu.Lock()
e.rcvClosed = true
e.rcvBufSize = 0
// 清空接收链表
for !e.rcvList.Empty() {
p := e.rcvList.Front()
e.rcvList.Remove(p)
}
e.rcvMu.Unlock()
e.route.Release()
// Update the state.
e.state = stateClosed
e.mu.Unlock()
e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
// 从UDP端读取消息即使没有消息也不会阻塞。
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.rcvMu.Lock()
// 如果接收链表为空,即没有任何数据
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
return buffer.View{}, tcpip.ControlMessages{}, err
}
// 从接收链表中取出最前面的数据报,接着从链表中删除该数据报
// 然后减少接收缓存的大小
p := e.rcvList.Front()
e.rcvList.Remove(p)
e.rcvBufSize -= p.data.Size()
ts := e.rcvTimestamp
e.rcvMu.Unlock()
if addr != nil {
// 赋值发送地址
*addr = p.senderAddress
}
if ts && !p.hasTimestamp {
// Linux uses the current time.
p.timestamp = e.stack.NowNanoseconds()
}
// 返回数据报的内容
return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
// binds it if it's still in the initial state. To do so, it must first
// reacquire the mutex in exclusive mode.
//
// Returns true for retry if preparation should be retried.
// 写数据之前的准备,如果还是初始状态需要先进性绑定操作。
func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
switch e.state {
case stateInitial:
case stateConnected:
return false, nil
case stateBound:
if to == nil {
return false, tcpip.ErrDestinationRequired
}
return false, nil
default:
return false, tcpip.ErrInvalidEndpointState
}
e.mu.RUnlock()
defer e.mu.RLock()
e.mu.Lock()
defer e.mu.Unlock()
// The state changed when we released the shared locked and re-acquired
// it in exclusive mode. Try again.
if e.state != stateInitial {
return true, nil
}
// The state is still 'initial', so try to bind the endpoint.
if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
return false, err
}
return true, nil
}
// Write writes data to the endpoint's peer. This method does not block
// if the data cannot be written.
// 用户层最终调用该函数,发送数据包给对端,即使数据写失败,也不会阻塞。
func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
// MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
if opts.More {
return 0, nil, tcpip.ErrInvalidOptionValue
}
// 如果报文长度超过65535将会超过UDP最大的长度表示这是不允许的。
if p.Size() > math.MaxUint16 {
// Payload can't possibly fit in a packet.
return 0, nil, tcpip.ErrMessageTooLong
}
to := opts.To
e.mu.RLock()
defer e.mu.RUnlock()
// If we've shutdown with SHUT_WR we are in an invalid state for sending.
// 如果设置了关闭写数据,那返回错误
if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
return 0, nil, tcpip.ErrClosedForSend
}
// Prepare for write.
// 准备发送数据
for {
retry, err := e.prepareForWrite(to)
if err != nil {
return 0, nil, err
}
if !retry {
break
}
}
var route *stack.Route
var dstPort uint16
if to == nil {
// 如果没有指定发送的地址用UDP端 Connect 得到的路由和目的端口
route = &e.route
dstPort = e.dstPort
if route.IsResolutionRequired() {
// Promote lock to exclusive if using a shared route, given that it may need to
// change in Route.Resolve() call below.
// 如果使用共享路由则将锁定提升为独占路由因为它可能需要在下面的Route.Resolve调用中进行更改。
e.mu.RUnlock()
defer e.mu.RLock()
e.mu.Lock()
defer e.mu.Unlock()
// Recheck state after lock was re-acquired.
// 锁定后重新检查状态。
if e.state != stateConnected {
return 0, nil, tcpip.ErrInvalidEndpointState
}
}
} else { // 如果指定了发送的地址
nicid := to.NIC
// 如果绑定了网卡,使用该网卡
if e.bindNICID != 0 {
// 如果绑定的网卡和参数指定的网卡不同,那么返回错误
if nicid != 0 && nicid != e.bindNICID {
return 0, nil, tcpip.ErrNoRoute
}
nicid = e.bindNICID
}
// 得到目的IP+端口
toCopy := *to
to = &toCopy
netProto, err := e.checkV4Mapped(to, false)
if err != nil {
return 0, nil, err
}
log.Printf("netProto: 0x%x", netProto)
// Find the enpoint.
// 根据目的地址和协议找到相关路由信息
r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, to.Addr, netProto)
if err != nil {
return 0, nil, err
}
defer r.Release()
route = &r
dstPort = to.Port
}
// 如果路由没有下一跳的链路MAC地址那么触发相应的机制来填充该路由信息。
// 比如IPV4协议如果没有目的IP对应的MAC信息从从ARP缓存中查找信息找到了直接返回
// 若没找到那么发送ARP请求得到对应的MAC地址。
if route.IsResolutionRequired() {
waker := &sleep.Waker{}
if ch, err := route.Resolve(waker); err != nil {
if err == tcpip.ErrWouldBlock {
// Link address needs to be resolved. Resolution was triggered the background.
// Better luck next time.
route.RemoveWaker(waker)
return 0, ch, tcpip.ErrNoLinkAddress
}
return 0, nil, err
}
}
// 得到要发送的数据内容
v, err := p.Get(p.Size())
if err != nil {
return 0, nil, err
}
ttl := route.DefaultTTL()
// 如果是多播地址设置ttl
if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
ttl = e.multicastTTL
}
// 增加UDP头部信息并发送出去
if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
return 0, nil, err
}
return uintptr(len(v)), nil, nil
}
// Peek only returns data from a single datagram, so do nothing here.
func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
// SetSockOpt sets a socket option.
// 给 UDP 套接字设置选项
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
switch v := opt.(type) {
case tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
return tcpip.ErrInvalidEndpointState
}
e.mu.Lock()
defer e.mu.Unlock()
// We only allow this to be set when we're in the initial state.
if e.state != stateInitial {
return tcpip.ErrInvalidEndpointState
}
e.v6only = v != 0
case tcpip.TimestampOption:
e.rcvMu.Lock()
e.rcvTimestamp = v != 0
e.rcvMu.Unlock()
case tcpip.MulticastTTLOption:
e.mu.Lock()
e.multicastTTL = uint8(v)
e.mu.Unlock()
case tcpip.AddMembershipOption:
nicID := v.NIC
if v.InterfaceAddr != header.IPv4Any {
nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrNoRoute
}
// TODO: check that v.MulticastAddr is a multicast address.
if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
return err
}
e.mu.Lock()
defer e.mu.Unlock()
e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr})
case tcpip.RemoveMembershipOption:
nicID := v.NIC
if v.InterfaceAddr != header.IPv4Any {
nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
}
if nicID == 0 {
return tcpip.ErrNoRoute
}
// TODO: check that v.MulticastAddr is a multicast address.
if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
return err
}
e.mu.Lock()
defer e.mu.Unlock()
for i, mem := range e.multicastMemberships {
if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr {
// Only remove the first match, so that each added membership above is
// paired with exactly 1 removal.
e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1]
e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
break
}
}
}
return nil
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
// 从 UDP 端获取选项参数
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
case tcpip.ErrorOption:
return nil
case *tcpip.SendBufferSizeOption:
e.mu.Lock()
*o = tcpip.SendBufferSizeOption(e.sndBufSize)
e.mu.Unlock()
return nil
case *tcpip.ReceiveBufferSizeOption:
e.rcvMu.Lock()
*o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
e.rcvMu.Unlock()
return nil
case *tcpip.V6OnlyOption:
// We only recognize this option on v6 endpoints.
if e.netProto != header.IPv6ProtocolNumber {
return tcpip.ErrUnknownProtocolOption
}
e.mu.Lock()
v := e.v6only
e.mu.Unlock()
*o = 0
if v {
*o = 1
}
return nil
case *tcpip.ReceiveQueueSizeOption:
e.rcvMu.Lock()
if e.rcvList.Empty() {
*o = 0
} else {
p := e.rcvList.Front()
*o = tcpip.ReceiveQueueSizeOption(p.data.Size())
}
e.rcvMu.Unlock()
return nil
case *tcpip.TimestampOption:
e.rcvMu.Lock()
*o = 0
if e.rcvTimestamp {
*o = 1
}
e.rcvMu.Unlock()
case *tcpip.MulticastTTLOption:
e.mu.Lock()
*o = tcpip.MulticastTTLOption(e.multicastTTL)
e.mu.Unlock()
return nil
}
return tcpip.ErrUnknownProtocolOption
}
// sendUDP sends a UDP segment via the provided network endpoint and under the
// provided identity.
// 增加UDP头部信息并发送给给网络层
func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
// Allocate a buffer for the UDP header.
hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
// Initialize the header.
udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
// 得到报文的长度
length := uint16(hdr.UsedLength() + data.Size())
// UDP首部的编码
udp.Encode(&header.UDPFields{
SrcPort: localPort,
DstPort: remotePort,
Length: length,
})
// Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
// 检验和的计算
xsum := r.PseudoHeaderChecksum(ProtocolNumber)
for _, v := range data.Views() {
xsum = header.Checksum(v, xsum)
}
udp.SetChecksum(^udp.CalculateChecksum(xsum, length))
}
// Track count of packets sent.
r.Stats().UDP.PacketsSent.Increment()
// 将准备好的UDP首部和数据写给网络层
log.Printf("send udp %d bytes", hdr.UsedLength()+data.Size())
return r.WritePacket(hdr, data, ProtocolNumber, ttl)
}
// IPV6于IPV4地址的映射
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto
if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only {
return 0, tcpip.ErrNoRoute
}
netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
if addr.Addr == "\x00\x00\x00\x00" {
addr.Addr = ""
}
// Fail if we are bound to an IPv6 address.
if !allowMismatch && len(e.id.LocalAddress) == 16 {
return 0, tcpip.ErrNetworkUnreachable
}
}
// Fail if we're bound to an address length different from the one we're
// checking.
if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState
}
return netProto, nil
}
// Connect connects the endpoint to its peer. Specifying a NIC is optional.
// 连接UDP对端当发送数据的时候用这个地址来填充目标地址
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// 目标端口为0是错误的
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
}
e.mu.Lock()
defer e.mu.Unlock()
nicid := addr.NIC
var localPort uint16
// 判断UDP端的状态
switch e.state {
case stateInitial:
// 如果是初始状态,直接下一步
case stateBound, stateConnected:
// 如果已经绑定或者已连接状态
localPort = e.id.LocalPort
if e.bindNICID == 0 {
break
}
if nicid != 0 && nicid != e.bindNICID {
return tcpip.ErrInvalidEndpointState
}
nicid = e.bindNICID
default:
return tcpip.ErrInvalidEndpointState
}
// 检查地址的映射,得到相应的协议
netProto, err := e.checkV4Mapped(&addr, false)
if err != nil {
return err
}
// Find a route to the desired destination.
// 在全局协议栈中查找路由
r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
if err != nil {
return err
}
defer r.Release()
// 新建一个传输端的标识包括源IP、源端口、目的IP、目的端口
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
}
// Even if we're connected, this endpoint can still be used to send
// packets on a different network protocol, so we register both even if
// v6only is set to false and this is an ipv6 endpoint.
// 设置网络层协议IPV4或IPV6或两者都有
netProtos := []tcpip.NetworkProtocolNumber{netProto}
if netProto == header.IPv6ProtocolNumber && !e.v6only {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
}
}
// 将该UDP端注册到协议栈中
id, err = e.registerWithStack(nicid, netProtos, id)
if err != nil {
return err
}
// Remove the old registration.
// 如果源端口不为0则尝试在传输层端中删除老的UDP端
if e.id.LocalPort != 0 {
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
}
// 赋值UDP端的属性
e.id = id
e.route = r.Clone()
e.dstPort = addr.Port
e.regNICID = nicid
e.effectiveNetProtos = netProtos
// 更改该UDP端的状态为已连接
e.state = stateConnected
// 标志该UDP端可以接收数据了
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
return nil
}
// ConnectEndpoint is not supported.
func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
// Shutdown closes the read and/or write end of the endpoint connection
// to its peer.
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// A socket in the bound state can still receive multicast messages,
// so we need to notify waiters on shutdown.
if e.state != stateBound && e.state != stateConnected {
return tcpip.ErrNotConnected
}
e.shutdownFlags |= flags
if flags&tcpip.ShutdownRead != 0 {
e.rcvMu.Lock()
wasClosed := e.rcvClosed
e.rcvClosed = true
e.rcvMu.Unlock()
if !wasClosed {
e.waiterQueue.Notify(waiter.EventIn)
}
}
return nil
}
// Listen is not supported by UDP, it just fails.
// UDP是没有Listen的。
func (*endpoint) Listen(int) *tcpip.Error {
return tcpip.ErrNotSupported
}
// Accept is not supported by UDP, it just fails.
// 没有Listen自然没有Accept。
func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
return nil, nil, tcpip.ErrNotSupported
}
// 在协议栈中注册该UDP端并且分配源端口
func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
if e.id.LocalPort == 0 {
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
if err != nil {
return id, err
}
id.LocalPort = port
}
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
}
return id, err
}
// 根据addr参数绑定
func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
// Don't allow binding once endpoint is not in the initial state
// anymore.
// 不是初始状态的UDP端不让绑定
if e.state != stateInitial {
return tcpip.ErrInvalidEndpointState
}
netProto, err := e.checkV4Mapped(&addr, true)
if err != nil {
return err
}
// Expand netProtos to include v4 and v6 if the caller is binding to a
// wildcard (empty) address, and this is an IPv6 endpoint with v6only
// set to false.
netProtos := []tcpip.NetworkProtocolNumber{netProto}
if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber,
}
}
if len(addr.Addr) != 0 {
// A local address was specified, verify that it's valid.
if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
}
// 绑定的时候传输端ID是源IP+源端口
id := stack.TransportEndpointID{
LocalPort: addr.Port,
LocalAddress: addr.Addr,
}
// 在协议中注册该UDP端
id, err = e.registerWithStack(addr.NIC, netProtos, id)
if err != nil {
return err
}
// 如果指定了 commit 函数,执行并处理错误
if commit != nil {
if err := commit(); err != nil {
// Unregister, the commit failed.
e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id)
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
return err
}
}
e.id = id
e.regNICID = addr.NIC
e.effectiveNetProtos = netProtos
// Mark endpoint as bound.
// 标记状态为已绑定
e.state = stateBound
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
return nil
}
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
// Bind 将该UDP端绑定本地的一个IP+端口
// 例如绑定本地0.0.0.0的9000端口那么其他机器给这台机器9000端口发消息该UDP端就能收到消息了。
func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
// 执行绑定IP+端口操作
err := e.bindLocked(addr, commit)
if err != nil {
return err
}
// 绑定的网卡ID
e.bindNICID = addr.NIC
return nil
}
// GetLocalAddress returns the address to which the endpoint is bound.
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
return tcpip.FullAddress{
NIC: e.regNICID,
Addr: e.id.LocalAddress,
Port: e.id.LocalPort,
}, nil
}
// GetRemoteAddress returns the address to which the endpoint is connected.
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock()
defer e.mu.RUnlock()
if e.state != stateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected
}
return tcpip.FullAddress{
NIC: e.regNICID,
Addr: e.id.RemoteAddress,
Port: e.id.RemotePort,
}, nil
}
// Readiness returns the current readiness of the endpoint. For example, if
// waiter.EventIn is set, the endpoint is immediately readable.
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// The endpoint is always writable.
result := waiter.EventOut & mask
// Determine if the endpoint is readable if requested.
if (mask & waiter.EventIn) != 0 {
e.rcvMu.Lock()
if !e.rcvList.Empty() || e.rcvClosed {
result |= waiter.EventIn
}
e.rcvMu.Unlock()
}
return result
}
// HandlePacket is called by the stack when new packets arrive to this transport
// endpoint.
// 从网络层接收到UDP数据报时的处理函数
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
// Get the header then trim it from the view.
hdr := header.UDP(vv.First())
if int(hdr.Length()) > vv.Size() {
// Malformed packet.
// 错误报文
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return
}
// 去除UDP首部
vv.TrimFront(header.UDPMinimumSize)
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
// Drop the packet if our buffer is currently full.
// 如果UDP的接收缓存已经满了那么丢弃报文。
if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
e.rcvMu.Unlock()
return
}
// 接收缓存是否为空
wasEmpty := e.rcvBufSize == 0
// Push new packet into receive list and increment the buffer size.
// 新建一个UDP 数据报结构,插入到接收链表里
pkt := &udpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
Port: hdr.SourcePort(),
},
}
// 复制UDP数据的用户数据
pkt.data = vv.Clone(pkt.views[:])
// 插入到接收链表里,并增加已使用的缓存
e.rcvList.PushBack(pkt)
e.rcvBufSize += vv.Size()
if e.rcvTimestamp {
pkt.timestamp = e.stack.NowNanoseconds()
pkt.hasTimestamp = true
}
e.rcvMu.Unlock()
// Notify any waiters that there's data to be read now.
// 通知程序现在可以读取到数据了
if wasEmpty {
e.waiterQueue.Notify(waiter.EventIn)
}
log.Printf("recv udp %d bytes", hdr.Length())
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
}

View File

@@ -0,0 +1,85 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package udp contains the implementation of the UDP transport protocol. To use
// it in the networking stack, this package must be added to the project, and
// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
// transport protocols when calling stack.New(). Then endpoints can be created
// by passing udp.ProtocolNumber as the transport protocol number when calling
// Stack.NewEndpoint().
package udp
import (
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/waiter"
)
const (
// ProtocolName is the string representation of the udp protocol name.
ProtocolName = "udp"
// ProtocolNumber is the udp protocol number.
ProtocolNumber = header.UDPProtocolNumber
)
// tcpip.Endpoint 接口的UDP协议实现
type protocol struct{}
// Number returns the udp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber {
return ProtocolNumber
}
// NewEndpoint creates a new udp endpoint.
func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber,
waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil
}
// MinimumPacketSize returns the minimum valid udp packet size.
func (*protocol) MinimumPacketSize() int {
return header.UDPMinimumSize
}
// ParsePorts returns the source and destination ports stored in the given udp
// packet.
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
h := header.UDP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
// that don't match any existing endpoint.
func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
return true
}
// SetOption implements TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{}
})
}

View File

@@ -0,0 +1,174 @@
package udp
// ElementMapper provides an identity mapping by default.
//
// This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type.
type udpPacketElementMapper struct{}
// linkerFor maps an Element to a Linker.
//
// This default implementation should be inlined.
//
//go:nosplit
func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
// List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations.
//
// The zero value for List is an empty list ready to use.
//
// To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() {
// // do something with e.
// }
//
// +stateify savable
// udp数据报的双向链表结构
type udpPacketList struct {
head *udpPacket
tail *udpPacket
}
// Reset resets list l to the empty state.
func (l *udpPacketList) Reset() {
l.head = nil
l.tail = nil
}
// Empty returns true iff the list is empty.
func (l *udpPacketList) Empty() bool {
return l.head == nil
}
// Front returns the first element of list l or nil.
func (l *udpPacketList) Front() *udpPacket {
return l.head
}
// Back returns the last element of list l or nil.
func (l *udpPacketList) Back() *udpPacket {
return l.tail
}
// PushFront inserts the element e at the front of list l.
func (l *udpPacketList) PushFront(e *udpPacket) {
udpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
udpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil {
udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
} else {
l.tail = e
}
l.head = e
}
// PushBack inserts the element e at the back of list l.
func (l *udpPacketList) PushBack(e *udpPacket) {
udpPacketElementMapper{}.linkerFor(e).SetNext(nil)
udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil {
udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
} else {
l.head = e
}
l.tail = e
}
// PushBackList inserts list m at the end of list l, emptying m.
func (l *udpPacketList) PushBackList(m *udpPacketList) {
if l.head == nil {
l.head = m.head
l.tail = m.tail
} else if m.head != nil {
udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail
}
m.head = nil
m.tail = nil
}
// InsertAfter inserts e after b.
func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
a := udpPacketElementMapper{}.linkerFor(b).Next()
udpPacketElementMapper{}.linkerFor(e).SetNext(a)
udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
udpPacketElementMapper{}.linkerFor(b).SetNext(e)
if a != nil {
udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
} else {
l.tail = e
}
}
// InsertBefore inserts e before a.
func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
b := udpPacketElementMapper{}.linkerFor(a).Prev()
udpPacketElementMapper{}.linkerFor(e).SetNext(a)
udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil {
udpPacketElementMapper{}.linkerFor(b).SetNext(e)
} else {
l.head = e
}
}
// Remove removes e from l.
func (l *udpPacketList) Remove(e *udpPacket) {
prev := udpPacketElementMapper{}.linkerFor(e).Prev()
next := udpPacketElementMapper{}.linkerFor(e).Next()
if prev != nil {
udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
} else {
l.head = next
}
if next != nil {
udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
} else {
l.tail = prev
}
}
// Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the
// methods needed by List.
//
// +stateify savable
type udpPacketEntry struct {
next *udpPacket
prev *udpPacket
}
// Next returns the entry that follows e in the list.
func (e *udpPacketEntry) Next() *udpPacket {
return e.next
}
// Prev returns the entry that precedes e in the list.
func (e *udpPacketEntry) Prev() *udpPacket {
return e.prev
}
// SetNext assigns 'entry' as the entry that follows e in the list.
func (e *udpPacketEntry) SetNext(elem *udpPacket) {
e.next = elem
}
// SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
e.prev = elem
}

View File

@@ -0,0 +1,928 @@
// Copyright 2018 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package udp_test
import (
"bytes"
"math/rand"
"testing"
"time"
"tcpip/netstack/tcpip"
"tcpip/netstack/tcpip/buffer"
"tcpip/netstack/tcpip/checker"
"tcpip/netstack/tcpip/header"
"tcpip/netstack/tcpip/link/channel"
"tcpip/netstack/tcpip/link/sniffer"
"tcpip/netstack/tcpip/network/ipv4"
"tcpip/netstack/tcpip/network/ipv6"
"tcpip/netstack/tcpip/stack"
"tcpip/netstack/tcpip/transport/udp"
"tcpip/netstack/waiter"
)
const (
stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
multicastV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + multicastAddr
V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
stackAddr = "\x0a\x00\x00\x01"
stackPort = 1234
testAddr = "\x0a\x00\x00\x02"
testPort = 4096
multicastAddr = "\xe8\x2b\xd3\xea"
multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
multicastPort = 1234
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
// of loopback interfaces on linux systems.
defaultMTU = 65536
)
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
ep tcpip.Endpoint
wq waiter.Queue
}
type headers struct {
srcPort uint16
dstPort uint16
}
func newDualTestContext(t *testing.T, mtu uint32) *testContext {
s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}, stack.Options{})
id, linkEP := channel.New(256, mtu, "")
if testing.Verbose() {
id = sniffer.New(id)
}
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
t.Fatalf("AddAddress failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: "\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
},
{
Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
},
})
return &testContext{
t: t,
s: s,
linkEP: linkEP,
}
}
func (c *testContext) cleanup() {
if c.ep != nil {
c.ep.Close()
}
}
func (c *testContext) createV6Endpoint(v6only bool) {
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
var v tcpip.V6OnlyOption
if v6only {
v = 1
}
if err := c.ep.SetSockOpt(v); err != nil {
c.t.Fatalf("SetSockOpt failed failed: %v", err)
}
}
func (c *testContext) getPacket(protocolNumber tcpip.NetworkProtocolNumber, multicast bool) []byte {
select {
case p := <-c.linkEP.C:
if p.Proto != protocolNumber {
c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, protocolNumber)
}
b := make([]byte, len(p.Header)+len(p.Payload))
copy(b, p.Header)
copy(b[len(p.Header):], p.Payload)
var checkerFn func(*testing.T, []byte, ...checker.NetworkChecker)
var srcAddr, dstAddr tcpip.Address
switch protocolNumber {
case ipv4.ProtocolNumber:
checkerFn = checker.IPv4
srcAddr, dstAddr = stackAddr, testAddr
if multicast {
dstAddr = multicastAddr
}
case ipv6.ProtocolNumber:
checkerFn = checker.IPv6
srcAddr, dstAddr = stackV6Addr, testV6Addr
if multicast {
dstAddr = multicastV6Addr
}
default:
c.t.Fatalf("unknown protocol %d", protocolNumber)
}
checkerFn(c.t, b, checker.SrcAddr(srcAddr), checker.DstAddr(dstAddr))
return b
case <-time.After(2 * time.Second):
c.t.Fatalf("Packet wasn't written out")
}
return nil
}
func (c *testContext) sendV6Packet(payload []byte, h *headers) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
// Initialize the IP header.
ip := header.IPv6(buf)
ip.Encode(&header.IPv6Fields{
PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
NextHeader: uint8(udp.ProtocolNumber),
HopLimit: 65,
SrcAddr: testV6Addr,
DstAddr: stackV6Addr,
})
// Initialize the UDP header.
u := header.UDP(buf[header.IPv6MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.Checksum([]byte(testV6Addr), 0)
xsum = header.Checksum([]byte(stackV6Addr), xsum)
xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
// Calculate the UDP checksum and set it.
length := uint16(header.UDPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum, length))
// Inject packet.
c.linkEP.Inject(ipv6.ProtocolNumber, buf.ToVectorisedView())
}
func (c *testContext) sendPacket(payload []byte, h *headers) {
// Allocate a buffer for data and headers.
buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
copy(buf[len(buf)-len(payload):], payload)
// Initialize the IP header.
ip := header.IPv4(buf)
ip.Encode(&header.IPv4Fields{
IHL: header.IPv4MinimumSize,
TotalLength: uint16(len(buf)),
TTL: 65,
Protocol: uint8(udp.ProtocolNumber),
SrcAddr: testAddr,
DstAddr: stackAddr,
})
ip.SetChecksum(^ip.CalculateChecksum())
// Initialize the UDP header.
u := header.UDP(buf[header.IPv4MinimumSize:])
u.Encode(&header.UDPFields{
SrcPort: h.srcPort,
DstPort: h.dstPort,
Length: uint16(header.UDPMinimumSize + len(payload)),
})
// Calculate the UDP pseudo-header checksum.
xsum := header.Checksum([]byte(testAddr), 0)
xsum = header.Checksum([]byte(stackAddr), xsum)
xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
// Calculate the UDP checksum and set it.
length := uint16(header.UDPMinimumSize + len(payload))
xsum = header.Checksum(payload, xsum)
u.SetChecksum(^u.CalculateChecksum(xsum, length))
// Inject packet.
c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView())
}
func newPayload() []byte {
b := make([]byte, 30+rand.Intn(100))
for i := range b {
b[i] = byte(rand.Intn(256))
}
return b
}
func testV4Read(c *testContext) {
// Send a packet.
payload := newPayload()
c.sendPacket(payload, &headers{
srcPort: testPort,
dstPort: stackPort,
})
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
var addr tcpip.FullAddress
v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
v, _, err = c.ep.Read(&addr)
if err != nil {
c.t.Fatalf("Read failed: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for data")
}
}
// Check the peer address.
if addr.Addr != testAddr {
c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
}
}
func TestBindEphemeralPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
if err := c.ep.Bind(tcpip.FullAddress{}, nil); err != nil {
t.Fatalf("ep.Bind(...) failed: %v", err)
}
}
func TestBindReservedPort(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
addr, err := c.ep.GetLocalAddress()
if err != nil {
t.Fatalf("GetLocalAddress failed: %v", err)
}
// We can't bind the address reserved by the connected endpoint above.
{
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep.Close()
if got, want := ep.Bind(addr, nil), tcpip.ErrPortInUse; got != want {
t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
}
}
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep.Close()
// We can't bind ipv4-any on the port reserved by the connected endpoint
// above, since the endpoint is dual-stack.
if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil), tcpip.ErrPortInUse; got != want {
t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want)
}
// We can bind an ipv4 address on this port, though.
if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}, nil); err != nil {
t.Fatalf("ep.Bind(...) failed: %v", err)
}
}()
// Once the connected endpoint releases its port reservation, we are able to
// bind ipv4-any once again.
c.ep.Close()
func() {
ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
t.Fatalf("NewEndpoint failed: %v", err)
}
defer ep.Close()
if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}, nil); err != nil {
t.Fatalf("ep.Bind(...) failed: %v", err)
}
}()
}
func TestV4ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to v4 mapped wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to local address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func TestV6ReadOnV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Send a packet.
payload := newPayload()
c.sendV6Packet(payload, &headers{
srcPort: testPort,
dstPort: stackPort,
})
// Try to receive the data.
we, ch := waiter.NewChannelEntry(nil)
c.wq.EventRegister(&we, waiter.EventIn)
defer c.wq.EventUnregister(&we)
var addr tcpip.FullAddress
v, _, err := c.ep.Read(&addr)
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
v, _, err = c.ep.Read(&addr)
if err != nil {
c.t.Fatalf("Read failed: %v", err)
}
case <-time.After(1 * time.Second):
c.t.Fatalf("Timed out waiting for data")
}
}
// Check the peer address.
if addr.Addr != testV6Addr {
c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
}
// Check the payload.
if !bytes.Equal(payload, v) {
c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
}
}
func TestV4ReadOnV4(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
// Create v4 UDP endpoint.
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Test acceptance.
testV4Read(c)
}
func testV4Write(c *testContext) uint16 {
// Write to V4 mapped address.
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet.
b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
checker.DstPort(testPort),
),
)
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
return udp.SourcePort()
}
func testV6Write(c *testContext) uint16 {
// Write to v6 address.
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet.
b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
checker.DstPort(testPort),
),
)
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
return udp.SourcePort()
}
func testDualWrite(c *testContext) uint16 {
v4Port := testV4Write(c)
v6Port := testV6Write(c)
if v4Port != v6Port {
c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port)
}
return v4Port
}
func TestDualWriteUnbound(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
testDualWrite(c)
}
func TestDualWriteBoundToWildcard(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
p := testDualWrite(c)
if p != stackPort {
c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
}
}
func TestDualWriteConnectedToV6(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
testV6Write(c)
// Write to V4 mapped address.
payload := buffer.View(newPayload())
_, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
})
if err != tcpip.ErrNetworkUnreachable {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable)
}
}
func TestDualWriteConnectedToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
testV4Write(c)
// Write to v6 address.
payload := buffer.View(newPayload())
_, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
})
if err != tcpip.ErrInvalidEndpointState {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
}
}
func TestV4WriteOnV6Only(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(true)
// Write to V4 mapped address.
payload := buffer.View(newPayload())
_, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
})
if err != tcpip.ErrNoRoute {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
}
}
func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Bind to v4 mapped address.
if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
// Write to v6 address.
payload := buffer.View(newPayload())
_, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
})
if err != tcpip.ErrInvalidEndpointState {
c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState)
}
}
func TestV6WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v6 address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
// Write without destination.
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet.
b := c.getPacket(ipv6.ProtocolNumber, false)
udp := header.UDP(header.IPv6(b).Payload())
checker.IPv6(c.t, b,
checker.UDP(
checker.DstPort(testPort),
),
)
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
}
func TestV4WriteOnConnected(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
// Connect to v4 mapped address.
if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
c.t.Fatalf("Connect failed: %v", err)
}
// Write without destination.
payload := buffer.View(newPayload())
n, _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
}
// Check that we received the packet.
b := c.getPacket(ipv4.ProtocolNumber, false)
udp := header.UDP(header.IPv4(b).Payload())
checker.IPv4(c.t, b,
checker.UDP(
checker.DstPort(testPort),
),
)
// Check the payload.
if !bytes.Equal(payload, udp.Payload()) {
c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
}
}
func TestReadIncrementsPacketsReceived(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
// Create IPv4 UDP endpoint
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
// Bind to wildcard.
if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
c.t.Fatalf("Bind failed: %v", err)
}
testV4Read(c)
var want uint64 = 1
if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want {
c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want)
}
}
func TestWriteIncrementsPacketsSent(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
c.createV6Endpoint(false)
testDualWrite(c)
var want uint64 = 2
if got := c.s.Stats().UDP.PacketsSent.Value(); got != want {
c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want)
}
}
func TestTTL(t *testing.T) {
payload := tcpip.SlicePayload(buffer.View(newPayload()))
for _, name := range []string{"v4", "v6", "dual"} {
t.Run(name, func(t *testing.T) {
var networkProtocolNumber tcpip.NetworkProtocolNumber
switch name {
case "v4":
networkProtocolNumber = ipv4.ProtocolNumber
case "v6", "dual":
networkProtocolNumber = ipv6.ProtocolNumber
default:
t.Fatal("unknown test variant")
}
var variants []string
switch name {
case "v4":
variants = []string{"v4"}
case "v6":
variants = []string{"v6"}
case "dual":
variants = []string{"v6", "mapped"}
}
for _, variant := range variants {
t.Run(variant, func(t *testing.T) {
for _, typ := range []string{"unicast", "multicast"} {
t.Run(typ, func(t *testing.T) {
var addr tcpip.Address
var port uint16
switch typ {
case "unicast":
port = testPort
switch variant {
case "v4":
addr = testAddr
case "mapped":
addr = testV4MappedAddr
case "v6":
addr = testV6Addr
default:
t.Fatal("unknown test variant")
}
case "multicast":
port = multicastPort
switch variant {
case "v4":
addr = multicastAddr
case "mapped":
addr = multicastV4MappedAddr
case "v6":
addr = multicastV6Addr
default:
t.Fatal("unknown test variant")
}
default:
t.Fatal("unknown test variant")
}
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
var err *tcpip.Error
c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, networkProtocolNumber, &c.wq)
if err != nil {
c.t.Fatalf("NewEndpoint failed: %v", err)
}
switch name {
case "v4":
case "v6":
if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(1)); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
case "dual":
if err := c.ep.SetSockOpt(tcpip.V6OnlyOption(0)); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
default:
t.Fatal("unknown test variant")
}
const multicastTTL = 42
if err := c.ep.SetSockOpt(tcpip.MulticastTTLOption(multicastTTL)); err != nil {
c.t.Fatalf("SetSockOpt failed: %v", err)
}
n, _, err := c.ep.Write(payload, tcpip.WriteOptions{To: &tcpip.FullAddress{Addr: addr, Port: port}})
if err != nil {
c.t.Fatalf("Write failed: %v", err)
}
if n != uintptr(len(payload)) {
c.t.Fatalf("got c.ep.Write(...) = %d, want = %d", n, len(payload))
}
checkerFn := checker.IPv4
switch variant {
case "v4", "mapped":
case "v6":
checkerFn = checker.IPv6
default:
t.Fatal("unknown test variant")
}
var wantTTL uint8
var multicast bool
switch typ {
case "unicast":
multicast = false
switch variant {
case "v4", "mapped":
ep, err := ipv4.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
if err != nil {
t.Fatal(err)
}
wantTTL = ep.DefaultTTL()
ep.Close()
case "v6":
ep, err := ipv6.NewProtocol().NewEndpoint(0, "", nil, nil, nil)
if err != nil {
t.Fatal(err)
}
wantTTL = ep.DefaultTTL()
ep.Close()
default:
t.Fatal("unknown test variant")
}
case "multicast":
wantTTL = multicastTTL
multicast = true
default:
t.Fatal("unknown test variant")
}
var networkProtocolNumber tcpip.NetworkProtocolNumber
switch variant {
case "v4", "mapped":
networkProtocolNumber = ipv4.ProtocolNumber
case "v6":
networkProtocolNumber = ipv6.ProtocolNumber
default:
t.Fatal("unknown test variant")
}
b := c.getPacket(networkProtocolNumber, multicast)
checkerFn(c.t, b,
checker.TTL(wantTTL),
checker.UDP(
checker.DstPort(port),
),
)
})
}
})
}
})
}
}

Some files were not shown because too many files have changed in this diff Show More