This commit is contained in:
impact-eintr
2022-12-07 19:43:27 +08:00
parent d4d5c61a83
commit de9a9295b5
15 changed files with 2709 additions and 2709 deletions

View File

@@ -1,280 +1,280 @@
package main package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"net" "net"
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/link/fdbased" "netstack/tcpip/link/fdbased"
"netstack/tcpip/link/tuntap" "netstack/tcpip/link/tuntap"
"netstack/tcpip/network/arp" "netstack/tcpip/network/arp"
"netstack/tcpip/network/ipv4" "netstack/tcpip/network/ipv4"
"netstack/tcpip/network/ipv6" "netstack/tcpip/network/ipv6"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"netstack/tcpip/transport/tcp" "netstack/tcpip/transport/tcp"
"netstack/tcpip/transport/udp" "netstack/tcpip/transport/udp"
"netstack/waiter" "netstack/waiter"
"os" "os"
"os/signal" "os/signal"
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
) )
var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
func main() { func main() {
flag.Parse() flag.Parse()
if len(flag.Args()) != 4 { if len(flag.Args()) != 4 {
log.Fatal("Usage: ", os.Args[0], " <tap-device> <local-address/mask> <ip-address> <local-port>") log.Fatal("Usage: ", os.Args[0], " <tap-device> <local-address/mask> <ip-address> <local-port>")
} }
log.SetFlags(log.Lshortfile | log.LstdFlags) log.SetFlags(log.Lshortfile | log.LstdFlags)
tapName := flag.Arg(0) tapName := flag.Arg(0)
cidrName := flag.Arg(1) cidrName := flag.Arg(1)
addrName := flag.Arg(2) addrName := flag.Arg(2)
portName := flag.Arg(3) portName := flag.Arg(3)
log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName) log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName)
maddr, err := net.ParseMAC(*mac) maddr, err := net.ParseMAC(*mac)
if err != nil { if err != nil {
log.Fatalf("Bad MAC address: %v", *mac) log.Fatalf("Bad MAC address: %v", *mac)
} }
parsedAddr := net.ParseIP(addrName) parsedAddr := net.ParseIP(addrName)
if err != nil { if err != nil {
log.Fatalf("Bad addrress: %v", addrName) log.Fatalf("Bad addrress: %v", addrName)
} }
// 解析地址ip地址ipv4或者ipv6地址都支持 // 解析地址ip地址ipv4或者ipv6地址都支持
var addr tcpip.Address var addr tcpip.Address
var proto tcpip.NetworkProtocolNumber var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil { if parsedAddr.To4() != nil {
addr = tcpip.Address(parsedAddr.To4()) addr = tcpip.Address(parsedAddr.To4())
proto = ipv4.ProtocolNumber proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil { } else if parsedAddr.To16() != nil {
addr = tcpip.Address(parsedAddr.To16()) addr = tcpip.Address(parsedAddr.To16())
proto = ipv6.ProtocolNumber proto = ipv6.ProtocolNumber
} else { } else {
log.Fatalf("Unknown IP type: %v", parsedAddr) log.Fatalf("Unknown IP type: %v", parsedAddr)
} }
localPort, err := strconv.Atoi(portName) localPort, err := strconv.Atoi(portName)
if err != nil { if err != nil {
log.Fatalf("Unable to convert port %v: %v", portName, err) log.Fatalf("Unable to convert port %v: %v", portName, err)
} }
// 虚拟网卡配置 // 虚拟网卡配置
conf := &tuntap.Config{ conf := &tuntap.Config{
Name: tapName, Name: tapName,
Mode: tuntap.TAP, Mode: tuntap.TAP,
} }
var fd int var fd int
// 新建虚拟网卡 // 新建虚拟网卡
fd, err = tuntap.NewNetDev(conf) fd, err = tuntap.NewNetDev(conf)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// 启动tap网卡 // 启动tap网卡
_ = tuntap.SetLinkUp(tapName) _ = tuntap.SetLinkUp(tapName)
// 设置路由 // 设置路由
_ = tuntap.SetRoute(tapName, cidrName) _ = tuntap.SetRoute(tapName, cidrName)
// 抽象的文件接口 // 抽象的文件接口
linkID := fdbased.New(&fdbased.Options{ linkID := fdbased.New(&fdbased.Options{
FD: fd, FD: fd,
MTU: 1500, MTU: 1500,
Address: tcpip.LinkAddress(maddr), Address: tcpip.LinkAddress(maddr),
ResolutionRequired: true, ResolutionRequired: true,
}) })
// 新建相关协议的协议栈 // 新建相关协议的协议栈
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName},
[]string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
// 新建抽象的网卡 // 新建抽象的网卡
if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil { if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil {
log.Fatal(err) log.Fatal(err)
} }
// 在该协议栈上添加和注册相应的网络层 // 在该协议栈上添加和注册相应的网络层
if err := s.AddAddress(1, proto, addr); err != nil { if err := s.AddAddress(1, proto, addr); err != nil {
log.Fatal(err) log.Fatal(err)
} }
// 在该协议栈上添加和注册ARP协议 // 在该协议栈上添加和注册ARP协议
if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
log.Fatal(err) log.Fatal(err)
} }
// 添加默认路由 // 添加默认路由
s.SetRouteTable([]tcpip.Route{ s.SetRouteTable([]tcpip.Route{
{ {
Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))),
Gateway: "", Gateway: "",
NIC: 1, NIC: 1,
}, },
}) })
//go func() { // echo server //go func() { // echo server
// // 监听udp localPort端口 // // 监听udp localPort端口
// conn := udpListen(s, proto, addr, localPort) // conn := udpListen(s, proto, addr, localPort)
// for { // for {
// buf := make([]byte, 1024) // buf := make([]byte, 1024)
// n, err := conn.Read(buf) // n, err := conn.Read(buf)
// if err != nil { // if err != nil {
// log.Println(err) // log.Println(err)
// break // break
// } // }
// log.Println("接收到数据", string(buf[:n])) // log.Println("接收到数据", string(buf[:n]))
// conn.Write([]byte("server echo")) // conn.Write([]byte("server echo"))
// } // }
// // 关闭监听服务,此时会释放端口 // // 关闭监听服务,此时会释放端口
// conn.Close() // conn.Close()
//}() //}()
go func() { // echo server go func() { // echo server
listener := tcpListen(s, proto, addr, localPort) listener := tcpListen(s, proto, addr, localPort)
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
continue continue
} }
conn.Read(nil) conn.Read(nil)
} }
}() }()
c := make(chan os.Signal) c := make(chan os.Signal)
signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2) signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2)
<-c <-c
} }
type TcpConn struct { type TcpConn struct {
raddr tcpip.FullAddress raddr tcpip.FullAddress
ep tcpip.Endpoint ep tcpip.Endpoint
wq *waiter.Queue wq *waiter.Queue
we *waiter.Entry we *waiter.Entry
notifyCh chan struct{} notifyCh chan struct{}
} }
// Accept 封装tcp的accept操作 // Accept 封装tcp的accept操作
func (conn *TcpConn) Accept() (tcpip.Endpoint, error) { func (conn *TcpConn) Accept() (tcpip.Endpoint, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn) conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we) defer conn.wq.EventUnregister(conn.we)
for { for {
ep, _, err := conn.ep.Accept() ep, _, err := conn.ep.Accept()
if err != nil { if err != nil {
if err == tcpip.ErrWouldBlock { if err == tcpip.ErrWouldBlock {
<-conn.notifyCh <-conn.notifyCh
continue continue
} }
return nil, fmt.Errorf("%s", err.String()) return nil, fmt.Errorf("%s", err.String())
} }
return ep, nil return ep, nil
} }
} }
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn { func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn {
var wq waiter.Queue var wq waiter.Queue
// 新建一个tcp端 // 新建一个tcp端
ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// 绑定IP和端口这里的IP地址为空表示绑定任何IP // 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 此时就会调用端口管理器 // 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil { if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err) log.Fatal("Bind failed: ", err)
} }
// 开始监听 // 开始监听
if err := ep.Listen(10); err != nil { if err := ep.Listen(10); err != nil {
log.Fatal("Listen failed: ", err) log.Fatal("Listen failed: ", err)
} }
waitEntry, notifyCh := waiter.NewChannelEntry(nil) waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &TcpConn{ return &TcpConn{
ep: ep, ep: ep,
wq: &wq, wq: &wq,
we: &waitEntry, we: &waitEntry,
notifyCh: notifyCh} notifyCh: notifyCh}
} }
type UdpConn struct { type UdpConn struct {
raddr tcpip.FullAddress raddr tcpip.FullAddress
ep tcpip.Endpoint ep tcpip.Endpoint
wq *waiter.Queue wq *waiter.Queue
we *waiter.Entry we *waiter.Entry
notifyCh chan struct{} notifyCh chan struct{}
} }
func (conn *UdpConn) Close() { func (conn *UdpConn) Close() {
conn.ep.Close() conn.ep.Close()
} }
func (conn *UdpConn) Read(rcv []byte) (int, error) { func (conn *UdpConn) Read(rcv []byte) (int, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn) conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we) defer conn.wq.EventUnregister(conn.we)
for { for {
buf, _, err := conn.ep.Read(&conn.raddr) buf, _, err := conn.ep.Read(&conn.raddr)
if err != nil { if err != nil {
if err == tcpip.ErrWouldBlock { if err == tcpip.ErrWouldBlock {
<-conn.notifyCh <-conn.notifyCh
continue continue
} }
return 0, fmt.Errorf("%s", err.String()) return 0, fmt.Errorf("%s", err.String())
} }
n := len(buf) n := len(buf)
if n > cap(rcv) { if n > cap(rcv) {
n = cap(rcv) n = cap(rcv)
} }
rcv = append(rcv[:0], buf[:n]...) rcv = append(rcv[:0], buf[:n]...)
return n, nil return n, nil
} }
} }
func (conn *UdpConn) Write(snd []byte) error { func (conn *UdpConn) Write(snd []byte) error {
for { for {
_, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr}) _, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr})
if err != nil { if err != nil {
if err == tcpip.ErrNoLinkAddress { if err == tcpip.ErrNoLinkAddress {
<-notifyCh <-notifyCh
continue continue
} }
return fmt.Errorf("%s", err.String()) return fmt.Errorf("%s", err.String())
} }
return nil return nil
} }
} }
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn { func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn {
var wq waiter.Queue var wq waiter.Queue
// 新建一个udp端 // 新建一个udp端
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
// 绑定IP和端口这里的IP地址为空表示绑定任何IP // 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现
// 此时就会调用端口管理器 // 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil { if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err) log.Fatal("Bind failed: ", err)
} }
waitEntry, notifyCh := waiter.NewChannelEntry(nil) waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &UdpConn{ return &UdpConn{
ep: ep, ep: ep,
wq: &wq, wq: &wq,
we: &waitEntry, we: &waitEntry,
notifyCh: notifyCh} notifyCh: notifyCh}
} }

View File

@@ -1,101 +1,101 @@
package main package main
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/header" "netstack/tcpip/header"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"netstack/tcpip/transport/udp" "netstack/tcpip/transport/udp"
"netstack/waiter" "netstack/waiter"
"runtime" "runtime"
"strings" "strings"
) )
type TCPHandler interface { type TCPHandler interface {
Handle(net.Conn) Handle(net.Conn)
} }
func TCPServer(listener net.Listener, handler TCPHandler) error { func TCPServer(listener net.Listener, handler TCPHandler) error {
log.Printf("netstack 网络解析地址: %s", listener.Addr()) log.Printf("netstack 网络解析地址: %s", listener.Addr())
for { for {
clientConn, err := listener.Accept() clientConn, err := listener.Accept()
if err != nil { if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() { if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
log.Printf("temporary Accept() failure - %s", err) log.Printf("temporary Accept() failure - %s", err)
runtime.Gosched() runtime.Gosched()
continue continue
} }
// theres no direct way to detect this error because it is not exposed // theres no direct way to detect this error because it is not exposed
if !strings.Contains(err.Error(), "use of closed network connection") { if !strings.Contains(err.Error(), "use of closed network connection") {
return fmt.Errorf("listener.Accept() error - %s", err) return fmt.Errorf("listener.Accept() error - %s", err)
} }
break break
} }
go handler.Handle(clientConn) go handler.Handle(clientConn)
} }
log.Printf("TCP: closing %s", listener.Addr()) log.Printf("TCP: closing %s", listener.Addr())
return nil return nil
} }
var transportPool = make(map[uint64]tcpip.Endpoint) var transportPool = make(map[uint64]tcpip.Endpoint)
type RCV struct { type RCV struct {
*stack.Stack *stack.Stack
ep tcpip.Endpoint ep tcpip.Endpoint
addr tcpip.FullAddress addr tcpip.FullAddress
rcvBuf []byte rcvBuf []byte
} }
func (r *RCV) Handle(conn net.Conn) { func (r *RCV) Handle(conn net.Conn) {
var err error var err error
r.rcvBuf, err = io.ReadAll(conn) r.rcvBuf, err = io.ReadAll(conn)
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
panic(err) panic(err)
} }
switch string(r.rcvBuf[:3]) { switch string(r.rcvBuf[:3]) {
case "udp": case "udp":
var wq waiter.Queue var wq waiter.Queue
// 新建一个udp端 // 新建一个udp端
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
r.ep = ep r.ep = ep
r.Bind() r.Bind()
r.Connect() r.Connect()
r.Close() r.Close()
case "tcp": case "tcp":
default: default:
return return
} }
} }
func (r *RCV) Bind() { func (r *RCV) Bind() {
if len(r.rcvBuf) < 9 { // udp ip port if len(r.rcvBuf) < 9 { // udp ip port
log.Println("Error: too few arg") log.Println("Error: too few arg")
return return
} }
port := binary.BigEndian.Uint16(r.rcvBuf[7:9]) port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
r.addr = tcpip.FullAddress{ r.addr = tcpip.FullAddress{
NIC: 1, NIC: 1,
Addr: tcpip.Address(r.rcvBuf[3:7]), Addr: tcpip.Address(r.rcvBuf[3:7]),
Port: port, Port: port,
} }
r.ep.Bind(r.addr, nil) r.ep.Bind(r.addr, nil)
} }
func (r *RCV) Connect() { func (r *RCV) Connect() {
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888}) r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
} }
func (r *RCV) Close() { func (r *RCV) Close() {
r.ep.Close() r.ep.Close()
} }

View File

@@ -1,41 +1,41 @@
package main package main
import ( import (
"flag" "flag"
"log" "log"
"net" "net"
) )
func main() { func main() {
var ( var (
addr = flag.String("a", "192.168.1.1:9999", "udp dst address") addr = flag.String("a", "192.168.1.1:9999", "udp dst address")
) )
log.SetFlags(log.Lshortfile | log.LstdFlags) log.SetFlags(log.Lshortfile | log.LstdFlags)
var err error var err error
udpAddr, err := net.ResolveUDPAddr("udp", *addr) udpAddr, err := net.ResolveUDPAddr("udp", *addr)
if err != nil { if err != nil {
panic(err) panic(err)
} }
// 建立UDP连接只是填息了目的IP和端口并未真正的建立连接 // 建立UDP连接只是填息了目的IP和端口并未真正的建立连接
conn, err := net.DialUDP("udp", nil, udpAddr) conn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil { if err != nil {
panic(err) panic(err)
} }
//send := []byte("hello world") //send := []byte("hello world")
send := make([]byte, 1600) send := make([]byte, 1600)
if _, err := conn.Write(send); err != nil { if _, err := conn.Write(send); err != nil {
panic(err) panic(err)
} }
log.Printf("send: %s", string(send)) log.Printf("send: %s", string(send))
recv := make([]byte, 32) recv := make([]byte, 32)
rn, _, err := conn.ReadFrom(recv) rn, _, err := conn.ReadFrom(recv)
if err != nil { if err != nil {
panic(err) panic(err)
} }
log.Printf("recv: %s", string(recv[:rn])) log.Printf("recv: %s", string(recv[:rn]))
} }

File diff suppressed because it is too large Load Diff

View File

@@ -1,50 +1,50 @@
package seqnum package seqnum
// Value represents the value of a sequence number. // Value represents the value of a sequence number.
type Value uint32 type Value uint32
// Size represents the size (length) of a sequence number window // Size represents the size (length) of a sequence number window
type Size uint32 type Size uint32
// LessThan v < w // LessThan v < w
func (v Value) LessThan(w Value) bool { func (v Value) LessThan(w Value) bool {
return int32(v-w) < 0 return int32(v-w) < 0
} }
// LessThanEq returns true if v==w or v is before i.e., v < w. // LessThanEq returns true if v==w or v is before i.e., v < w.
func (v Value) LessThanEq(w Value) bool { func (v Value) LessThanEq(w Value) bool {
if v == w { if v == w {
return true return true
} }
return v.LessThan(w) return v.LessThan(w)
} }
// InRange v ∈ [a, b) // InRange v ∈ [a, b)
func (v Value) InRange(a, b Value) bool { func (v Value) InRange(a, b Value) bool {
return a <= v && v < b return a <= v && v < b
} }
// InWindows check v in [first, first+size) // InWindows check v in [first, first+size)
func (v Value) InWindows(first Value, size Size) bool { func (v Value) InWindows(first Value, size Size) bool {
return v.InRange(first, first.Add(size)) return v.InRange(first, first.Add(size))
} }
// Add return v + s // Add return v + s
func (v Value) Add(s Size) Value { func (v Value) Add(s Size) Value {
return v + Value(s) return v + Value(s)
} }
// Size return the size of [v, w) // Size return the size of [v, w)
func (v Value) Size(w Value) Size { func (v Value) Size(w Value) Size {
return Size(w - v) return Size(w - v)
} }
// UpdateForward update the value to v+s // UpdateForward update the value to v+s
func (v *Value) UpdateForward(s Size) { func (v *Value) UpdateForward(s Size) {
*v += Value(s) *v += Value(s)
} }
// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). // Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
func Overlap(a Value, b Size, x Value, y Size) bool { func Overlap(a Value, b Size, x Value, y Size) bool {
return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
} }

View File

@@ -1,77 +1,77 @@
# TCP 协议 # TCP 协议
## tcp特点 ## tcp特点
1. tcp 是面向连接的传输协议。 1. tcp 是面向连接的传输协议。
2. tcp 的连接是端到端的。 2. tcp 的连接是端到端的。
3. tcp 提供可靠的传输。 3. tcp 提供可靠的传输。
4. tcp 的传输以字节流的方式。 4. tcp 的传输以字节流的方式。
5. tcp 提供全双工的通信。 5. tcp 提供全双工的通信。
6. tcp 有拥塞控制。 6. tcp 有拥塞控制。
![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555573949562.png) ![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555573949562.png)
``` sh ``` sh
0 1 2 3 0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Source Port | Destination Port | | Source Port | Destination Port |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Sequence Number | | Sequence Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Acknowledgment Number | | Acknowledgment Number |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Data | |U|A|P|R|S|F| | | Data | |U|A|P|R|S|F| |
| Offset| Reserved |R|C|S|S|Y|I| Window | | Offset| Reserved |R|C|S|S|Y|I| Window |
| | |G|K|H|T|N|N| | | | |G|K|H|T|N|N| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Checksum | Urgent Pointer | | Checksum | Urgent Pointer |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Options | Padding | | Options | Padding |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| data | | data |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
``` ```
1. 源端口和目的端口 各占 2 个字节,分别 tcp 连接的源端口和目的端口。关于端口的概念之前已经介绍过了。 1. 源端口和目的端口 各占 2 个字节,分别 tcp 连接的源端口和目的端口。关于端口的概念之前已经介绍过了。
2. 序号 占 4 字节,序号范围是[02^32 - 1],共 2^32即 4294967296个序号。序号增加到 2^32-1 后,下一个序号就又回到 0。TCP 是面向字节流的,在一个 TCP 连接中传送的字节流中的每一个字节都按顺序编号。整个要传送的字节流的起始序号ISN必须在连接建立时设置。首部中的序号字段值则是指的是本报文段所发送的数据的第一个字节的序号。例如一报文段的序号是 301而接待的数据共有 100 字节。这就表明:本报文段的数据的第一个字节的序号是 301最后一个字节的序号是 400。显然下一个报文段如果还有的话的数据序号应当从 401 开始,即下一个报文段的序号字段值应为 401。 2. 序号 占 4 字节,序号范围是[02^32 - 1],共 2^32即 4294967296个序号。序号增加到 2^32-1 后,下一个序号就又回到 0。TCP 是面向字节流的,在一个 TCP 连接中传送的字节流中的每一个字节都按顺序编号。整个要传送的字节流的起始序号ISN必须在连接建立时设置。首部中的序号字段值则是指的是本报文段所发送的数据的第一个字节的序号。例如一报文段的序号是 301而接待的数据共有 100 字节。这就表明:本报文段的数据的第一个字节的序号是 301最后一个字节的序号是 400。显然下一个报文段如果还有的话的数据序号应当从 401 开始,即下一个报文段的序号字段值应为 401。
3. 确认号 占 4 字节是期望收到对方下一个报文段的第一个数据字节的序号。例如B 正确收到了 A 发送过来的一个报文段,其序号字段值是 501而数据长度是 200 字节(序号 501~700这表明 B 正确收到了 A 发送的到序号 700 为止的数据。因此B 期望收到 A 的下一个数据序号是 701于是 B 在发送给 A 的确认报文段中把确认号置为 701。注意现在确认号不是 501也不是 700而是 701。 总之:若确认号为 N则表明到序号 N-1 为止的所有数据都已正确收到。TCP 除了第一个 SYN 报文之外,所有 TCP 报文都需要携带 ACK 状态位。 3. 确认号 占 4 字节是期望收到对方下一个报文段的第一个数据字节的序号。例如B 正确收到了 A 发送过来的一个报文段,其序号字段值是 501而数据长度是 200 字节(序号 501~700这表明 B 正确收到了 A 发送的到序号 700 为止的数据。因此B 期望收到 A 的下一个数据序号是 701于是 B 在发送给 A 的确认报文段中把确认号置为 701。注意现在确认号不是 501也不是 700而是 701。 总之:若确认号为 N则表明到序号 N-1 为止的所有数据都已正确收到。TCP 除了第一个 SYN 报文之外,所有 TCP 报文都需要携带 ACK 状态位。
4. 数据偏移 占 4 位,它指出 TCP 报文段的数据起始处距离 TCP 报文段的起始处有多远。这个字段实际上是指出 TCP 报文段的首部长度。由于首部中还有长度不确定的选项字段,因此数据偏移字段是必要的,但应注意,“数据偏移”的单位是 4 个字节,由于 4 位二进制数能表示的最大十进制数字是 15因此数据偏移的最大值是 60 字节。 4. 数据偏移 占 4 位,它指出 TCP 报文段的数据起始处距离 TCP 报文段的起始处有多远。这个字段实际上是指出 TCP 报文段的首部长度。由于首部中还有长度不确定的选项字段,因此数据偏移字段是必要的,但应注意,“数据偏移”的单位是 4 个字节,由于 4 位二进制数能表示的最大十进制数字是 15因此数据偏移的最大值是 60 字节。
5. 保留 占 6 位,保留为今后使用,但目前应置为 0。 5. 保留 占 6 位,保留为今后使用,但目前应置为 0。
6. 控制报文标志 6. 控制报文标志
- **紧急URGURGent** 当 URG=1 时,表明紧急指针字段有效。它告诉系统此报文段中有紧急数据,应尽快发送(相当于高优先级的数据),而不要按原来的排队顺序来传送。例如,已经发送了很长的一个程序要在远地的主机上运行。但后来发现了一些问题,需要取消该程序的运行,因此用户从键盘发出中断命令。如果不使用紧急数据,那么这两个字符将存储在接收 TCP 的缓存末尾。只有在所有的数据被处理完毕后这两个字符才被交付接收方的应用进程。这样做就浪费了很多时间。 当 URG 置为 1 时,发送应用进程就告诉发送方的 TCP 有紧急数据要传送。于是发送方 TCP 就把紧急数据插入到本报文段数据的最前面而在紧急数据后面的数据仍然是普通数据。这时要与首部中紧急指针Urgent Pointer字段配合使用。 - **紧急URGURGent** 当 URG=1 时,表明紧急指针字段有效。它告诉系统此报文段中有紧急数据,应尽快发送(相当于高优先级的数据),而不要按原来的排队顺序来传送。例如,已经发送了很长的一个程序要在远地的主机上运行。但后来发现了一些问题,需要取消该程序的运行,因此用户从键盘发出中断命令。如果不使用紧急数据,那么这两个字符将存储在接收 TCP 的缓存末尾。只有在所有的数据被处理完毕后这两个字符才被交付接收方的应用进程。这样做就浪费了很多时间。 当 URG 置为 1 时,发送应用进程就告诉发送方的 TCP 有紧急数据要传送。于是发送方 TCP 就把紧急数据插入到本报文段数据的最前面而在紧急数据后面的数据仍然是普通数据。这时要与首部中紧急指针Urgent Pointer字段配合使用。
- **确认ACKACKnowledgment** 仅当 ACK=1 时确认号字段才有效,当 ACK=0 时确认号无效。TCP 规定,在连接建立后所有的传送的报文段都必须把 ACK 置为 1。 - **确认ACKACKnowledgment** 仅当 ACK=1 时确认号字段才有效,当 ACK=0 时确认号无效。TCP 规定,在连接建立后所有的传送的报文段都必须把 ACK 置为 1。
- **推送 PSHPuSH** 当两个应用进程进行交互式的通信时有时在一端的应用进程希望在键入一个命令后立即就能收到对方的响应。在这种情况下TCP 就可以使用推送push操作。这时发送方 TCP 把 PSH 置为 1并立即创建一个报文段发送出去。接收方 TCP 收到 PSH=1 的报文段,就尽快地交付接收应用进程。 - **推送 PSHPuSH** 当两个应用进程进行交互式的通信时有时在一端的应用进程希望在键入一个命令后立即就能收到对方的响应。在这种情况下TCP 就可以使用推送push操作。这时发送方 TCP 把 PSH 置为 1并立即创建一个报文段发送出去。接收方 TCP 收到 PSH=1 的报文段,就尽快地交付接收应用进程。
- **复位RSTReSeT** 当 RST=1 时,表名 TCP 连接中出现了严重错误如由于主机崩溃或其他原因必须释放连接然后再重新建立传输连接。RST 置为 1 用来拒绝一个非法的报文段或拒绝打开一个连接。 - **复位RSTReSeT** 当 RST=1 时,表名 TCP 连接中出现了严重错误如由于主机崩溃或其他原因必须释放连接然后再重新建立传输连接。RST 置为 1 用来拒绝一个非法的报文段或拒绝打开一个连接。
- **同步SYNSYNchronization** 在连接建立时用来同步序号。当 SYN=1 而 ACK=0 时,表明这是一个连接请求报文段。对方若同意建立连接,则应在响应的报文段中使 SYN=1 和 ACK=1因此 SYN 置为 1 就表示这是一个连接请求或连接接受报文。 - **同步SYNSYNchronization** 在连接建立时用来同步序号。当 SYN=1 而 ACK=0 时,表明这是一个连接请求报文段。对方若同意建立连接,则应在响应的报文段中使 SYN=1 和 ACK=1因此 SYN 置为 1 就表示这是一个连接请求或连接接受报文。
- **终止FINFINis意思是“完”“终”** 用来释放一个连接。当 FIN=1 时,表明此报文段的发送发的数据已发送完毕,并要求释放运输连接。 - **终止FINFINis意思是“完”“终”** 用来释放一个连接。当 FIN=1 时,表明此报文段的发送发的数据已发送完毕,并要求释放运输连接。
7. 窗口 占 2 字节,窗口值是[02^16-1]之间的整数。窗口指的是发送本报文段的一方的接受窗口(而不是自己的发送窗口)。窗口值告诉对方:从本报文段首部中的确认号算起,接收方目前允许对方发送的数据量(以字节为单位)。之所以要有这个限制,是因为接收方的数据缓存空间是有限的。总之,窗口值作为接收方让发送方设置其发送窗口的依据,作为流量控制的依据,后面会详细介绍。 总之:窗口字段明确指出了现在允许对方发送的数据量。窗口值经常在动态变化。 7. 窗口 占 2 字节,窗口值是[02^16-1]之间的整数。窗口指的是发送本报文段的一方的接受窗口(而不是自己的发送窗口)。窗口值告诉对方:从本报文段首部中的确认号算起,接收方目前允许对方发送的数据量(以字节为单位)。之所以要有这个限制,是因为接收方的数据缓存空间是有限的。总之,窗口值作为接收方让发送方设置其发送窗口的依据,作为流量控制的依据,后面会详细介绍。 总之:窗口字段明确指出了现在允许对方发送的数据量。窗口值经常在动态变化。
8. 检验和 占 2 字节,检验和字段检验的范围包括首部和数据这两部分。和 UDP 用户数据报一样,在计算检验和时,要在 TCP 报文段的前面加上 12 字节的伪首部。伪首部的格式和 UDP 用户数据报的伪首部一样。但应把伪首部第 4 个字段中的 17 改为 6TCP 的协议号是 6把第 5 字段中的 UDP 中的长度改为 TCP 长度。接收方收到此报文段后,仍要加上这个伪首部来计算检验和。若使用 IPv6则相应的伪首部也要改变。 8. 检验和 占 2 字节,检验和字段检验的范围包括首部和数据这两部分。和 UDP 用户数据报一样,在计算检验和时,要在 TCP 报文段的前面加上 12 字节的伪首部。伪首部的格式和 UDP 用户数据报的伪首部一样。但应把伪首部第 4 个字段中的 17 改为 6TCP 的协议号是 6把第 5 字段中的 UDP 中的长度改为 TCP 长度。接收方收到此报文段后,仍要加上这个伪首部来计算检验和。若使用 IPv6则相应的伪首部也要改变。
9. 紧急指针 占 2 字节,紧急指针仅在 URG=1 时才有意义,它指出本报文段中的紧急数据的字节数(紧急数据结束后就是普通数据) 。因此在紧急指针指出了紧急数据的末尾在报文段中的位置。当所有紧急数据都处理完时TCP 就告诉应用程序恢复到正常操作。值得注意的是,即使窗口为 0 时也可以发送紧急数据。 9. 紧急指针 占 2 字节,紧急指针仅在 URG=1 时才有意义,它指出本报文段中的紧急数据的字节数(紧急数据结束后就是普通数据) 。因此在紧急指针指出了紧急数据的末尾在报文段中的位置。当所有紧急数据都处理完时TCP 就告诉应用程序恢复到正常操作。值得注意的是,即使窗口为 0 时也可以发送紧急数据。
10. 选项 选项长度可变,最长可达 40 字节。当没有使用“选项”时TCP 的首部长度是 20 字节。TCP 首部总长度由 TCP 头中的“数据偏移”字段决定,前面说了,最长偏移为 60 字节。那么“tcp 选项”的长度最大为 60-20=40 字节。 10. 选项 选项长度可变,最长可达 40 字节。当没有使用“选项”时TCP 的首部长度是 20 字节。TCP 首部总长度由 TCP 头中的“数据偏移”字段决定,前面说了,最长偏移为 60 字节。那么“tcp 选项”的长度最大为 60-20=40 字节。
## tcp选项 ## tcp选项
TCP 最初只规定了一种选项,即最大报文段长度 MSSMaximum Segment Szie。后来又增加了几个选项如窗口扩大选项、时间戳选项等下面说明常用的选项。 TCP 最初只规定了一种选项,即最大报文段长度 MSSMaximum Segment Szie。后来又增加了几个选项如窗口扩大选项、时间戳选项等下面说明常用的选项。
1. kind=0 是选项表结束选项。 1. kind=0 是选项表结束选项。
2. kind=1 是空操作nop选项 2. kind=1 是空操作nop选项
没有特殊含义,一般用于将 TCP 选项的总长度填充为 4 字节的整数倍,为啥需要 4 字节整数倍?因为前面讲了数据偏移字段的单位是 4 个字节。 没有特殊含义,一般用于将 TCP 选项的总长度填充为 4 字节的整数倍,为啥需要 4 字节整数倍?因为前面讲了数据偏移字段的单位是 4 个字节。
3. kind=2 是最大报文段长度选项 3. kind=2 是最大报文段长度选项
TCP 连接初始化时通信双方使用该选项来协商最大报文段长度Max Segment SizeMSS。TCP 模块通常将 MSS 设置为MTU-40字节减掉的这 40 字节包括 20 字节的 TCP 头部和 20 字节的 IP 头部)。这样携带 TCP 报文段的 IP 数据报的长度就不会超过 MTU假设 TCP 头部和 IP 头部都不包含选项字段,并且这也是一般情况),从而避免本机发生 IP 分片。对以太网而言MSS 值是 14601500-40字节。 TCP 连接初始化时通信双方使用该选项来协商最大报文段长度Max Segment SizeMSS。TCP 模块通常将 MSS 设置为MTU-40字节减掉的这 40 字节包括 20 字节的 TCP 头部和 20 字节的 IP 头部)。这样携带 TCP 报文段的 IP 数据报的长度就不会超过 MTU假设 TCP 头部和 IP 头部都不包含选项字段,并且这也是一般情况),从而避免本机发生 IP 分片。对以太网而言MSS 值是 14601500-40字节。
4. kind=3 是窗口扩大因子选项 4. kind=3 是窗口扩大因子选项
TCP 连接初始化时,通信双方使用该选项来协商接收通告窗口的扩大因子。在 TCP 的头部中,接收通告窗口大小是用 16 位表示的,故最大为 65535 字节,但实际上 TCP 模块允许的接收通告窗口大小远不止这个数(为了提高 TCP 通信的吞吐量)。窗口扩大因子解决了这个问题。假设 TCP 头部中的接收通告窗口大小是 N窗口扩大因子移位数是 M那么 TCP 报文段的实际接收通告窗口大小是 N 乘 2M或者说 N 左移 M 位。注意M 的取值范围是 0 14。 TCP 连接初始化时,通信双方使用该选项来协商接收通告窗口的扩大因子。在 TCP 的头部中,接收通告窗口大小是用 16 位表示的,故最大为 65535 字节,但实际上 TCP 模块允许的接收通告窗口大小远不止这个数(为了提高 TCP 通信的吞吐量)。窗口扩大因子解决了这个问题。假设 TCP 头部中的接收通告窗口大小是 N窗口扩大因子移位数是 M那么 TCP 报文段的实际接收通告窗口大小是 N 乘 2M或者说 N 左移 M 位。注意M 的取值范围是 0 14。
和 MSS 选项一样,窗口扩大因子选项只能出现在同步报文段中,否则将被忽略。但同步报文段本身不执行窗口扩大操作,即同步报文段头部的接收通告窗口大小就是该 TCP 报文段的实际接收通告窗口大小。当连接建立好之后,每个数据传输方向的窗口扩大因子就固定不变了。关于窗口扩大因子选项的细节,可参考标准文档 RFC 1323。 和 MSS 选项一样,窗口扩大因子选项只能出现在同步报文段中,否则将被忽略。但同步报文段本身不执行窗口扩大操作,即同步报文段头部的接收通告窗口大小就是该 TCP 报文段的实际接收通告窗口大小。当连接建立好之后,每个数据传输方向的窗口扩大因子就固定不变了。关于窗口扩大因子选项的细节,可参考标准文档 RFC 1323。
5. kind=4 是选择性确认Selective AcknowledgmentSACK选项 5. kind=4 是选择性确认Selective AcknowledgmentSACK选项
TCP 通信时,如果某个 TCP 报文段丢失,则 TCP 模块会重传最后被确认的 TCP 报文段后续的所有报文段,这样原先已经正确传输的 TCP 报文段也可能重复发送,从而降低了 TCP 性能。SACK 技术正是为改善这种情况而产生的,它使 TCP 模块只重新发送丢失的 TCP 报文段,不用发送所有未被确认的 TCP 报文段。选择性确认选项用在连接初始化时,表示是否支持 SACK 技术。 TCP 通信时,如果某个 TCP 报文段丢失,则 TCP 模块会重传最后被确认的 TCP 报文段后续的所有报文段,这样原先已经正确传输的 TCP 报文段也可能重复发送,从而降低了 TCP 性能。SACK 技术正是为改善这种情况而产生的,它使 TCP 模块只重新发送丢失的 TCP 报文段,不用发送所有未被确认的 TCP 报文段。选择性确认选项用在连接初始化时,表示是否支持 SACK 技术。
6. kind=5 是 SACK 实际工作的选项 6. kind=5 是 SACK 实际工作的选项
该选项的参数告诉发送方本端已经收到并缓存的不连续的数据块从而让发送端可以据此检查并重发丢失的数据块。每个块边沿edge of block参数包含一个 4 字节的序号。其中块左边沿表示不连续块的第一个数据的序号,而块右边沿则表示不连续块的最后一个数据的序号的下一个序号。这样一对参数(块左边沿和块右边沿)之间的数据是没有收到的。因为一个块信息占用 8 字节,所以 TCP 头部选项中实际上最多可以包含 4 个这样的不连续数据块(考虑选项类型和长度占用的 2 字节)。 该选项的参数告诉发送方本端已经收到并缓存的不连续的数据块从而让发送端可以据此检查并重发丢失的数据块。每个块边沿edge of block参数包含一个 4 字节的序号。其中块左边沿表示不连续块的第一个数据的序号,而块右边沿则表示不连续块的最后一个数据的序号的下一个序号。这样一对参数(块左边沿和块右边沿)之间的数据是没有收到的。因为一个块信息占用 8 字节,所以 TCP 头部选项中实际上最多可以包含 4 个这样的不连续数据块(考虑选项类型和长度占用的 2 字节)。
7. kind=8 是时间戳选项 7. kind=8 是时间戳选项
该选项提供了较为准确的计算通信双方之间的回路时间Round Trip TimeRTT的方法从而为 TCP 流量控制提供重要信息。 该选项提供了较为准确的计算通信双方之间的回路时间Round Trip TimeRTT的方法从而为 TCP 流量控制提供重要信息。

View File

@@ -1,330 +1,330 @@
package tcp package tcp
import ( import (
"crypto/rand" "crypto/rand"
"crypto/sha1" "crypto/sha1"
"encoding/binary" "encoding/binary"
"hash" "hash"
"io" "io"
"log" "log"
"netstack/sleep" "netstack/sleep"
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/header" "netstack/tcpip/header"
"netstack/tcpip/seqnum" "netstack/tcpip/seqnum"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"sync" "sync"
"time" "time"
) )
const ( const (
// tsLen is the length, in bits, of the timestamp in the SYN cookie. // tsLen is the length, in bits, of the timestamp in the SYN cookie.
tsLen = 8 tsLen = 8
// tsMask is a mask for timestamp values (i.e., tsLen bits). // tsMask is a mask for timestamp values (i.e., tsLen bits).
tsMask = (1 << tsLen) - 1 tsMask = (1 << tsLen) - 1
// tsOffset is the offset, in bits, of the timestamp in the SYN cookie. // tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
tsOffset = 24 tsOffset = 24
// hashMask is the mask for hash values (i.e., tsOffset bits). // hashMask is the mask for hash values (i.e., tsOffset bits).
hashMask = (1 << tsOffset) - 1 hashMask = (1 << tsOffset) - 1
// maxTSDiff is the maximum allowed difference between a received cookie // maxTSDiff is the maximum allowed difference between a received cookie
// timestamp and the current timestamp. If the difference is greater // timestamp and the current timestamp. If the difference is greater
// than maxTSDiff, the cookie is expired. // than maxTSDiff, the cookie is expired.
maxTSDiff = 2 maxTSDiff = 2
) )
var ( var (
// SynRcvdCountThreshold is the global maximum number of connections // SynRcvdCountThreshold is the global maximum number of connections
// that are allowed to be in SYN-RCVD state before TCP starts using SYN // that are allowed to be in SYN-RCVD state before TCP starts using SYN
// cookies to accept connections. // cookies to accept connections.
// //
// It is an exported variable only for testing, and should not otherwise // It is an exported variable only for testing, and should not otherwise
// be used by importers of this package. // be used by importers of this package.
SynRcvdCountThreshold uint64 = 1000 SynRcvdCountThreshold uint64 = 1000
// mssTable is a slice containing the possible MSS values that we // mssTable is a slice containing the possible MSS values that we
// encode in the SYN cookie with two bits. // encode in the SYN cookie with two bits.
mssTable = []uint16{536, 1300, 1440, 1460} mssTable = []uint16{536, 1300, 1440, 1460}
) )
func encodeMSS(mss uint16) uint32 { func encodeMSS(mss uint16) uint32 {
for i := len(mssTable) - 1; i > 0; i-- { for i := len(mssTable) - 1; i > 0; i-- {
if mss >= mssTable[i] { if mss >= mssTable[i] {
return uint32(i) return uint32(i)
} }
} }
return 0 return 0
} }
// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is // syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
// protected by a mutex so that we can increment only when it's guaranteed not // protected by a mutex so that we can increment only when it's guaranteed not
// to go above a threshold. // to go above a threshold.
var synRcvdCount struct { var synRcvdCount struct {
sync.Mutex sync.Mutex
value uint64 value uint64
pending sync.WaitGroup pending sync.WaitGroup
} }
// listenContext is used by a listening endpoint to store state used while // listenContext is used by a listening endpoint to store state used while
// listening for connections. This struct is allocated by the listen goroutine // listening for connections. This struct is allocated by the listen goroutine
// and must not be accessed or have its methods called concurrently as they // and must not be accessed or have its methods called concurrently as they
// may mutate the stored objects. // may mutate the stored objects.
type listenContext struct { type listenContext struct {
stack *stack.Stack stack *stack.Stack
rcvWnd seqnum.Size rcvWnd seqnum.Size
nonce [2][sha1.BlockSize]byte // nonce 随机数 nonce [2][sha1.BlockSize]byte // nonce 随机数
hasherMu sync.Mutex hasherMu sync.Mutex
hasher hash.Hash // 散列实现 hasher hash.Hash // 散列实现
v6only bool v6only bool
netProto tcpip.NetworkProtocolNumber netProto tcpip.NetworkProtocolNumber
} }
// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. // timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
func timeStamp() uint32 { func timeStamp() uint32 {
return uint32(time.Now().Unix()>>6) & tsMask // 00 00 00 FF return uint32(time.Now().Unix()>>6) & tsMask // 00 00 00 FF
} }
// 增加一个任务 最多1000个 // 增加一个任务 最多1000个
func incSynRcvdCount() bool { func incSynRcvdCount() bool {
synRcvdCount.Mutex.Lock() synRcvdCount.Mutex.Lock()
defer synRcvdCount.Unlock() defer synRcvdCount.Unlock()
if synRcvdCount.value >= SynRcvdCountThreshold { if synRcvdCount.value >= SynRcvdCountThreshold {
return false return false
} }
synRcvdCount.pending.Add(1) synRcvdCount.pending.Add(1)
synRcvdCount.value++ synRcvdCount.value++
return true return true
} }
// 结束一个任务 // 结束一个任务
func decSynRcvdCount() { func decSynRcvdCount() {
synRcvdCount.Mutex.Lock() synRcvdCount.Mutex.Lock()
defer synRcvdCount.Unlock() defer synRcvdCount.Unlock()
synRcvdCount.value-- synRcvdCount.value--
synRcvdCount.pending.Done() synRcvdCount.pending.Done()
} }
// newListenContext creates a new listen context. // newListenContext creates a new listen context.
func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
l := &listenContext{ l := &listenContext{
stack: stack, stack: stack,
rcvWnd: rcvWnd, rcvWnd: rcvWnd,
hasher: sha1.New(), hasher: sha1.New(),
v6only: v6only, v6only: v6only,
netProto: netProto, netProto: netProto,
} }
rand.Read(l.nonce[0][:]) rand.Read(l.nonce[0][:])
rand.Read(l.nonce[1][:]) rand.Read(l.nonce[1][:])
return l return l
} }
// cookieHash calculates the cookieHash for the given id, timestamp and nonce // cookieHash calculates the cookieHash for the given id, timestamp and nonce
// index. The hash is used to create and validate cookies. // index. The hash is used to create and validate cookies.
func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
// Initialize block with fixed-size data: local ports and v. // Initialize block with fixed-size data: local ports and v.
var payload [8]byte var payload [8]byte
binary.BigEndian.PutUint16(payload[0:], id.LocalPort) binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
binary.BigEndian.PutUint16(payload[2:], id.RemotePort) binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
binary.BigEndian.PutUint32(payload[4:], ts) binary.BigEndian.PutUint32(payload[4:], ts)
// Feed everything to the hasher. // Feed everything to the hasher.
l.hasherMu.Lock() l.hasherMu.Lock()
l.hasher.Reset() l.hasher.Reset()
l.hasher.Write(payload[:]) l.hasher.Write(payload[:])
l.hasher.Write(l.nonce[nonceIndex][:]) l.hasher.Write(l.nonce[nonceIndex][:])
io.WriteString(l.hasher, string(id.LocalAddress)) io.WriteString(l.hasher, string(id.LocalAddress))
io.WriteString(l.hasher, string(id.RemoteAddress)) io.WriteString(l.hasher, string(id.RemoteAddress))
// Finalize the calculation of the hash and return the first 4 bytes. // Finalize the calculation of the hash and return the first 4 bytes.
h := make([]byte, 0, sha1.Size) h := make([]byte, 0, sha1.Size)
h = l.hasher.Sum(h) h = l.hasher.Sum(h)
l.hasherMu.Unlock() l.hasherMu.Unlock()
return binary.BigEndian.Uint32(h[:]) return binary.BigEndian.Uint32(h[:])
} }
// createCookie creates a SYN cookie for the given id and incoming sequence // createCookie creates a SYN cookie for the given id and incoming sequence
// number. // number.
func (l *listenContext) createCookie(id stack.TransportEndpointID, func (l *listenContext) createCookie(id stack.TransportEndpointID,
seq seqnum.Value, data uint32) seqnum.Value { seq seqnum.Value, data uint32) seqnum.Value {
ts := timeStamp() ts := timeStamp()
v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
v += (l.cookieHash(id, ts, 1) + data) & hashMask v += (l.cookieHash(id, ts, 1) + data) & hashMask
return seqnum.Value(v) return seqnum.Value(v)
} }
// isCookieValid checks if the supplied cookie is valid for the given id and // isCookieValid checks if the supplied cookie is valid for the given id and
// sequence number. If it is, it also returns the data originally encoded in the // sequence number. If it is, it also returns the data originally encoded in the
// cookie when createCookie was called. // cookie when createCookie was called.
func (l *listenContext) isCookieValid(id stack.TransportEndpointID, func (l *listenContext) isCookieValid(id stack.TransportEndpointID,
cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
ts := timeStamp() ts := timeStamp()
v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
cookieTS := v >> tsOffset cookieTS := v >> tsOffset
if ((ts - cookieTS) & tsMask) > maxTSDiff { if ((ts - cookieTS) & tsMask) > maxTSDiff {
return 0, false return 0, false
} }
return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
} }
// 新建一个tcp端 这个tcp端与segment同属一个tcp连接 但属于不同阶段 用于写回远端 // 新建一个tcp端 这个tcp端与segment同属一个tcp连接 但属于不同阶段 用于写回远端
func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value,
irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// Create a new endpoint. // Create a new endpoint.
netProto := l.netProto netProto := l.netProto
if netProto == 0 { if netProto == 0 {
netProto = s.route.NetProto netProto = s.route.NetProto
} }
n := newEndpoint(l.stack, netProto, nil) n := newEndpoint(l.stack, netProto, nil)
n.v6only = l.v6only n.v6only = l.v6only
n.id = s.id n.id = s.id
n.boundNICID = s.route.NICID() n.boundNICID = s.route.NICID()
n.route = s.route.Clone() n.route = s.route.Clone()
n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
n.rcvBufSize = int(l.rcvWnd) n.rcvBufSize = int(l.rcvWnd)
n.maybeEnableTimestamp(rcvdSynOpts) n.maybeEnableTimestamp(rcvdSynOpts)
n.maybeEnableSACKPermitted(rcvdSynOpts) n.maybeEnableSACKPermitted(rcvdSynOpts)
// Register new endpoint so that packets are routed to it. // Register new endpoint so that packets are routed to it.
// 在网络协议栈中去注册这个tcp端 // 在网络协议栈中去注册这个tcp端
if err := n.stack.RegisterTransportEndpoint(n.boundNICID, if err := n.stack.RegisterTransportEndpoint(n.boundNICID,
n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil { n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
n.Close() n.Close()
return nil, err return nil, err
} }
n.isRegistered = true n.isRegistered = true
n.state = stateConnected n.state = stateConnected
// Create sender and receiver. // Create sender and receiver.
// The receiver at least temporarily has a zero receive window scale, // The receiver at least temporarily has a zero receive window scale,
// but the caller may change it (before starting the protocol loop). // but the caller may change it (before starting the protocol loop).
n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
n.rcv = newReceiver(n, irs, l.rcvWnd, 0) n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
return n, nil return n, nil
} }
func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
// create new endpoint // create new endpoint
irs := s.sequenceNumber irs := s.sequenceNumber
cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
log.Println("收到一个远端握手申请", irs, "标记cookie", cookie) log.Println("收到一个远端握手申请", irs, "标记cookie", cookie)
ep, err := l.createConnectedEndpoint(s, cookie, irs, opts) ep, err := l.createConnectedEndpoint(s, cookie, irs, opts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 以下执行三次握手 // 以下执行三次握手
// 构建handshake管理器 // 构建handshake管理器
h, err := newHandshake(ep, l.rcvWnd) h, err := newHandshake(ep, l.rcvWnd)
if err != nil { if err != nil {
ep.Close() ep.Close()
return nil, err return nil, err
} }
// 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack // 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack
h.resetToSynRcvd(cookie, irs, opts) h.resetToSynRcvd(cookie, irs, opts)
if err := h.execute(); err != nil { if err := h.execute(); err != nil {
ep.Close() ep.Close()
return nil, err return nil, err
} }
// 更新接收窗口扩张因子 // 更新接收窗口扩张因子
return ep, nil return ep, nil
} }
// 一旦侦听端点收到SYN段handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。 // 一旦侦听端点收到SYN段handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。
// 在TCP开始使用SYN cookie接受连接之前允许使用有限数量的这些goroutine。 // 在TCP开始使用SYN cookie接受连接之前允许使用有限数量的这些goroutine。
func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
defer decSynRcvdCount() defer decSynRcvdCount()
defer s.decRef() defer s.decRef()
_, err := ctx.createEndpointAndPerformHandshake(s, opts) _, err := ctx.createEndpointAndPerformHandshake(s, opts)
if err != nil { if err != nil {
return return
} }
// 到这里,三次握手已经完成,那么分发一个新的连接 // 到这里,三次握手已经完成,那么分发一个新的连接
//e.deliverAccepted(n) //e.deliverAccepted(n)
} }
// handleListenSegment is called when a listening endpoint receives a segment // handleListenSegment is called when a listening endpoint receives a segment
// and needs to handle it. // and needs to handle it.
func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
switch s.flags { switch s.flags {
case flagSyn: // syn报文处理 case flagSyn: // syn报文处理
// 分析tcp选项 // 分析tcp选项
opts := parseSynSegmentOptions(s) opts := parseSynSegmentOptions(s)
if incSynRcvdCount() { if incSynRcvdCount() {
s.incRef() s.incRef()
go e.handleSynSegment(ctx, s, &opts) go e.handleSynSegment(ctx, s, &opts)
} else { } else {
log.Println("暂时不处理") log.Println("暂时不处理")
} }
// 返回一个syn+ack报文 // 返回一个syn+ack报文
case flagFin: // fin报文处理 case flagFin: // fin报文处理
// 三次握手最后一次 ack 报文 // 三次握手最后一次 ack 报文
} }
} }
func parseSynSegmentOptions(s *segment) header.TCPSynOptions { func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck)) synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck))
if synOpts.TS { if synOpts.TS {
s.parsedOptions.TSVal = synOpts.TSVal s.parsedOptions.TSVal = synOpts.TSVal
s.parsedOptions.TSEcr = synOpts.TSEcr s.parsedOptions.TSEcr = synOpts.TSEcr
} }
return synOpts return synOpts
} }
// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行负责处理连接请求 // protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行负责处理连接请求
func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
defer func() { defer func() {
// TODO 后置处理 // TODO 后置处理
}() }()
e.mu.Lock() e.mu.Lock()
v6only := e.v6only v6only := e.v6only
e.mu.Unlock() e.mu.Unlock()
ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto) ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto)
// 初始化事件触发器 并添加事件 // 初始化事件触发器 并添加事件
s := sleep.Sleeper{} s := sleep.Sleeper{}
s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
s.AddWaker(&e.notificationWaker, wakerForNotification) s.AddWaker(&e.notificationWaker, wakerForNotification)
for { for {
switch index, _ := s.Fetch(true); index { // Fetch(true) 阻塞获取 switch index, _ := s.Fetch(true); index { // Fetch(true) 阻塞获取
case wakerForNewSegment: case wakerForNewSegment:
mayRequeue := true mayRequeue := true
// 接收和处理tcp报文 // 接收和处理tcp报文
for i := 0; i < maxSegmentsPerWake; i++ { for i := 0; i < maxSegmentsPerWake; i++ {
s := e.segmentQueue.dequeue() s := e.segmentQueue.dequeue()
if s == nil { if s == nil {
mayRequeue = false mayRequeue = false
break break
} }
e.handleListenSegment(ctx, s) e.handleListenSegment(ctx, s)
s.decRef() s.decRef()
} }
// If the queue is not empty, make sure we'll wake up // If the queue is not empty, make sure we'll wake up
// in the next iteration. // in the next iteration.
if mayRequeue && !e.segmentQueue.empty() { // 主协程又添加了新数据 if mayRequeue && !e.segmentQueue.empty() { // 主协程又添加了新数据
e.newSegmentWaker.Assert() // 重新尝试获取数据 e.newSegmentWaker.Assert() // 重新尝试获取数据
} }
case wakerForNotification: case wakerForNotification:
// TODO 触发其他事件 // TODO 触发其他事件
log.Println("其他事件?") log.Println("其他事件?")
} }
} }
} }

View File

@@ -1,364 +1,364 @@
package tcp package tcp
import ( import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"log" "log"
"netstack/sleep" "netstack/sleep"
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/buffer" "netstack/tcpip/buffer"
"netstack/tcpip/header" "netstack/tcpip/header"
"netstack/tcpip/seqnum" "netstack/tcpip/seqnum"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"sync" "sync"
"time" "time"
) )
const maxSegmentsPerWake = 100 const maxSegmentsPerWake = 100
type handshakeState int type handshakeState int
const ( const (
handshakeSynSent handshakeState = iota handshakeSynSent handshakeState = iota
handshakeSynRcvd handshakeSynRcvd
handshakeCompleted handshakeCompleted
) )
// The following are used to set up sleepers. // The following are used to set up sleepers.
const ( const (
wakerForNotification = iota wakerForNotification = iota
wakerForNewSegment wakerForNewSegment
wakerForResend wakerForResend
wakerForResolution wakerForResolution
) )
// handshake holds the state used during a TCP 3-way handshake. // handshake holds the state used during a TCP 3-way handshake.
// tcp三次握手时候使用的对象 // tcp三次握手时候使用的对象
type handshake struct { type handshake struct {
ep *endpoint ep *endpoint
// 握手的状态 // 握手的状态
state handshakeState state handshakeState
active bool active bool
flags uint8 flags uint8
ackNum seqnum.Value ackNum seqnum.Value
// iss is the initial send sequence number, as defined in RFC 793. // iss is the initial send sequence number, as defined in RFC 793.
// 初始序列号 // 初始序列号
iss seqnum.Value iss seqnum.Value
// rcvWnd is the receive window, as defined in RFC 793. // rcvWnd is the receive window, as defined in RFC 793.
// 接收窗口 // 接收窗口
rcvWnd seqnum.Size rcvWnd seqnum.Size
// sndWnd is the send window, as defined in RFC 793. // sndWnd is the send window, as defined in RFC 793.
// 发送窗口 // 发送窗口
sndWnd seqnum.Size sndWnd seqnum.Size
// mss is the maximum segment size received from the peer. // mss is the maximum segment size received from the peer.
// 最大报文段大小 // 最大报文段大小
mss uint16 mss uint16
// sndWndScale is the send window scale, as defined in RFC 1323. A // sndWndScale is the send window scale, as defined in RFC 1323. A
// negative value means no scaling is supported by the peer. // negative value means no scaling is supported by the peer.
// 发送窗口扩展因子 // 发送窗口扩展因子
sndWndScale int sndWndScale int
// rcvWndScale is the receive window scale, as defined in RFC 1323. // rcvWndScale is the receive window scale, as defined in RFC 1323.
// 接收窗口扩展因子 // 接收窗口扩展因子
rcvWndScale int rcvWndScale int
} }
const ( const (
// Maximum space available for options. // Maximum space available for options.
// tcp选项的最大长度 // tcp选项的最大长度
maxOptionSize = 40 maxOptionSize = 40
) )
func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) {
h := handshake{ h := handshake{
ep: ep, ep: ep,
active: true, // 激活这个管理器 active: true, // 激活这个管理器
rcvWnd: rcvWnd, // 初始接收窗口 rcvWnd: rcvWnd, // 初始接收窗口
// TODO // TODO
} }
if err := h.resetState(); err != nil { if err := h.resetState(); err != nil {
return handshake{}, err return handshake{}, err
} }
return h, nil return h, nil
} }
func (h *handshake) resetState() *tcpip.Error { func (h *handshake) resetState() *tcpip.Error {
// 随机一个iss(对方将收到的序号) 防止黑客搞事 // 随机一个iss(对方将收到的序号) 防止黑客搞事
b := make([]byte, 4) b := make([]byte, 4)
if _, err := rand.Read(b); err != nil { if _, err := rand.Read(b); err != nil {
panic(err) panic(err)
} }
// 初始化状态为 SynSent // 初始化状态为 SynSent
h.state = handshakeSynSent h.state = handshakeSynSent
log.Println("收到 syn 同步报文 设置tcp状态为 [sent]") log.Println("收到 syn 同步报文 设置tcp状态为 [sent]")
h.flags = flagSyn h.flags = flagSyn
h.ackNum = 0 h.ackNum = 0
h.mss = 0 h.mss = 0
h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
return nil return nil
} }
// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD // resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
// state. // state.
func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
h.active = false h.active = false
h.state = handshakeSynRcvd h.state = handshakeSynRcvd
log.Println("发送 syn|ack 确认报文 设置tcp状态为 [rcvd]") log.Println("发送 syn|ack 确认报文 设置tcp状态为 [rcvd]")
h.flags = flagSyn | flagAck h.flags = flagSyn | flagAck
h.iss = iss h.iss = iss
h.ackNum = irs + 1 // NOTE ACK = synNum + 1 h.ackNum = irs + 1 // NOTE ACK = synNum + 1
h.mss = opts.MSS h.mss = opts.MSS
h.sndWndScale = opts.WS h.sndWndScale = opts.WS
} }
func (h *handshake) resolveRoute() *tcpip.Error { func (h *handshake) resolveRoute() *tcpip.Error {
log.Printf("tcp resolveRoute") log.Printf("tcp resolveRoute")
// Set up the wakers. // Set up the wakers.
s := sleep.Sleeper{} s := sleep.Sleeper{}
resolutionWaker := &sleep.Waker{} resolutionWaker := &sleep.Waker{}
s.AddWaker(resolutionWaker, wakerForResolution) s.AddWaker(resolutionWaker, wakerForResolution)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification) s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
defer s.Done() defer s.Done()
// Initial action is to resolve route. // Initial action is to resolve route.
index := wakerForResolution index := wakerForResolution
for { for {
log.Println(index) log.Println(index)
switch index { switch index {
case wakerForResolution: case wakerForResolution:
if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
// Either success (err == nil) or failure. // Either success (err == nil) or failure.
return err return err
} }
// Resolution not completed. Keep trying... // Resolution not completed. Keep trying...
case wakerForNotification: case wakerForNotification:
// TODO // TODO
//n := h.ep.fetchNotifications() //n := h.ep.fetchNotifications()
//if n&notifyClose != 0 { //if n&notifyClose != 0 {
// h.ep.route.RemoveWaker(resolutionWaker) // h.ep.route.RemoveWaker(resolutionWaker)
// return tcpip.ErrAborted // return tcpip.ErrAborted
//} //}
//if n&notifyDrain != 0 { //if n&notifyDrain != 0 {
// close(h.ep.drainDone) // close(h.ep.drainDone)
// <-h.ep.undrain // <-h.ep.undrain
//} //}
} }
// Wait for notification. // Wait for notification.
index, _ = s.Fetch(true) index, _ = s.Fetch(true)
} }
} }
// execute executes the TCP 3-way handshake. // execute executes the TCP 3-way handshake.
// 执行tcp 3次握手客户端和服务端都是调用该函数来实现三次握手 // 执行tcp 3次握手客户端和服务端都是调用该函数来实现三次握手
/* /*
c flag s c flag s
| | | |
sync_sent|------sync---->|sync_rcvd sync_sent|------sync---->|sync_rcvd
| | | |
| | | |
established|<--sync|ack----| established|<--sync|ack----|
| | | |
| | | |
|------ack----->|established |------ack----->|established
*/ */
func (h *handshake) execute() *tcpip.Error { func (h *handshake) execute() *tcpip.Error {
// 是否需要拿到下一条地址 // 是否需要拿到下一条地址
if h.ep.route.IsResolutionRequired() { if h.ep.route.IsResolutionRequired() {
if err := h.resolveRoute(); err != nil { if err := h.resolveRoute(); err != nil {
return err return err
} }
} }
// Initialize the resend timer. // Initialize the resend timer.
// 初始化重传定时器 // 初始化重传定时器
resendWaker := sleep.Waker{} resendWaker := sleep.Waker{}
// 设置1s超时 // 设置1s超时
timeOut := time.Duration(time.Second) timeOut := time.Duration(time.Second)
rt := time.AfterFunc(timeOut, func() { rt := time.AfterFunc(timeOut, func() {
resendWaker.Assert() resendWaker.Assert()
}) })
defer rt.Stop() defer rt.Stop()
// Set up the wakers. // Set up the wakers.
s := sleep.Sleeper{} s := sleep.Sleeper{}
s.AddWaker(&resendWaker, wakerForResend) s.AddWaker(&resendWaker, wakerForResend)
s.AddWaker(&h.ep.notificationWaker, wakerForNotification) s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
defer s.Done() defer s.Done()
// sync报文的选项参数 // sync报文的选项参数
synOpts := header.TCPSynOptions{} synOpts := header.TCPSynOptions{}
// 如果是客户端发送 syn 报文,如果是服务端发送 syn+ack 报文 // 如果是客户端发送 syn 报文,如果是服务端发送 syn+ack 报文
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
for h.state != handshakeCompleted { for h.state != handshakeCompleted {
// 获取事件id // 获取事件id
switch index, _ := s.Fetch(true); index { switch index, _ := s.Fetch(true); index {
case wakerForResend: // NOTE tcp超时重传机制 case wakerForResend: // NOTE tcp超时重传机制
// 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传 // 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传
// 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传 // 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传
// 超时时间变为上次的2倍 // 超时时间变为上次的2倍
timeOut *= 2 timeOut *= 2
if timeOut > 60*time.Second { if timeOut > 60*time.Second {
return tcpip.ErrTimeout return tcpip.ErrTimeout
} }
rt.Reset(timeOut) rt.Reset(timeOut)
// 重新发送syn报文 // 重新发送syn报文
sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
case wakerForNotification: case wakerForNotification:
case wakerForNewSegment: case wakerForNewSegment:
// 处理握手报文 // 处理握手报文
} }
} }
return nil return nil
} }
var optionPool = sync.Pool{ var optionPool = sync.Pool{
New: func() interface{} { New: func() interface{} {
return make([]byte, maxOptionSize) return make([]byte, maxOptionSize)
}, },
} }
// 减少资源浪费 // 减少资源浪费
func getOptions() []byte { func getOptions() []byte {
return optionPool.Get().([]byte) return optionPool.Get().([]byte)
} }
func putOptions(options []byte) { func putOptions(options []byte) {
// Reslice to full capacity. // Reslice to full capacity.
optionPool.Put(options[0:cap(options)]) optionPool.Put(options[0:cap(options)])
} }
// tcp选项的编码 将一个TCPSyncOptions编码到 []byte 中 // tcp选项的编码 将一个TCPSyncOptions编码到 []byte 中
func makeSynOptions(opts header.TCPSynOptions) []byte { func makeSynOptions(opts header.TCPSynOptions) []byte {
// Emulate linux option order. This is as follows: // Emulate linux option order. This is as follows:
// //
// if md5: NOP NOP MD5SIG 18 md5sig(16) // if md5: NOP NOP MD5SIG 18 md5sig(16)
// if mss: MSS 4 mss(2) // if mss: MSS 4 mss(2)
// if ts and sack_advertise: // if ts and sack_advertise:
// SACK 2 TIMESTAMP 2 timestamp(8) // SACK 2 TIMESTAMP 2 timestamp(8)
// elif ts: NOP NOP TIMESTAMP 10 timestamp(8) // elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
// elif sack: NOP NOP SACK 2 // elif sack: NOP NOP SACK 2
// if wscale: NOP WINDOW 3 ws(1) // if wscale: NOP WINDOW 3 ws(1)
// if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8)) // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
// [for each block] start_seq(4) end_seq(4) // [for each block] start_seq(4) end_seq(4)
// if fastopen_cookie: // if fastopen_cookie:
// if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2) // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
// else: FASTOPEN (2 + len(cookie)) // else: FASTOPEN (2 + len(cookie))
// cookie(variable) [padding to four bytes] // cookie(variable) [padding to four bytes]
// //
options := getOptions() options := getOptions()
// Always encode the mss. // Always encode the mss.
offset := header.EncodeMSSOption(uint32(opts.MSS), options) offset := header.EncodeMSSOption(uint32(opts.MSS), options)
// Special ordering is required here. If both TS and SACK are enabled, // Special ordering is required here. If both TS and SACK are enabled,
// then the SACK option precedes TS, with no padding. If they are // then the SACK option precedes TS, with no padding. If they are
// enabled individually, then we see padding before the option. // enabled individually, then we see padding before the option.
if opts.TS && opts.SACKPermitted { if opts.TS && opts.SACKPermitted {
offset += header.EncodeSACKPermittedOption(options[offset:]) offset += header.EncodeSACKPermittedOption(options[offset:])
offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
} else if opts.TS { } else if opts.TS {
offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:])
offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
} else if opts.SACKPermitted { } else if opts.SACKPermitted {
offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:])
offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:])
offset += header.EncodeSACKPermittedOption(options[offset:]) offset += header.EncodeSACKPermittedOption(options[offset:])
} }
// Initialize the WS option. // Initialize the WS option.
if opts.WS >= 0 { if opts.WS >= 0 {
offset += header.EncodeNOP(options[offset:]) offset += header.EncodeNOP(options[offset:])
offset += header.EncodeWSOption(opts.WS, options[offset:]) offset += header.EncodeWSOption(opts.WS, options[offset:])
} }
// Padding to the end; note that this never apply unless we add a // Padding to the end; note that this never apply unless we add a
// fastopen option, we always expect the offset to remain the same. // fastopen option, we always expect the offset to remain the same.
if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
panic("unexpected option encoding") panic("unexpected option encoding")
} }
return options[:offset] return options[:offset]
} }
// 封装 sendTCP ,发送 syn 报文 // 封装 sendTCP ,发送 syn 报文
func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte,
seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
if opts.MSS == 0 { if opts.MSS == 0 {
opts.MSS = uint16(r.MTU() - header.TCPMinimumSize) opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
} }
options := makeSynOptions(opts) options := makeSynOptions(opts)
err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options) err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options)
return err return err
} }
// sendTCP sends a TCP segment with the provided options via the provided // sendTCP sends a TCP segment with the provided options via the provided
// network endpoint and under the provided identity. // network endpoint and under the provided identity.
// 发送一个tcp段数据封装 tcp 首部,并写入网路层 // 发送一个tcp段数据封装 tcp 首部,并写入网路层
func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte,
seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
log.Println("进行一个报文的发送") log.Println("进行一个报文的发送")
optLen := len(opts) optLen := len(opts)
// Allocate a buffer for the TCP header. // Allocate a buffer for the TCP header.
hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
if rcvWnd > 0xffff { if rcvWnd > 0xffff {
rcvWnd = 0xffff rcvWnd = 0xffff
} }
// Initialize the header. // Initialize the header.
tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
tcp.Encode(&header.TCPFields{ tcp.Encode(&header.TCPFields{
SrcPort: id.LocalPort, SrcPort: id.LocalPort,
DstPort: id.RemotePort, DstPort: id.RemotePort,
SeqNum: uint32(seq), SeqNum: uint32(seq),
AckNum: uint32(ack), AckNum: uint32(ack),
DataOffset: uint8(header.TCPMinimumSize + optLen), DataOffset: uint8(header.TCPMinimumSize + optLen),
Flags: flags, Flags: flags,
WindowSize: uint16(rcvWnd), WindowSize: uint16(rcvWnd),
}) })
copy(tcp[header.TCPMinimumSize:], opts) copy(tcp[header.TCPMinimumSize:], opts)
// Only calculate the checksum if offloading isn't supported. // Only calculate the checksum if offloading isn't supported.
if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
length := uint16(hdr.UsedLength() + data.Size()) length := uint16(hdr.UsedLength() + data.Size())
// tcp伪首部校验和的计算 // tcp伪首部校验和的计算
xsum := r.PseudoHeaderChecksum(ProtocolNumber) xsum := r.PseudoHeaderChecksum(ProtocolNumber)
for _, v := range data.Views() { for _, v := range data.Views() {
xsum = header.Checksum(v, xsum) xsum = header.Checksum(v, xsum)
} }
// tcp的可靠性校验和的计算用于检测损伤的报文段 // tcp的可靠性校验和的计算用于检测损伤的报文段
tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
} }
r.Stats().TCP.SegmentsSent.Increment() r.Stats().TCP.SegmentsSent.Increment()
if (flags & flagRst) != 0 { if (flags & flagRst) != 0 {
r.Stats().TCP.ResetsSent.Increment() r.Stats().TCP.ResetsSent.Increment()
} }
log.Printf("send tcp %s segment to %s, seq: %d, ack: %d, rcvWnd: %d", log.Printf("send tcp %s segment to %s, seq: %d, ack: %d, rcvWnd: %d",
flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort), flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort),
seq, ack, rcvWnd) seq, ack, rcvWnd)
return r.WritePacket(hdr, data, ProtocolNumber, ttl) return r.WritePacket(hdr, data, ProtocolNumber, ttl)
} }
// protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行负责握手、发送段和处理收到的段 // protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行负责握手、发送段和处理收到的段
func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
for { for {
log.Println("三次握手机制在这里实现") log.Println("三次握手机制在这里实现")
select {} select {}
} }
} }

View File

@@ -1,403 +1,403 @@
package tcp package tcp
import ( import (
"fmt" "fmt"
"log" "log"
"netstack/sleep" "netstack/sleep"
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/buffer" "netstack/tcpip/buffer"
"netstack/tcpip/header" "netstack/tcpip/header"
"netstack/tcpip/seqnum" "netstack/tcpip/seqnum"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"netstack/waiter" "netstack/waiter"
"sync" "sync"
) )
// tcp状态机的状态 // tcp状态机的状态
type endpointState int type endpointState int
// tcp 状态机的各种状态 // tcp 状态机的各种状态
const ( const (
stateInitial endpointState = iota stateInitial endpointState = iota
stateBound stateBound
stateListen stateListen
stateConnecting stateConnecting
stateConnected stateConnected
stateClosed stateClosed
stateError stateError
) )
// endpoint 表示TCP端点。该结构用作端点用户和协议实现之间的接口;让并发goroutine调用端点是合法的 // endpoint 表示TCP端点。该结构用作端点用户和协议实现之间的接口;让并发goroutine调用端点是合法的
// 它们是正确同步的。然而协议实现在单个goroutine中运行。 // 它们是正确同步的。然而协议实现在单个goroutine中运行。
type endpoint struct { type endpoint struct {
stack *stack.Stack // 网络协议栈 stack *stack.Stack // 网络协议栈
netProto tcpip.NetworkProtocolNumber // 网络协议号 ipv4 ipv6 netProto tcpip.NetworkProtocolNumber // 网络协议号 ipv4 ipv6
waiterQueue *waiter.Queue // 事件驱动机制 waiterQueue *waiter.Queue // 事件驱动机制
// TODO 需要添加 // TODO 需要添加
// rcvListMu can be taken after the endpoint mu below. // rcvListMu can be taken after the endpoint mu below.
rcvListMu sync.Mutex rcvListMu sync.Mutex
rcvList segmentList rcvList segmentList
rcvClosed bool rcvClosed bool
rcvBufSize int rcvBufSize int
rcvBufUsed int rcvBufUsed int
// The following fields are protected by the mutex. // The following fields are protected by the mutex.
mu sync.RWMutex mu sync.RWMutex
id stack.TransportEndpointID // tcp端在网络协议栈的唯一ID id stack.TransportEndpointID // tcp端在网络协议栈的唯一ID
state endpointState // 目前tcp状态机的状态 state endpointState // 目前tcp状态机的状态
isPortReserved bool // 是否已经分配端口 isPortReserved bool // 是否已经分配端口
isRegistered bool // 是否已经注册在网络协议栈 isRegistered bool // 是否已经注册在网络协议栈
boundNICID tcpip.NICID boundNICID tcpip.NICID
route stack.Route // tcp端在网络协议栈中的路由地址 route stack.Route // tcp端在网络协议栈中的路由地址
v6only bool // 是否仅仅支持ipv6 v6only bool // 是否仅仅支持ipv6
isConnectNotified bool isConnectNotified bool
// effectiveNetProtos contains the network protocols actually in use. In // effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6 // most cases it will only contain "netProto", but in cases like IPv6
// endpoints with v6only set to false, this could include multiple // endpoints with v6only set to false, this could include multiple
// protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
// IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
// address). // address).
effectiveNetProtos []tcpip.NetworkProtocolNumber effectiveNetProtos []tcpip.NetworkProtocolNumber
// workerRunning specifies if a worker goroutine is running. // workerRunning specifies if a worker goroutine is running.
workerRunning bool workerRunning bool
// sendTSOk is used to indicate when the TS Option has been negotiated. // sendTSOk is used to indicate when the TS Option has been negotiated.
// When sendTSOk is true every non-RST segment should carry a TS as per // When sendTSOk is true every non-RST segment should carry a TS as per
// RFC7323#section-1.1 // RFC7323#section-1.1
sendTSOk bool sendTSOk bool
// recentTS is the timestamp that should be sent in the TSEcr field of // recentTS is the timestamp that should be sent in the TSEcr field of
// the timestamp for future segments sent by the endpoint. This field is // the timestamp for future segments sent by the endpoint. This field is
// updated if required when a new segment is received by this endpoint. // updated if required when a new segment is received by this endpoint.
recentTS uint32 recentTS uint32
// sackPermitted is set to true if the peer sends the TCPSACKPermitted // sackPermitted is set to true if the peer sends the TCPSACKPermitted
// option in the SYN/SYN-ACK. // option in the SYN/SYN-ACK.
sackPermitted bool sackPermitted bool
segmentQueue segmentQueue segmentQueue segmentQueue
// When the send side is closed, the protocol goroutine is notified via // When the send side is closed, the protocol goroutine is notified via
// sndCloseWaker, and sndClosed is set to true. // sndCloseWaker, and sndClosed is set to true.
sndBufMu sync.Mutex sndBufMu sync.Mutex
sndBufSize int sndBufSize int
sndBufUsed int sndBufUsed int
sndClosed bool sndClosed bool
sndBufInQueue seqnum.Size sndBufInQueue seqnum.Size
sndQueue segmentList sndQueue segmentList
sndWaker sleep.Waker sndWaker sleep.Waker
sndCloseWaker sleep.Waker sndCloseWaker sleep.Waker
// notificationWaker is used to indicate to the protocol goroutine that // notificationWaker is used to indicate to the protocol goroutine that
// it needs to wake up and check for notifications. // it needs to wake up and check for notifications.
notificationWaker sleep.Waker notificationWaker sleep.Waker
// newSegmentWaker is used to indicate to the protocol goroutine that // newSegmentWaker is used to indicate to the protocol goroutine that
// it needs to wake up and handle new segments queued to it. // it needs to wake up and handle new segments queued to it.
// HandlePacket收到segment后通知处理的事件驱动器 // HandlePacket收到segment后通知处理的事件驱动器
newSegmentWaker sleep.Waker newSegmentWaker sleep.Waker
// acceptedChan is used by a listening endpoint protocol goroutine to // acceptedChan is used by a listening endpoint protocol goroutine to
// send newly accepted connections to the endpoint so that they can be // send newly accepted connections to the endpoint so that they can be
// read by Accept() calls. // read by Accept() calls.
acceptedChan chan *endpoint acceptedChan chan *endpoint
// The following are only used from the protocol goroutine, and // The following are only used from the protocol goroutine, and
// therefore don't need locks to protect them. // therefore don't need locks to protect them.
rcv *receiver rcv *receiver
snd *sender snd *sender
// The following are only used to assist the restore run to re-connect. // The following are only used to assist the restore run to re-connect.
bindAddress tcpip.Address bindAddress tcpip.Address
connectingAddress tcpip.Address connectingAddress tcpip.Address
} }
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
e := &endpoint{ e := &endpoint{
stack: stack, stack: stack,
netProto: netProto, netProto: netProto,
waiterQueue: waiterQueue, waiterQueue: waiterQueue,
rcvBufSize: DefaultBufferSize, rcvBufSize: DefaultBufferSize,
sndBufSize: DefaultBufferSize, sndBufSize: DefaultBufferSize,
} }
// TODO 需要添加 // TODO 需要添加
e.segmentQueue.setLimit(2 * e.rcvBufSize) e.segmentQueue.setLimit(2 * e.rcvBufSize)
return e return e
} }
func (e *endpoint) Close() { func (e *endpoint) Close() {
log.Println("TODO 在写了 在写了") log.Println("TODO 在写了 在写了")
} }
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
return nil, tcpip.ControlMessages{}, nil return nil, tcpip.ControlMessages{}, nil
} }
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
return 0, nil, nil return 0, nil, nil
} }
func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil return 0, tcpip.ControlMessages{}, nil
} }
func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
netProto := e.netProto netProto := e.netProto
if header.IsV4MappedAddress(addr.Addr) { if header.IsV4MappedAddress(addr.Addr) {
// Fail if using a v4 mapped address on a v6only endpoint. // Fail if using a v4 mapped address on a v6only endpoint.
if e.v6only { if e.v6only {
return 0, tcpip.ErrNoRoute return 0, tcpip.ErrNoRoute
} }
netProto = header.IPv4ProtocolNumber netProto = header.IPv4ProtocolNumber
addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
if addr.Addr == "\x00\x00\x00\x00" { if addr.Addr == "\x00\x00\x00\x00" {
addr.Addr = "" addr.Addr = ""
} }
} }
// Fail if we're bound to an address length different from the one we're // Fail if we're bound to an address length different from the one we're
// checking. // checking.
if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
return 0, tcpip.ErrInvalidEndpointState return 0, tcpip.ErrInvalidEndpointState
} }
return netProto, nil return netProto, nil
} }
func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error { func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
return nil return nil
} }
func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
return nil return nil
} }
func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
log.Println("监听一个tcp端口") log.Println("监听一个tcp端口")
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()
defer func() { defer func() {
if err != nil && err.IgnoreStats() { if err != nil && err.IgnoreStats() {
e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
} }
}() }()
// TODO 需要添加 // TODO 需要添加
// 在调用 Listen 之前,必须先 Bind // 在调用 Listen 之前,必须先 Bind
if e.state != stateBound { if e.state != stateBound {
return tcpip.ErrInvalidEndpointState return tcpip.ErrInvalidEndpointState
} }
// 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。 // 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。
if err := e.stack.RegisterTransportEndpoint(e.boundNICID, if err := e.stack.RegisterTransportEndpoint(e.boundNICID,
e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil {
return err return err
} }
e.isRegistered = true e.isRegistered = true
e.state = stateListen e.state = stateListen
if e.acceptedChan == nil { if e.acceptedChan == nil {
e.acceptedChan = make(chan *endpoint, backlog) e.acceptedChan = make(chan *endpoint, backlog)
} }
e.workerRunning = true e.workerRunning = true
e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
// TODO tcp服务端实现的主循环这个函数很重要用一个goroutine来服务 // TODO tcp服务端实现的主循环这个函数很重要用一个goroutine来服务
go e.protocolListenLoop(seqnum.Size(0)) go e.protocolListenLoop(seqnum.Size(0))
return nil return nil
} }
// startAcceptedLoop sets up required state and starts a goroutine with the // startAcceptedLoop sets up required state and starts a goroutine with the
// main loop for accepted connections. // main loop for accepted connections.
func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
e.waiterQueue = waiterQueue e.waiterQueue = waiterQueue
e.workerRunning = true e.workerRunning = true
go e.protocolMainLoop(false) go e.protocolMainLoop(false)
} }
func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
e.mu.RLock() e.mu.RLock()
defer e.mu.RUnlock() defer e.mu.RUnlock()
// Endpoint must be in listen state before it can accept connections. // Endpoint must be in listen state before it can accept connections.
if e.state != stateListen { if e.state != stateListen {
return nil, nil, tcpip.ErrInvalidEndpointState return nil, nil, tcpip.ErrInvalidEndpointState
} }
var n *endpoint var n *endpoint
select { select {
case n = <-e.acceptedChan: case n = <-e.acceptedChan:
default: default:
return nil, nil, tcpip.ErrWouldBlock return nil, nil, tcpip.ErrWouldBlock
} }
wq := &waiter.Queue{} wq := &waiter.Queue{}
n.startAcceptedLoop(wq) n.startAcceptedLoop(wq)
return n, wq, nil return n, wq, nil
} }
// Bind binds the endpoint to a specific local port and optionally address. // Bind binds the endpoint to a specific local port and optionally address.
// 将端点绑定到特定的本地端口和可选的地址。 // 将端点绑定到特定的本地端口和可选的地址。
func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()
// 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。 // 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。
if e.state != stateInitial { if e.state != stateInitial {
return tcpip.ErrAlreadyBound return tcpip.ErrAlreadyBound
} }
// 确定tcp端的绑定ip // 确定tcp端的绑定ip
e.bindAddress = addr.Addr e.bindAddress = addr.Addr
netProto, err := e.checkV4Mapped(&addr) netProto, err := e.checkV4Mapped(&addr)
if err != nil { if err != nil {
return err return err
} }
// 确定tcp支持的网络层协议 // 确定tcp支持的网络层协议
netProtos := []tcpip.NetworkProtocolNumber{netProto} netProtos := []tcpip.NetworkProtocolNumber{netProto}
if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
netProtos = []tcpip.NetworkProtocolNumber{ netProtos = []tcpip.NetworkProtocolNumber{
header.IPv6ProtocolNumber, header.IPv6ProtocolNumber,
header.IPv4ProtocolNumber, header.IPv4ProtocolNumber,
} }
} }
// 绑定端口 // 绑定端口
port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
if err != nil { if err != nil {
return err return err
} }
e.isPortReserved = true e.isPortReserved = true
e.effectiveNetProtos = netProtos e.effectiveNetProtos = netProtos
e.id.LocalPort = port e.id.LocalPort = port
defer func() { defer func() {
// 如果有错,在退出的时候应该解除端口绑定 // 如果有错,在退出的时候应该解除端口绑定
if err != nil { if err != nil {
e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
e.isPortReserved = false e.isPortReserved = false
e.effectiveNetProtos = nil e.effectiveNetProtos = nil
e.id.LocalPort = 0 e.id.LocalPort = 0
e.id.LocalAddress = "" e.id.LocalAddress = ""
e.boundNICID = 0 e.boundNICID = 0
} }
}() }()
// 如果指定了ip地址 需要检查一下这个ip地址本地是否绑定过 // 如果指定了ip地址 需要检查一下这个ip地址本地是否绑定过
if len(addr.Addr) != 0 { if len(addr.Addr) != 0 {
nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
if nic == 0 { if nic == 0 {
return tcpip.ErrBadLocalAddress return tcpip.ErrBadLocalAddress
} }
e.boundNICID = nic e.boundNICID = nic
e.id.LocalAddress = addr.Addr e.id.LocalAddress = addr.Addr
} }
// Check the commit function. // Check the commit function.
if commit != nil { if commit != nil {
if err := commit(); err != nil { if err := commit(); err != nil {
// The defer takes care of unwind. // The defer takes care of unwind.
return err return err
} }
} }
// 标记状态为 stateBound // 标记状态为 stateBound
e.state = stateBound e.state = stateBound
return nil return nil
} }
func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock() e.mu.RLock()
defer e.mu.RUnlock() defer e.mu.RUnlock()
return tcpip.FullAddress{ return tcpip.FullAddress{
Addr: e.id.LocalAddress, Addr: e.id.LocalAddress,
Port: e.id.LocalPort, Port: e.id.LocalPort,
NIC: e.boundNICID, NIC: e.boundNICID,
}, nil }, nil
} }
func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
e.mu.RLock() e.mu.RLock()
defer e.mu.RUnlock() defer e.mu.RUnlock()
if e.state != stateConnected { if e.state != stateConnected {
return tcpip.FullAddress{}, tcpip.ErrNotConnected return tcpip.FullAddress{}, tcpip.ErrNotConnected
} }
return tcpip.FullAddress{ return tcpip.FullAddress{
Addr: e.id.RemoteAddress, Addr: e.id.RemoteAddress,
Port: e.id.RemotePort, Port: e.id.RemotePort,
NIC: e.boundNICID, NIC: e.boundNICID,
}, nil }, nil
} }
func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
return waiter.EventErr return waiter.EventErr
} }
func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil return nil
} }
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil return nil
} }
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
s := newSegment(r, id, vv) s := newSegment(r, id, vv)
// 解析tcp段如果解析失败丢弃该报文 // 解析tcp段如果解析失败丢弃该报文
if !s.parse() { if !s.parse() {
e.stack.Stats().MalformedRcvdPackets.Increment() e.stack.Stats().MalformedRcvdPackets.Increment()
e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
s.decRef() s.decRef()
return return
} }
e.stack.Stats().TCP.ValidSegmentsReceived.Increment() // 有效报文喜加一 e.stack.Stats().TCP.ValidSegmentsReceived.Increment() // 有效报文喜加一
if (s.flags & flagRst) != 0 { // RST报文需要拒绝 if (s.flags & flagRst) != 0 { // RST报文需要拒绝
e.stack.Stats().TCP.ResetsReceived.Increment() e.stack.Stats().TCP.ResetsReceived.Increment()
} }
// Send packet to worker goroutine. // Send packet to worker goroutine.
if e.segmentQueue.enqueue(s) { if e.segmentQueue.enqueue(s) {
log.Printf("收到 tcp [%s] 报文片段 from %s, seq: %d, ack: %d", log.Printf("收到 tcp [%s] 报文片段 from %s, seq: %d, ack: %d",
flagString(s.flags), fmt.Sprintf("%s:%d", s.id.RemoteAddress, s.id.RemotePort), flagString(s.flags), fmt.Sprintf("%s:%d", s.id.RemoteAddress, s.id.RemotePort),
s.sequenceNumber, s.ackNumber) s.sequenceNumber, s.ackNumber)
e.newSegmentWaker.Assert() e.newSegmentWaker.Assert()
} else { } else {
// The queue is full, so we drop the segment. // The queue is full, so we drop the segment.
e.stack.Stats().DroppedPackets.Increment() e.stack.Stats().DroppedPackets.Increment()
s.decRef() s.decRef()
} }
} }
func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
} }
// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if // maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
// the SYN options indicate that timestamp option was negotiated. It also // the SYN options indicate that timestamp option was negotiated. It also
// initializes the recentTS with the value provided in synOpts.TSval. // initializes the recentTS with the value provided in synOpts.TSval.
func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
if synOpts.TS { if synOpts.TS {
e.sendTSOk = true e.sendTSOk = true
e.recentTS = synOpts.TSVal e.recentTS = synOpts.TSVal
} }
} }
// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint // maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
// if the SYN options indicate that the SACK option was negotiated and the TCP // if the SYN options indicate that the SACK option was negotiated and the TCP
// stack is configured to enable TCP SACK option. // stack is configured to enable TCP SACK option.
func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
var v SACKEnabled var v SACKEnabled
if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
// Stack doesn't support SACK. So just return. // Stack doesn't support SACK. So just return.
return return
} }
if bool(v) && synOpts.SACKPermitted { if bool(v) && synOpts.SACKPermitted {
e.sackPermitted = true e.sackPermitted = true
} }
} }

View File

@@ -1,73 +1,73 @@
package tcp package tcp
import ( import (
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/buffer" "netstack/tcpip/buffer"
"netstack/tcpip/header" "netstack/tcpip/header"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"netstack/waiter" "netstack/waiter"
) )
const ( const (
// ProtocolName is the string representation of the tcp protocol name. // ProtocolName is the string representation of the tcp protocol name.
ProtocolName = "tcp" ProtocolName = "tcp"
// ProtocolNumber is the tcp protocol number. // ProtocolNumber is the tcp protocol number.
ProtocolNumber = header.TCPProtocolNumber ProtocolNumber = header.TCPProtocolNumber
// MinBufferSize is the smallest size of a receive or send buffer. // MinBufferSize is the smallest size of a receive or send buffer.
minBufferSize = 4 << 10 // 4096 bytes. minBufferSize = 4 << 10 // 4096 bytes.
// DefaultBufferSize is the default size of the receive and send buffers. // DefaultBufferSize is the default size of the receive and send buffers.
DefaultBufferSize = 1 << 20 // 1MB DefaultBufferSize = 1 << 20 // 1MB
// MaxBufferSize is the largest size a receive and send buffer can grow to. // MaxBufferSize is the largest size a receive and send buffer can grow to.
maxBufferSize = 4 << 20 // 4MB maxBufferSize = 4 << 20 // 4MB
) )
// SACKEnabled option can be used to enable SACK support in the TCP // SACKEnabled option can be used to enable SACK support in the TCP
// protocol. See: https://tools.ietf.org/html/rfc2018. // protocol. See: https://tools.ietf.org/html/rfc2018.
type SACKEnabled bool type SACKEnabled bool
type protocol struct{} type protocol struct{}
// Number returns the tcp protocol number. // Number returns the tcp protocol number.
func (*protocol) Number() tcpip.TransportProtocolNumber { func (*protocol) Number() tcpip.TransportProtocolNumber {
return ProtocolNumber return ProtocolNumber
} }
// NewEndpoint creates a new tcp endpoint. // NewEndpoint creates a new tcp endpoint.
func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
return newEndpoint(stack, netProto, waiterQueue), nil return newEndpoint(stack, netProto, waiterQueue), nil
} }
// ParsePorts returns the source and destination ports stored in the given tcp // ParsePorts returns the source and destination ports stored in the given tcp
// packet. // packet.
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
h := header.TCP(v) h := header.TCP(v)
return h.SourcePort(), h.DestinationPort(), nil return h.SourcePort(), h.DestinationPort(), nil
} }
// MinimumPacketSize returns the minimum valid tcp packet size. // MinimumPacketSize returns the minimum valid tcp packet size.
func (*protocol) MinimumPacketSize() int { func (*protocol) MinimumPacketSize() int {
return header.TCPMinimumSize return header.TCPMinimumSize
} }
func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
return false return false
} }
// SetOption implements TransportProtocol.SetOption. // SetOption implements TransportProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error { func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return nil return nil
} }
// Option implements TransportProtocol.Option. // Option implements TransportProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error { func (p *protocol) Option(option interface{}) *tcpip.Error {
return nil return nil
} }
func init() { func init() {
stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
return &protocol{} return &protocol{}
}) })
} }

View File

@@ -1,11 +1,11 @@
package tcp package tcp
import "netstack/tcpip/seqnum" import "netstack/tcpip/seqnum"
type receiver struct{} type receiver struct{}
// 新建并初始化接收器 // 新建并初始化接收器
func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
r := &receiver{} r := &receiver{}
return r return r
} }

View File

@@ -1,135 +1,135 @@
package tcp package tcp
import ( import (
"fmt" "fmt"
"log" "log"
"netstack/tcpip/buffer" "netstack/tcpip/buffer"
"netstack/tcpip/header" "netstack/tcpip/header"
"netstack/tcpip/seqnum" "netstack/tcpip/seqnum"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"strings" "strings"
"sync/atomic" "sync/atomic"
) )
// tcp 太复杂了 专门写一个协议解析器 segment 是有种类之分的 // tcp 太复杂了 专门写一个协议解析器 segment 是有种类之分的
// Flags that may be set in a TCP segment. // Flags that may be set in a TCP segment.
const ( const (
flagFin = 1 << iota flagFin = 1 << iota
flagSyn flagSyn
flagRst flagRst
flagPsh flagPsh
flagAck flagAck
flagUrg flagUrg
) )
func flagString(flags uint8) string { func flagString(flags uint8) string {
var s []string var s []string
if (flags & flagAck) != 0 { if (flags & flagAck) != 0 {
s = append(s, "ack") s = append(s, "ack")
} }
if (flags & flagFin) != 0 { if (flags & flagFin) != 0 {
s = append(s, "fin") s = append(s, "fin")
} }
if (flags & flagPsh) != 0 { if (flags & flagPsh) != 0 {
s = append(s, "psh") s = append(s, "psh")
} }
if (flags & flagRst) != 0 { if (flags & flagRst) != 0 {
s = append(s, "rst") s = append(s, "rst")
} }
if (flags & flagSyn) != 0 { if (flags & flagSyn) != 0 {
s = append(s, "syn") s = append(s, "syn")
} }
if (flags & flagUrg) != 0 { if (flags & flagUrg) != 0 {
s = append(s, "urg") s = append(s, "urg")
} }
return strings.Join(s, "|") return strings.Join(s, "|")
} }
// segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中 // segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中
type segment struct { type segment struct {
segmentEntry segmentEntry
refCnt int32 // 引用计数 refCnt int32 // 引用计数
id stack.TransportEndpointID id stack.TransportEndpointID
route stack.Route route stack.Route
data buffer.VectorisedView data buffer.VectorisedView
// views is used as buffer for data when its length is large // views is used as buffer for data when its length is large
// enough to store a VectorisedView. // enough to store a VectorisedView.
views [8]buffer.View views [8]buffer.View
// TODO 需要解析 // TODO 需要解析
viewToDeliver int viewToDeliver int
sequenceNumber seqnum.Value // tcp序号 第一个字节在整个报文的位置 sequenceNumber seqnum.Value // tcp序号 第一个字节在整个报文的位置
ackNumber seqnum.Value // 确认号 希望继续获取的下一个字节序号 ackNumber seqnum.Value // 确认号 希望继续获取的下一个字节序号
flags uint8 flags uint8
window seqnum.Size window seqnum.Size
// parsedOptions stores the parsed values from the options in the segment. // parsedOptions stores the parsed values from the options in the segment.
parsedOptions header.TCPOptions parsedOptions header.TCPOptions
options []byte options []byte
} }
func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
s := &segment{refCnt: 1, id: id, route: r.Clone()} s := &segment{refCnt: 1, id: id, route: r.Clone()}
s.data = vv.Clone(s.views[:]) s.data = vv.Clone(s.views[:])
return s return s
} }
func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
s := &segment{ s := &segment{
refCnt: 1, refCnt: 1,
id: id, id: id,
route: r.Clone(), route: r.Clone(),
} }
s.views[0] = v s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1]) // TODO 为什么只复制1? s.data = buffer.NewVectorisedView(len(v), s.views[:1]) // TODO 为什么只复制1?
return s return s
} }
func (s *segment) clone() *segment { func (s *segment) clone() *segment {
t := &segment{ t := &segment{
refCnt: 1, refCnt: 1,
id: s.id, id: s.id,
sequenceNumber: s.sequenceNumber, sequenceNumber: s.sequenceNumber,
ackNumber: s.ackNumber, ackNumber: s.ackNumber,
flags: s.flags, flags: s.flags,
window: s.window, window: s.window,
route: s.route.Clone(), route: s.route.Clone(),
viewToDeliver: s.viewToDeliver, viewToDeliver: s.viewToDeliver,
} }
t.data = s.data.Clone(t.views[:]) t.data = s.data.Clone(t.views[:])
return t return t
} }
func (s *segment) flagIsSet(flag uint8) bool { func (s *segment) flagIsSet(flag uint8) bool {
return (s.flags & flag) != 0 return (s.flags & flag) != 0
} }
func (s *segment) decRef() { func (s *segment) decRef() {
if atomic.AddInt32(&s.refCnt, -1) == 0 { if atomic.AddInt32(&s.refCnt, -1) == 0 {
s.route.Release() s.route.Release()
} }
} }
func (s *segment) incRef() { func (s *segment) incRef() {
atomic.AddInt32(&s.refCnt, 1) atomic.AddInt32(&s.refCnt, 1)
} }
func (s *segment) parse() bool { func (s *segment) parse() bool {
h := header.TCP(s.data.First()) h := header.TCP(s.data.First())
offset := int(h.DataOffset()) offset := int(h.DataOffset())
if offset < header.TCPMinimumSize || offset > len(h) { if offset < header.TCPMinimumSize || offset > len(h) {
return false return false
} }
s.options = h.Options() s.options = h.Options()
s.parsedOptions = header.ParseTCPOptions(s.options) s.parsedOptions = header.ParseTCPOptions(s.options)
log.Println(h) log.Println(h)
fmt.Println(s.parsedOptions) fmt.Println(s.parsedOptions)
s.data.TrimFront(offset) s.data.TrimFront(offset)
s.sequenceNumber = seqnum.Value(h.SequenceNumber()) s.sequenceNumber = seqnum.Value(h.SequenceNumber())
s.ackNumber = seqnum.Value(h.AckNumber()) s.ackNumber = seqnum.Value(h.AckNumber())
s.flags = h.Flags() // U|A|P|R|S|F s.flags = h.Flags() // U|A|P|R|S|F
s.window = seqnum.Size(h.WindowSize()) s.window = seqnum.Size(h.WindowSize())
return true return true
} }

View File

@@ -1,50 +1,50 @@
package tcp package tcp
import ( import (
"netstack/tcpip/header" "netstack/tcpip/header"
"sync" "sync"
) )
type segmentQueue struct { type segmentQueue struct {
mu sync.Mutex mu sync.Mutex
list segmentList // 队列实现 list segmentList // 队列实现
limit int // 队列容量 limit int // 队列容量
used int // 队列长度 used int // 队列长度
} }
func (q *segmentQueue) empty() bool { func (q *segmentQueue) empty() bool {
q.mu.Lock() q.mu.Lock()
r := q.used == 0 r := q.used == 0
q.mu.Unlock() q.mu.Unlock()
return r return r
} }
func (q *segmentQueue) enqueue(s *segment) bool { func (q *segmentQueue) enqueue(s *segment) bool {
q.mu.Lock() q.mu.Lock()
r := q.used < q.limit r := q.used < q.limit
if r { if r {
q.list.PushBack(s) q.list.PushBack(s)
q.used += s.data.Size() + header.TCPMinimumSize q.used += s.data.Size() + header.TCPMinimumSize
} }
q.mu.Unlock() q.mu.Unlock()
return r return r
} }
func (q *segmentQueue) dequeue() *segment { func (q *segmentQueue) dequeue() *segment {
q.mu.Lock() q.mu.Lock()
s := q.list.Front() s := q.list.Front()
if s != nil { if s != nil {
q.list.Remove(s) q.list.Remove(s)
q.used -= s.data.Size() + header.TCPMinimumSize q.used -= s.data.Size() + header.TCPMinimumSize
} }
q.mu.Unlock() q.mu.Unlock()
return s return s
} }
func (q *segmentQueue) setLimit(limit int) { func (q *segmentQueue) setLimit(limit int) {
q.mu.Lock() q.mu.Lock()
q.limit = limit q.limit = limit
q.mu.Unlock() q.mu.Unlock()
} }

View File

@@ -1,12 +1,12 @@
package tcp package tcp
import "netstack/tcpip/seqnum" import "netstack/tcpip/seqnum"
type sender struct { type sender struct {
} }
// 新建并初始化发送器 // 新建并初始化发送器
func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
s := &sender{} s := &sender{}
return s return s
} }

View File

@@ -1,173 +1,173 @@
package tcp package tcp
// ElementMapper provides an identity mapping by default. // ElementMapper provides an identity mapping by default.
// //
// This can be replaced to provide a struct that maps elements to linker // This can be replaced to provide a struct that maps elements to linker
// objects, if they are not the same. An ElementMapper is not typically // objects, if they are not the same. An ElementMapper is not typically
// required if: Linker is left as is, Element is left as is, or Linker and // required if: Linker is left as is, Element is left as is, or Linker and
// Element are the same type. // Element are the same type.
type segmentElementMapper struct{} type segmentElementMapper struct{}
// linkerFor maps an Element to a Linker. // linkerFor maps an Element to a Linker.
// //
// This default implementation should be inlined. // This default implementation should be inlined.
// //
//go:nosplit //go:nosplit
func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
// List is an intrusive list. Entries can be added to or removed from the list // List is an intrusive list. Entries can be added to or removed from the list
// in O(1) time and with no additional memory allocations. // in O(1) time and with no additional memory allocations.
// //
// The zero value for List is an empty list ready to use. // The zero value for List is an empty list ready to use.
// //
// To iterate over a list (where l is a List): // To iterate over a list (where l is a List):
// for e := l.Front(); e != nil; e = e.Next() { // for e := l.Front(); e != nil; e = e.Next() {
// // do something with e. // // do something with e.
// } // }
// //
// +stateify savable // +stateify savable
type segmentList struct { type segmentList struct {
head *segment head *segment
tail *segment tail *segment
} }
// Reset resets list l to the empty state. // Reset resets list l to the empty state.
func (l *segmentList) Reset() { func (l *segmentList) Reset() {
l.head = nil l.head = nil
l.tail = nil l.tail = nil
} }
// Empty returns true iff the list is empty. // Empty returns true iff the list is empty.
func (l *segmentList) Empty() bool { func (l *segmentList) Empty() bool {
return l.head == nil return l.head == nil
} }
// Front returns the first element of list l or nil. // Front returns the first element of list l or nil.
func (l *segmentList) Front() *segment { func (l *segmentList) Front() *segment {
return l.head return l.head
} }
// Back returns the last element of list l or nil. // Back returns the last element of list l or nil.
func (l *segmentList) Back() *segment { func (l *segmentList) Back() *segment {
return l.tail return l.tail
} }
// PushFront inserts the element e at the front of list l. // PushFront inserts the element e at the front of list l.
func (l *segmentList) PushFront(e *segment) { func (l *segmentList) PushFront(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(l.head) segmentElementMapper{}.linkerFor(e).SetNext(l.head)
segmentElementMapper{}.linkerFor(e).SetPrev(nil) segmentElementMapper{}.linkerFor(e).SetPrev(nil)
if l.head != nil { if l.head != nil {
segmentElementMapper{}.linkerFor(l.head).SetPrev(e) segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
} else { } else {
l.tail = e l.tail = e
} }
l.head = e l.head = e
} }
// PushBack inserts the element e at the back of list l. // PushBack inserts the element e at the back of list l.
func (l *segmentList) PushBack(e *segment) { func (l *segmentList) PushBack(e *segment) {
segmentElementMapper{}.linkerFor(e).SetNext(nil) segmentElementMapper{}.linkerFor(e).SetNext(nil)
segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
if l.tail != nil { if l.tail != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(e) segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
} else { } else {
l.head = e l.head = e
} }
l.tail = e l.tail = e
} }
// PushBackList inserts list m at the end of list l, emptying m. // PushBackList inserts list m at the end of list l, emptying m.
func (l *segmentList) PushBackList(m *segmentList) { func (l *segmentList) PushBackList(m *segmentList) {
if l.head == nil { if l.head == nil {
l.head = m.head l.head = m.head
l.tail = m.tail l.tail = m.tail
} else if m.head != nil { } else if m.head != nil {
segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
l.tail = m.tail l.tail = m.tail
} }
m.head = nil m.head = nil
m.tail = nil m.tail = nil
} }
// InsertAfter inserts e after b. // InsertAfter inserts e after b.
func (l *segmentList) InsertAfter(b, e *segment) { func (l *segmentList) InsertAfter(b, e *segment) {
a := segmentElementMapper{}.linkerFor(b).Next() a := segmentElementMapper{}.linkerFor(b).Next()
segmentElementMapper{}.linkerFor(e).SetNext(a) segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b) segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(b).SetNext(e) segmentElementMapper{}.linkerFor(b).SetNext(e)
if a != nil { if a != nil {
segmentElementMapper{}.linkerFor(a).SetPrev(e) segmentElementMapper{}.linkerFor(a).SetPrev(e)
} else { } else {
l.tail = e l.tail = e
} }
} }
// InsertBefore inserts e before a. // InsertBefore inserts e before a.
func (l *segmentList) InsertBefore(a, e *segment) { func (l *segmentList) InsertBefore(a, e *segment) {
b := segmentElementMapper{}.linkerFor(a).Prev() b := segmentElementMapper{}.linkerFor(a).Prev()
segmentElementMapper{}.linkerFor(e).SetNext(a) segmentElementMapper{}.linkerFor(e).SetNext(a)
segmentElementMapper{}.linkerFor(e).SetPrev(b) segmentElementMapper{}.linkerFor(e).SetPrev(b)
segmentElementMapper{}.linkerFor(a).SetPrev(e) segmentElementMapper{}.linkerFor(a).SetPrev(e)
if b != nil { if b != nil {
segmentElementMapper{}.linkerFor(b).SetNext(e) segmentElementMapper{}.linkerFor(b).SetNext(e)
} else { } else {
l.head = e l.head = e
} }
} }
// Remove removes e from l. // Remove removes e from l.
func (l *segmentList) Remove(e *segment) { func (l *segmentList) Remove(e *segment) {
prev := segmentElementMapper{}.linkerFor(e).Prev() prev := segmentElementMapper{}.linkerFor(e).Prev()
next := segmentElementMapper{}.linkerFor(e).Next() next := segmentElementMapper{}.linkerFor(e).Next()
if prev != nil { if prev != nil {
segmentElementMapper{}.linkerFor(prev).SetNext(next) segmentElementMapper{}.linkerFor(prev).SetNext(next)
} else { } else {
l.head = next l.head = next
} }
if next != nil { if next != nil {
segmentElementMapper{}.linkerFor(next).SetPrev(prev) segmentElementMapper{}.linkerFor(next).SetPrev(prev)
} else { } else {
l.tail = prev l.tail = prev
} }
} }
// Entry is a default implementation of Linker. Users can add anonymous fields // Entry is a default implementation of Linker. Users can add anonymous fields
// of this type to their structs to make them automatically implement the // of this type to their structs to make them automatically implement the
// methods needed by List. // methods needed by List.
// //
// +stateify savable // +stateify savable
type segmentEntry struct { type segmentEntry struct {
next *segment next *segment
prev *segment prev *segment
} }
// Next returns the entry that follows e in the list. // Next returns the entry that follows e in the list.
func (e *segmentEntry) Next() *segment { func (e *segmentEntry) Next() *segment {
return e.next return e.next
} }
// Prev returns the entry that precedes e in the list. // Prev returns the entry that precedes e in the list.
func (e *segmentEntry) Prev() *segment { func (e *segmentEntry) Prev() *segment {
return e.prev return e.prev
} }
// SetNext assigns 'entry' as the entry that follows e in the list. // SetNext assigns 'entry' as the entry that follows e in the list.
func (e *segmentEntry) SetNext(elem *segment) { func (e *segmentEntry) SetNext(elem *segment) {
e.next = elem e.next = elem
} }
// SetPrev assigns 'entry' as the entry that precedes e in the list. // SetPrev assigns 'entry' as the entry that precedes e in the list.
func (e *segmentEntry) SetPrev(elem *segment) { func (e *segmentEntry) SetPrev(elem *segment) {
e.prev = elem e.prev = elem
} }