diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index 1cfbfd2..6ad9e4e 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -1,280 +1,280 @@ -package main - -import ( - "flag" - "fmt" - "log" - "net" - "netstack/tcpip" - "netstack/tcpip/link/fdbased" - "netstack/tcpip/link/tuntap" - "netstack/tcpip/network/arp" - "netstack/tcpip/network/ipv4" - "netstack/tcpip/network/ipv6" - "netstack/tcpip/stack" - "netstack/tcpip/transport/tcp" - "netstack/tcpip/transport/udp" - "netstack/waiter" - "os" - "os/signal" - "strconv" - "strings" - "syscall" -) - -var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") - -func main() { - flag.Parse() - if len(flag.Args()) != 4 { - log.Fatal("Usage: ", os.Args[0], " ") - } - - log.SetFlags(log.Lshortfile | log.LstdFlags) - - tapName := flag.Arg(0) - cidrName := flag.Arg(1) - addrName := flag.Arg(2) - portName := flag.Arg(3) - - log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName) - - maddr, err := net.ParseMAC(*mac) - if err != nil { - log.Fatalf("Bad MAC address: %v", *mac) - } - - parsedAddr := net.ParseIP(addrName) - if err != nil { - log.Fatalf("Bad addrress: %v", addrName) - } - - // 解析地址ip地址,ipv4或者ipv6地址都支持 - var addr tcpip.Address - var proto tcpip.NetworkProtocolNumber - if parsedAddr.To4() != nil { - addr = tcpip.Address(parsedAddr.To4()) - proto = ipv4.ProtocolNumber - } else if parsedAddr.To16() != nil { - addr = tcpip.Address(parsedAddr.To16()) - proto = ipv6.ProtocolNumber - } else { - log.Fatalf("Unknown IP type: %v", parsedAddr) - } - - localPort, err := strconv.Atoi(portName) - if err != nil { - log.Fatalf("Unable to convert port %v: %v", portName, err) - } - - // 虚拟网卡配置 - conf := &tuntap.Config{ - Name: tapName, - Mode: tuntap.TAP, - } - - var fd int - // 新建虚拟网卡 - fd, err = tuntap.NewNetDev(conf) - if err != nil { - log.Fatal(err) - } - - // 启动tap网卡 - _ = tuntap.SetLinkUp(tapName) - // 设置路由 - _ = tuntap.SetRoute(tapName, cidrName) - - // 抽象的文件接口 - linkID := fdbased.New(&fdbased.Options{ - FD: fd, - MTU: 1500, - Address: tcpip.LinkAddress(maddr), - ResolutionRequired: true, - }) - - // 新建相关协议的协议栈 - s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, - []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) - - // 新建抽象的网卡 - if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil { - log.Fatal(err) - } - - // 在该协议栈上添加和注册相应的网络层 - if err := s.AddAddress(1, proto, addr); err != nil { - log.Fatal(err) - } - - // 在该协议栈上添加和注册ARP协议 - if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { - log.Fatal(err) - } - - // 添加默认路由 - s.SetRouteTable([]tcpip.Route{ - { - Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), - Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), - Gateway: "", - NIC: 1, - }, - }) - - //go func() { // echo server - // // 监听udp localPort端口 - // conn := udpListen(s, proto, addr, localPort) - - // for { - // buf := make([]byte, 1024) - // n, err := conn.Read(buf) - // if err != nil { - // log.Println(err) - // break - // } - // log.Println("接收到数据", string(buf[:n])) - // conn.Write([]byte("server echo")) - // } - // // 关闭监听服务,此时会释放端口 - // conn.Close() - //}() - - go func() { // echo server - listener := tcpListen(s, proto, addr, localPort) - for { - conn, err := listener.Accept() - if err != nil { - log.Println(err) - continue - } - conn.Read(nil) - } - }() - - c := make(chan os.Signal) - signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2) - <-c -} - -type TcpConn struct { - raddr tcpip.FullAddress - ep tcpip.Endpoint - wq *waiter.Queue - we *waiter.Entry - notifyCh chan struct{} -} - -// Accept 封装tcp的accept操作 -func (conn *TcpConn) Accept() (tcpip.Endpoint, error) { - conn.wq.EventRegister(conn.we, waiter.EventIn) - defer conn.wq.EventUnregister(conn.we) - for { - ep, _, err := conn.ep.Accept() - if err != nil { - if err == tcpip.ErrWouldBlock { - <-conn.notifyCh - continue - } - return nil, fmt.Errorf("%s", err.String()) - } - return ep, nil - } -} - -func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn { - var wq waiter.Queue - // 新建一个tcp端 - ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) - if err != nil { - log.Fatal(err) - } - - // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP - // 此时就会调用端口管理器 - if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil { - log.Fatal("Bind failed: ", err) - } - - // 开始监听 - if err := ep.Listen(10); err != nil { - log.Fatal("Listen failed: ", err) - } - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - return &TcpConn{ - ep: ep, - wq: &wq, - we: &waitEntry, - notifyCh: notifyCh} -} - -type UdpConn struct { - raddr tcpip.FullAddress - ep tcpip.Endpoint - wq *waiter.Queue - we *waiter.Entry - notifyCh chan struct{} -} - -func (conn *UdpConn) Close() { - conn.ep.Close() -} - -func (conn *UdpConn) Read(rcv []byte) (int, error) { - conn.wq.EventRegister(conn.we, waiter.EventIn) - defer conn.wq.EventUnregister(conn.we) - for { - buf, _, err := conn.ep.Read(&conn.raddr) - if err != nil { - if err == tcpip.ErrWouldBlock { - <-conn.notifyCh - continue - } - return 0, fmt.Errorf("%s", err.String()) - } - n := len(buf) - if n > cap(rcv) { - n = cap(rcv) - } - rcv = append(rcv[:0], buf[:n]...) - return n, nil - } -} - -func (conn *UdpConn) Write(snd []byte) error { - for { - _, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr}) - if err != nil { - if err == tcpip.ErrNoLinkAddress { - <-notifyCh - continue - } - return fmt.Errorf("%s", err.String()) - } - return nil - } -} - -func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn { - var wq waiter.Queue - // 新建一个udp端 - ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) - if err != nil { - log.Fatal(err) - } - - // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP - // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 - // 此时就会调用端口管理器 - if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil { - log.Fatal("Bind failed: ", err) - } - - waitEntry, notifyCh := waiter.NewChannelEntry(nil) - return &UdpConn{ - ep: ep, - wq: &wq, - we: &waitEntry, - notifyCh: notifyCh} -} +package main + +import ( + "flag" + "fmt" + "log" + "net" + "netstack/tcpip" + "netstack/tcpip/link/fdbased" + "netstack/tcpip/link/tuntap" + "netstack/tcpip/network/arp" + "netstack/tcpip/network/ipv4" + "netstack/tcpip/network/ipv6" + "netstack/tcpip/stack" + "netstack/tcpip/transport/tcp" + "netstack/tcpip/transport/udp" + "netstack/waiter" + "os" + "os/signal" + "strconv" + "strings" + "syscall" +) + +var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") + +func main() { + flag.Parse() + if len(flag.Args()) != 4 { + log.Fatal("Usage: ", os.Args[0], " ") + } + + log.SetFlags(log.Lshortfile | log.LstdFlags) + + tapName := flag.Arg(0) + cidrName := flag.Arg(1) + addrName := flag.Arg(2) + portName := flag.Arg(3) + + log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName) + + maddr, err := net.ParseMAC(*mac) + if err != nil { + log.Fatalf("Bad MAC address: %v", *mac) + } + + parsedAddr := net.ParseIP(addrName) + if err != nil { + log.Fatalf("Bad addrress: %v", addrName) + } + + // 解析地址ip地址,ipv4或者ipv6地址都支持 + var addr tcpip.Address + var proto tcpip.NetworkProtocolNumber + if parsedAddr.To4() != nil { + addr = tcpip.Address(parsedAddr.To4()) + proto = ipv4.ProtocolNumber + } else if parsedAddr.To16() != nil { + addr = tcpip.Address(parsedAddr.To16()) + proto = ipv6.ProtocolNumber + } else { + log.Fatalf("Unknown IP type: %v", parsedAddr) + } + + localPort, err := strconv.Atoi(portName) + if err != nil { + log.Fatalf("Unable to convert port %v: %v", portName, err) + } + + // 虚拟网卡配置 + conf := &tuntap.Config{ + Name: tapName, + Mode: tuntap.TAP, + } + + var fd int + // 新建虚拟网卡 + fd, err = tuntap.NewNetDev(conf) + if err != nil { + log.Fatal(err) + } + + // 启动tap网卡 + _ = tuntap.SetLinkUp(tapName) + // 设置路由 + _ = tuntap.SetRoute(tapName, cidrName) + + // 抽象的文件接口 + linkID := fdbased.New(&fdbased.Options{ + FD: fd, + MTU: 1500, + Address: tcpip.LinkAddress(maddr), + ResolutionRequired: true, + }) + + // 新建相关协议的协议栈 + s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, + []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) + + // 新建抽象的网卡 + if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil { + log.Fatal(err) + } + + // 在该协议栈上添加和注册相应的网络层 + if err := s.AddAddress(1, proto, addr); err != nil { + log.Fatal(err) + } + + // 在该协议栈上添加和注册ARP协议 + if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + log.Fatal(err) + } + + // 添加默认路由 + s.SetRouteTable([]tcpip.Route{ + { + Destination: tcpip.Address(strings.Repeat("\x00", len(addr))), + Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))), + Gateway: "", + NIC: 1, + }, + }) + + //go func() { // echo server + // // 监听udp localPort端口 + // conn := udpListen(s, proto, addr, localPort) + + // for { + // buf := make([]byte, 1024) + // n, err := conn.Read(buf) + // if err != nil { + // log.Println(err) + // break + // } + // log.Println("接收到数据", string(buf[:n])) + // conn.Write([]byte("server echo")) + // } + // // 关闭监听服务,此时会释放端口 + // conn.Close() + //}() + + go func() { // echo server + listener := tcpListen(s, proto, addr, localPort) + for { + conn, err := listener.Accept() + if err != nil { + log.Println(err) + continue + } + conn.Read(nil) + } + }() + + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2) + <-c +} + +type TcpConn struct { + raddr tcpip.FullAddress + ep tcpip.Endpoint + wq *waiter.Queue + we *waiter.Entry + notifyCh chan struct{} +} + +// Accept 封装tcp的accept操作 +func (conn *TcpConn) Accept() (tcpip.Endpoint, error) { + conn.wq.EventRegister(conn.we, waiter.EventIn) + defer conn.wq.EventUnregister(conn.we) + for { + ep, _, err := conn.ep.Accept() + if err != nil { + if err == tcpip.ErrWouldBlock { + <-conn.notifyCh + continue + } + return nil, fmt.Errorf("%s", err.String()) + } + return ep, nil + } +} + +func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn { + var wq waiter.Queue + // 新建一个tcp端 + ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) + if err != nil { + log.Fatal(err) + } + + // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP + // 此时就会调用端口管理器 + if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil { + log.Fatal("Bind failed: ", err) + } + + // 开始监听 + if err := ep.Listen(10); err != nil { + log.Fatal("Listen failed: ", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + return &TcpConn{ + ep: ep, + wq: &wq, + we: &waitEntry, + notifyCh: notifyCh} +} + +type UdpConn struct { + raddr tcpip.FullAddress + ep tcpip.Endpoint + wq *waiter.Queue + we *waiter.Entry + notifyCh chan struct{} +} + +func (conn *UdpConn) Close() { + conn.ep.Close() +} + +func (conn *UdpConn) Read(rcv []byte) (int, error) { + conn.wq.EventRegister(conn.we, waiter.EventIn) + defer conn.wq.EventUnregister(conn.we) + for { + buf, _, err := conn.ep.Read(&conn.raddr) + if err != nil { + if err == tcpip.ErrWouldBlock { + <-conn.notifyCh + continue + } + return 0, fmt.Errorf("%s", err.String()) + } + n := len(buf) + if n > cap(rcv) { + n = cap(rcv) + } + rcv = append(rcv[:0], buf[:n]...) + return n, nil + } +} + +func (conn *UdpConn) Write(snd []byte) error { + for { + _, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr}) + if err != nil { + if err == tcpip.ErrNoLinkAddress { + <-notifyCh + continue + } + return fmt.Errorf("%s", err.String()) + } + return nil + } +} + +func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn { + var wq waiter.Queue + // 新建一个udp端 + ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) + if err != nil { + log.Fatal(err) + } + + // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP + // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 + // 此时就会调用端口管理器 + if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil { + log.Fatal("Bind failed: ", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + return &UdpConn{ + ep: ep, + wq: &wq, + we: &waitEntry, + notifyCh: notifyCh} +} diff --git a/cmd/netstack/tcp_server.go b/cmd/netstack/tcp_server.go index a8f888d..24f6148 100644 --- a/cmd/netstack/tcp_server.go +++ b/cmd/netstack/tcp_server.go @@ -1,101 +1,101 @@ -package main - -import ( - "encoding/binary" - "fmt" - "io" - "log" - "net" - "netstack/tcpip" - "netstack/tcpip/header" - "netstack/tcpip/stack" - "netstack/tcpip/transport/udp" - "netstack/waiter" - "runtime" - "strings" -) - -type TCPHandler interface { - Handle(net.Conn) -} - -func TCPServer(listener net.Listener, handler TCPHandler) error { - log.Printf("netstack 网络解析地址: %s", listener.Addr()) - - for { - clientConn, err := listener.Accept() - if err != nil { - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { - log.Printf("temporary Accept() failure - %s", err) - runtime.Gosched() - continue - } - // theres no direct way to detect this error because it is not exposed - if !strings.Contains(err.Error(), "use of closed network connection") { - return fmt.Errorf("listener.Accept() error - %s", err) - } - break - } - go handler.Handle(clientConn) - } - - log.Printf("TCP: closing %s", listener.Addr()) - - return nil -} - -var transportPool = make(map[uint64]tcpip.Endpoint) - -type RCV struct { - *stack.Stack - ep tcpip.Endpoint - addr tcpip.FullAddress - rcvBuf []byte -} - -func (r *RCV) Handle(conn net.Conn) { - var err error - r.rcvBuf, err = io.ReadAll(conn) - if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port - panic(err) - } - - switch string(r.rcvBuf[:3]) { - case "udp": - var wq waiter.Queue - // 新建一个udp端 - ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) - if err != nil { - log.Fatal(err) - } - r.ep = ep - r.Bind() - r.Connect() - r.Close() - case "tcp": - default: - return - } -} - -func (r *RCV) Bind() { - if len(r.rcvBuf) < 9 { // udp ip port - log.Println("Error: too few arg") - return - } - port := binary.BigEndian.Uint16(r.rcvBuf[7:9]) - r.addr = tcpip.FullAddress{ - NIC: 1, - Addr: tcpip.Address(r.rcvBuf[3:7]), - Port: port, - } - r.ep.Bind(r.addr, nil) -} - -func (r *RCV) Connect() { - r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888}) -} - -func (r *RCV) Close() { - r.ep.Close() -} +package main + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "net" + "netstack/tcpip" + "netstack/tcpip/header" + "netstack/tcpip/stack" + "netstack/tcpip/transport/udp" + "netstack/waiter" + "runtime" + "strings" +) + +type TCPHandler interface { + Handle(net.Conn) +} + +func TCPServer(listener net.Listener, handler TCPHandler) error { + log.Printf("netstack 网络解析地址: %s", listener.Addr()) + + for { + clientConn, err := listener.Accept() + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + log.Printf("temporary Accept() failure - %s", err) + runtime.Gosched() + continue + } + // theres no direct way to detect this error because it is not exposed + if !strings.Contains(err.Error(), "use of closed network connection") { + return fmt.Errorf("listener.Accept() error - %s", err) + } + break + } + go handler.Handle(clientConn) + } + + log.Printf("TCP: closing %s", listener.Addr()) + + return nil +} + +var transportPool = make(map[uint64]tcpip.Endpoint) + +type RCV struct { + *stack.Stack + ep tcpip.Endpoint + addr tcpip.FullAddress + rcvBuf []byte +} + +func (r *RCV) Handle(conn net.Conn) { + var err error + r.rcvBuf, err = io.ReadAll(conn) + if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port + panic(err) + } + + switch string(r.rcvBuf[:3]) { + case "udp": + var wq waiter.Queue + // 新建一个udp端 + ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq) + if err != nil { + log.Fatal(err) + } + r.ep = ep + r.Bind() + r.Connect() + r.Close() + case "tcp": + default: + return + } +} + +func (r *RCV) Bind() { + if len(r.rcvBuf) < 9 { // udp ip port + log.Println("Error: too few arg") + return + } + port := binary.BigEndian.Uint16(r.rcvBuf[7:9]) + r.addr = tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.Address(r.rcvBuf[3:7]), + Port: port, + } + r.ep.Bind(r.addr, nil) +} + +func (r *RCV) Connect() { + r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888}) +} + +func (r *RCV) Close() { + r.ep.Close() +} diff --git a/cmd/udp_client/main.go b/cmd/udp_client/main.go index 442d05b..02f5a5f 100644 --- a/cmd/udp_client/main.go +++ b/cmd/udp_client/main.go @@ -1,41 +1,41 @@ -package main - -import ( - "flag" - "log" - "net" -) - -func main() { - var ( - addr = flag.String("a", "192.168.1.1:9999", "udp dst address") - ) - - log.SetFlags(log.Lshortfile | log.LstdFlags) - - var err error - udpAddr, err := net.ResolveUDPAddr("udp", *addr) - if err != nil { - panic(err) - } - - // 建立UDP连接(只是填息了目的IP和端口,并未真正的建立连接) - conn, err := net.DialUDP("udp", nil, udpAddr) - if err != nil { - panic(err) - } - - //send := []byte("hello world") - send := make([]byte, 1600) - if _, err := conn.Write(send); err != nil { - panic(err) - } - log.Printf("send: %s", string(send)) - - recv := make([]byte, 32) - rn, _, err := conn.ReadFrom(recv) - if err != nil { - panic(err) - } - log.Printf("recv: %s", string(recv[:rn])) -} +package main + +import ( + "flag" + "log" + "net" +) + +func main() { + var ( + addr = flag.String("a", "192.168.1.1:9999", "udp dst address") + ) + + log.SetFlags(log.Lshortfile | log.LstdFlags) + + var err error + udpAddr, err := net.ResolveUDPAddr("udp", *addr) + if err != nil { + panic(err) + } + + // 建立UDP连接(只是填息了目的IP和端口,并未真正的建立连接) + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + panic(err) + } + + //send := []byte("hello world") + send := make([]byte, 1600) + if _, err := conn.Write(send); err != nil { + panic(err) + } + log.Printf("send: %s", string(send)) + + recv := make([]byte, 32) + rn, _, err := conn.ReadFrom(recv) + if err != nil { + panic(err) + } + log.Printf("recv: %s", string(recv[:rn])) +} diff --git a/tcpip/header/tcp.go b/tcpip/header/tcp.go index 73569fa..25c5377 100644 --- a/tcpip/header/tcp.go +++ b/tcpip/header/tcp.go @@ -1,609 +1,609 @@ -package header - -import ( - "encoding/binary" - "fmt" - "netstack/tcpip" - "netstack/tcpip/seqnum" -) - -/* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Source Port | Destination Port | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Sequence Number | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Acknowledgment Number | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Data | |U|A|P|R|S|F| | -| Offset| Reserved |R|C|S|S|Y|I| Window | -| | |G|K|H|T|N|N| | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Checksum | Urgent Pointer | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Options | Padding | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| data | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -*/ - -// TCPFields contains the fields of a TCP packet. It is used to describe the -// fields of a packet that needs to be encoded. -// tcp首部字段 -type TCPFields struct { - // SrcPort is the "source port" field of a TCP packet. - SrcPort uint16 - - // DstPort is the "destination port" field of a TCP packet. - DstPort uint16 - - // SeqNum is the "sequence number" field of a TCP packet. - // TCP的初始序列号ISN是随机生成的 - // 如果TCP每次连接都使用固定ISN,黑客可以很方便模拟任何IP与server建立连接 - // 如果ISN是固定的,那很可能在新连接建立后,上次连接通信的报文才到达, - // 这种情况有概率发生老报文的seq号正好是server希望收到的新连接的报文seq。这就全乱了。 - SeqNum uint32 - - // AckNum is the "acknowledgement number" field of a TCP packet. - AckNum uint32 - - // DataOffset is the "data offset" field of a TCP packet. - DataOffset uint8 - - // Flags is the "flags" field of a TCP packet. - Flags uint8 - - // WindowSize is the "window size" field of a TCP packet. - WindowSize uint16 - - // Checksum is the "checksum" field of a TCP packet. - Checksum uint16 - - // UrgentPointer is the "urgent pointer" field of a TCP packet. - UrgentPointer uint16 -} - -// TCPSynOptions is used to return the parsed TCP Options in a syn -// segment. -// syn 报文的选项 -type TCPSynOptions struct { - // MSS is the maximum segment size provided by the peer in the SYN. - MSS uint16 - - // WS is the window scale option provided by the peer in the SYN. - // - // Set to -1 if no window scale option was provided. - WS int - - // TS is true if the timestamp option was provided in the syn/syn-ack. - TS bool - - // TSVal is the value of the TSVal field in the timestamp option. - TSVal uint32 - - // TSEcr is the value of the TSEcr field in the timestamp option. - TSEcr uint32 - - // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK. - SACKPermitted bool -} - -const ( - srcPort = 0 - dstPort = 2 - seqNum = 4 - ackNum = 8 - dataOffset = 12 - tcpFlags = 13 - winSize = 14 - tcpChecksum = 16 - urgentPtr = 18 -) - -// Options that may be present in a TCP segment. -const ( - // 选项表结束选项 - TCPOptionEOL = 0 - // 空操作(nop)选项 - TCPOptionNOP = 1 - // 最大报文段长度选项 - TCPOptionMSS = 2 - // 窗口扩大因子选项 - TCPOptionWS = 3 - // 时间戳选项 - TCPOptionTS = 8 - // 选择性确认(Selective Acknowledgment,SACK)选项 - TCPOptionSACKPermitted = 4 - // SACK 实际工作的选项 - TCPOptionSACK = 5 -) - -const ( - // MaxWndScale is maximum allowed window scaling, as described in - // RFC 1323, section 2.3, page 11. - MaxWndScale = 14 - - // TCPMaxSACKBlocks is the maximum number of SACK blocks that can - // be encoded in a TCP option field. - TCPMaxSACKBlocks = 4 -) - -// SACKBlock 表示 sack 块的结构体 -type SACKBlock struct { - // Start indicates the lowest sequence number in the block. - Start seqnum.Value - - // End indicates the sequence number immediately following the last - // sequence number of this block. - End seqnum.Value -} - -/* - 1byte 1byte nbytes -+--------+--------+------------------+ -| Kind | Length | Info | -+--------+--------+------------------+ -*/ - -// TCPOptions tcp选项结构,这个结构不表示 syn/syn-ack 报文 -type TCPOptions struct { - // TS is true if the TimeStamp option is enabled. - TS bool - - // TSVal is the value in the TSVal field of the segment. - TSVal uint32 - - // TSEcr is the value in the TSEcr field of the segment. - TSEcr uint32 - - // SACKBlocks are the SACK blocks specified in the segment. - SACKBlocks []SACKBlock - - // 以下仅供测试之用 不对外暴露 - - // MSS is the maximum segment size provided by the peer in the SYN. - mss uint16 - - // WS is the window scale option provided by the peer in the SYN. - // - // Set to -1 if no window scale option was provided. - ws int - - // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK. - sackPermitted bool -} - -// TCP represents a TCP header stored in a byte array. -type TCP []byte - -const ( - // TCPMinimumSize is the minimum size of a valid TCP packet. - TCPMinimumSize = 20 - - // TCPProtocolNumber is TCP's transport protocol number. - TCPProtocolNumber tcpip.TransportProtocolNumber = 6 -) - -// SourcePort returns the "source port" field of the tcp header. -func (b TCP) SourcePort() uint16 { - return binary.BigEndian.Uint16(b[srcPort:]) -} - -// DestinationPort returns the "destination port" field of the tcp header. -func (b TCP) DestinationPort() uint16 { - return binary.BigEndian.Uint16(b[dstPort:]) -} - -// SequenceNumber returns the "sequence number" field of the tcp header. -func (b TCP) SequenceNumber() uint32 { - return binary.BigEndian.Uint32(b[seqNum:]) -} - -// AckNumber returns the "ack number" field of the tcp header. -func (b TCP) AckNumber() uint32 { - return binary.BigEndian.Uint32(b[ackNum:]) -} - -// DataOffset returns the "data offset" field of the tcp header. -func (b TCP) DataOffset() uint8 { - return (b[dataOffset] >> 4) * 4 // 以32bits为单位 最小为5 20bytes -} - -// Payload returns the data in the tcp packet. -func (b TCP) Payload() []byte { - return b[b.DataOffset():] -} - -// TCPViewSize TCP报文概览长度 -const TCPViewSize = IPViewSize - TCPMinimumSize - -func (b TCP) viewPayload() []byte { - if len(b.Payload())-int(b.DataOffset()) < TCPViewSize { - return b.Payload() - } - return b[b.DataOffset():][:TCPViewSize] -} - -// Flags returns the flags field of the tcp header. -func (b TCP) Flags() uint8 { - return b[tcpFlags] -} - -// WindowSize returns the "window size" field of the tcp header. -func (b TCP) WindowSize() uint16 { - return binary.BigEndian.Uint16(b[winSize:]) -} - -// Checksum returns the "checksum" field of the tcp header. -func (b TCP) Checksum() uint16 { - return binary.BigEndian.Uint16(b[tcpChecksum:]) -} - -// UrgentPtr returns the "urgentptr" field of the tcp header. -func (b TCP) UrgentPtr() uint16 { - return binary.BigEndian.Uint16(b[urgentPtr:]) -} - -// SetSourcePort sets the "source port" field of the tcp header. -func (b TCP) SetSourcePort(port uint16) { - binary.BigEndian.PutUint16(b[srcPort:], port) -} - -// SetDestinationPort sets the "destination port" field of the tcp header. -func (b TCP) SetDestinationPort(port uint16) { - binary.BigEndian.PutUint16(b[dstPort:], port) -} - -// SetChecksum sets the checksum field of the tcp header. -func (b TCP) SetChecksum(checksum uint16) { - binary.BigEndian.PutUint16(b[tcpChecksum:], checksum) -} - -// CalculateChecksum calculates the checksum of the tcp segment given -// the totalLen and partialChecksum(descriptions below) -// totalLen is the total length of the segment -// partialChecksum is the checksum of the network-layer pseudo-header -// (excluding the total length) and the checksum of the segment data. -func (b TCP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { - // Add the length portion of the checksum to the pseudo-checksum. - tmp := make([]byte, 2) - binary.BigEndian.PutUint16(tmp, totalLen) - checksum := Checksum(tmp, partialChecksum) - - // Calculate the rest of the checksum. - return Checksum(b[:b.DataOffset()], checksum) -} - -// Options returns a slice that holds the unparsed TCP options in the segment. -func (b TCP) Options() []byte { - return b[TCPMinimumSize:b.DataOffset()] -} - -func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) { - binary.BigEndian.PutUint32(b[seqNum:], seq) - binary.BigEndian.PutUint32(b[ackNum:], ack) - b[tcpFlags] = flags - binary.BigEndian.PutUint16(b[winSize:], rcvwnd) -} - -// Encode encodes all the fields of the tcp header. -func (b TCP) Encode(t *TCPFields) { - b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) - binary.BigEndian.PutUint16(b[srcPort:], t.SrcPort) - binary.BigEndian.PutUint16(b[dstPort:], t.DstPort) - b[dataOffset] = (t.DataOffset / 4) << 4 - binary.BigEndian.PutUint16(b[tcpChecksum:], t.Checksum) - binary.BigEndian.PutUint16(b[urgentPtr:], t.UrgentPointer) -} - -// ParseTCPOptions extracts and stores all known options in the provided byte -// slice in a TCPOptions structure. -func ParseTCPOptions(b []byte) TCPOptions { - opts := TCPOptions{} - limit := len(b) - for i := 0; i < limit; { - switch b[i] { - case TCPOptionEOL: // 末尾 - i = limit - case TCPOptionNOP: // 空值 - i++ - case TCPOptionMSS: - if i+4 > limit || b[i+1] != 4 { - return opts - } - mss := uint16(b[i+2])<<8 | uint16(b[i+3]) - if mss == 0 { - return opts - } - opts.mss = mss - i += 4 - case TCPOptionWS: - if i+3 > limit || b[i+1] != 3 { - return opts - } - ws := int(b[i+2]) - if ws > MaxWndScale { - ws = MaxWndScale - } - opts.ws = ws - i += 3 - case TCPOptionTS: // 计时 - if i+10 > limit || (b[i+1] != 10) { - return opts - } - opts.TS = true - opts.TSVal = binary.BigEndian.Uint32(b[i+2:]) - opts.TSEcr = binary.BigEndian.Uint32(b[i+6:]) - i += 10 - case TCPOptionSACKPermitted: - if i+2 > limit || b[i+1] != 2 { - return opts - } - opts.sackPermitted = true - i += 2 - case TCPOptionSACK: - if i+2 > limit { - // Malformed SACK block, just return and stop parsing. - return opts - } - sackOptionLen := int(b[i+1]) - numBlocks := (sackOptionLen - 2) / 8 // 去头 每个block长为8 - opts.SACKBlocks = []SACKBlock{} - for j := 0; j < numBlocks; j++ { - start := binary.BigEndian.Uint32(b[i+2+j*8:]) - end := binary.BigEndian.Uint32(b[i+2+j*8+4:]) - opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{ - Start: seqnum.Value(start), - End: seqnum.Value(end), - }) - } - - i += sackOptionLen - default: - // 这里不做进一步解析 留到后面进行 - if i+2 > limit { - return opts - } - l := int(b[i+1]) - // If the length is incorrect or if l+i overflows the - // total options length then return false. - if l < 2 || i+l > limit { - return opts - } - i += l - } - } - - return opts -} - -func (opts TCPOptions) String() string { - return fmt.Sprintf("|MSS|% 29d|\n|WS |% 29d|\n|TS |% 29v|\n|TSV|% 29d|\n|TSE|% 29d|\n|SP |% 29v|\n|SBS|%v|", - opts.mss, opts.ws, opts.TS, opts.TSVal, opts.TSEcr, opts.sackPermitted, opts.SACKBlocks) -} - -// ParseSynOptions parses the options received in a SYN segment and returns the -// relevant ones. opts should point to the option part of the TCP Header. -func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { - synOpts := TCPSynOptions{ - // Per RFC 1122, page 85: "If an MSS option is not received at - // connection setup, TCP MUST assume a default send MSS of 536." - MSS: 536, - // If no window scale option is specified, WS in options is - // returned as -1; this is because the absence of the option - // indicates that the we cannot use window scaling on the - // receive end either. - WS: -1, - } - - limit := len(opts) - for i := 0; i < limit; { - switch opts[i] { - case TCPOptionEOL: - i = limit - case TCPOptionNOP: - i++ - case TCPOptionMSS: - if i+4 > limit || opts[i+1] != 4 { - return synOpts - } - mss := uint16(opts[i+2])<<8 | uint16(opts[i+3]) - if mss == 0 { - return synOpts - } - synOpts.MSS = mss - i += 4 - case TCPOptionWS: - if i+3 > limit || opts[i+1] != 3 { - return synOpts - } - ws := int(opts[i+2]) - if ws > MaxWndScale { - ws = MaxWndScale - } - synOpts.WS = ws - i += 3 - case TCPOptionTS: - if i+10 > limit || opts[i+1] != 10 { - return synOpts - } - synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:]) - if isAck { // ACK报文需要记录时间间隔 - // If the segment is a SYN-ACK then store the Timestamp Echo Reply - // in the segment. - synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:]) - } - synOpts.TS = true - i += 10 - case TCPOptionSACKPermitted: - if i+2 > limit || opts[i+1] != 2 { - return synOpts - } - synOpts.SACKPermitted = true - i += 2 - default: - // We don't recognize this option, just skip over it. - if i+2 > limit { - return synOpts - } - l := int(opts[i+1]) - // If the length is incorrect or if l+i overflows the - // total options length then return false. - if l < 2 || i+l > limit { - return synOpts - } - i += l - } - } - return synOpts -} - -func (opts TCPSynOptions) String() string { - return fmt.Sprintf("|%d|%d|%v|%d|%d|%v|", opts.MSS, opts.WS, opts.TS, opts.TSVal, opts.TSEcr, opts.SACKPermitted) -} - -// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in -// the supplied buffer. If the provided buffer is not large enough then it just -// returns without encoding anything. It returns the number of bytes written to -// the provided buffer. -func EncodeMSSOption(mss uint32, b []byte) int { - // mssOptionSize is the number of bytes in a valid MSS option. - const mssOptionSize = 4 - - if len(b) < mssOptionSize { - return 0 - } - b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) - return mssOptionSize -} - -// EncodeWSOption encodes the WS TCP option with the WS value in the -// provided buffer. If the provided buffer is not large enough then it just -// returns without encoding anything. It returns the number of bytes written to -// the provided buffer. -func EncodeWSOption(ws int, b []byte) int { - if len(b) < 3 { - return 0 - } - b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) - return int(b[1]) -} - -// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp -// option into the provided buffer. If the buffer is smaller than expected it -// just returns without encoding anything. It returns the number of bytes -// written to the provided buffer. -func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { - if len(b) < 10 { - return 0 - } - b[0], b[1] = TCPOptionTS, 10 - binary.BigEndian.PutUint32(b[2:], tsVal) - binary.BigEndian.PutUint32(b[6:], tsEcr) - return int(b[1]) -} - -// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided -// buffer. If the buffer is smaller than required it just returns without -// encoding anything. It returns the number of bytes written to the provided -// buffer. -func EncodeSACKPermittedOption(b []byte) int { - if len(b) < 2 { - return 0 - } - - b[0], b[1] = TCPOptionSACKPermitted, 2 - return int(b[1]) -} - -// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block -// in the provided slice. It tries to fit in as many blocks as possible based on -// number of bytes available in the provided buffer. It returns the number of -// bytes written to the provided buffer. -func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int { - if len(sackBlocks) == 0 { - return 0 - } - l := len(sackBlocks) - if l > TCPMaxSACKBlocks { - l = TCPMaxSACKBlocks - } - if ll := (len(b) - 2) / 8; ll < l { - l = ll - } - if l == 0 { - // There is not enough space in the provided buffer to add - // any SACK blocks. - return 0 - } - b[0] = TCPOptionSACK - b[1] = byte(l*8 + 2) - for i := 0; i < l; i++ { - binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start)) - binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End)) - } - return int(b[1]) -} - -// EncodeNOP adds an explicit NOP to the option list. -func EncodeNOP(b []byte) int { - if len(b) == 0 { - return 0 - } - b[0] = TCPOptionNOP - return 1 -} - -// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align -// the option buffer. It adds padding bytes after the offset specified and -// returns the number of padding bytes added. The passed in options slice -// must have space for the padding bytes. -func AddTCPOptionPadding(options []byte, offset int) int { - paddingToAdd := -offset & 3 - // Now add any padding bytes that might be required to quad align the - // options. - for i := offset; i < offset+paddingToAdd; i++ { - options[i] = TCPOptionNOP - } - return paddingToAdd -} - -/* - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Source Port | Destination Port | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Sequence Number | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Acknowledgment Number | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Data | |U|A|P|R|S|F| | -| Offset| Reserved |R|C|S|S|Y|I| Window | -| | |G|K|H|T|N|N| | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Checksum | Urgent Pointer | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Options | Padding | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| data | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -*/ - -var tcpFmt string = ` -|% 16s|% 16s| -|% 32s | -|% 32s | -|% 4s|% 4s|%06b|% 16s| -|% 16s|% 16s| - ---------------------------------` - -func (b TCP) String() string { - return fmt.Sprintf(tcpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), - atoi(b.SequenceNumber()), - atoi(b.AckNumber()), - atoi(b.DataOffset()), "0", b.Flags(), atoi(b.WindowSize()), - atoi(b.Checksum()), atoi(b.UrgentPtr())) -} +package header + +import ( + "encoding/binary" + "fmt" + "netstack/tcpip" + "netstack/tcpip/seqnum" +) + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Source Port | Destination Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Acknowledgment Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Data | |U|A|P|R|S|F| | +| Offset| Reserved |R|C|S|S|Y|I| Window | +| | |G|K|H|T|N|N| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Checksum | Urgent Pointer | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Options | Padding | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| data | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +// TCPFields contains the fields of a TCP packet. It is used to describe the +// fields of a packet that needs to be encoded. +// tcp首部字段 +type TCPFields struct { + // SrcPort is the "source port" field of a TCP packet. + SrcPort uint16 + + // DstPort is the "destination port" field of a TCP packet. + DstPort uint16 + + // SeqNum is the "sequence number" field of a TCP packet. + // TCP的初始序列号ISN是随机生成的 + // 如果TCP每次连接都使用固定ISN,黑客可以很方便模拟任何IP与server建立连接 + // 如果ISN是固定的,那很可能在新连接建立后,上次连接通信的报文才到达, + // 这种情况有概率发生老报文的seq号正好是server希望收到的新连接的报文seq。这就全乱了。 + SeqNum uint32 + + // AckNum is the "acknowledgement number" field of a TCP packet. + AckNum uint32 + + // DataOffset is the "data offset" field of a TCP packet. + DataOffset uint8 + + // Flags is the "flags" field of a TCP packet. + Flags uint8 + + // WindowSize is the "window size" field of a TCP packet. + WindowSize uint16 + + // Checksum is the "checksum" field of a TCP packet. + Checksum uint16 + + // UrgentPointer is the "urgent pointer" field of a TCP packet. + UrgentPointer uint16 +} + +// TCPSynOptions is used to return the parsed TCP Options in a syn +// segment. +// syn 报文的选项 +type TCPSynOptions struct { + // MSS is the maximum segment size provided by the peer in the SYN. + MSS uint16 + + // WS is the window scale option provided by the peer in the SYN. + // + // Set to -1 if no window scale option was provided. + WS int + + // TS is true if the timestamp option was provided in the syn/syn-ack. + TS bool + + // TSVal is the value of the TSVal field in the timestamp option. + TSVal uint32 + + // TSEcr is the value of the TSEcr field in the timestamp option. + TSEcr uint32 + + // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK. + SACKPermitted bool +} + +const ( + srcPort = 0 + dstPort = 2 + seqNum = 4 + ackNum = 8 + dataOffset = 12 + tcpFlags = 13 + winSize = 14 + tcpChecksum = 16 + urgentPtr = 18 +) + +// Options that may be present in a TCP segment. +const ( + // 选项表结束选项 + TCPOptionEOL = 0 + // 空操作(nop)选项 + TCPOptionNOP = 1 + // 最大报文段长度选项 + TCPOptionMSS = 2 + // 窗口扩大因子选项 + TCPOptionWS = 3 + // 时间戳选项 + TCPOptionTS = 8 + // 选择性确认(Selective Acknowledgment,SACK)选项 + TCPOptionSACKPermitted = 4 + // SACK 实际工作的选项 + TCPOptionSACK = 5 +) + +const ( + // MaxWndScale is maximum allowed window scaling, as described in + // RFC 1323, section 2.3, page 11. + MaxWndScale = 14 + + // TCPMaxSACKBlocks is the maximum number of SACK blocks that can + // be encoded in a TCP option field. + TCPMaxSACKBlocks = 4 +) + +// SACKBlock 表示 sack 块的结构体 +type SACKBlock struct { + // Start indicates the lowest sequence number in the block. + Start seqnum.Value + + // End indicates the sequence number immediately following the last + // sequence number of this block. + End seqnum.Value +} + +/* + 1byte 1byte nbytes ++--------+--------+------------------+ +| Kind | Length | Info | ++--------+--------+------------------+ +*/ + +// TCPOptions tcp选项结构,这个结构不表示 syn/syn-ack 报文 +type TCPOptions struct { + // TS is true if the TimeStamp option is enabled. + TS bool + + // TSVal is the value in the TSVal field of the segment. + TSVal uint32 + + // TSEcr is the value in the TSEcr field of the segment. + TSEcr uint32 + + // SACKBlocks are the SACK blocks specified in the segment. + SACKBlocks []SACKBlock + + // 以下仅供测试之用 不对外暴露 + + // MSS is the maximum segment size provided by the peer in the SYN. + mss uint16 + + // WS is the window scale option provided by the peer in the SYN. + // + // Set to -1 if no window scale option was provided. + ws int + + // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK. + sackPermitted bool +} + +// TCP represents a TCP header stored in a byte array. +type TCP []byte + +const ( + // TCPMinimumSize is the minimum size of a valid TCP packet. + TCPMinimumSize = 20 + + // TCPProtocolNumber is TCP's transport protocol number. + TCPProtocolNumber tcpip.TransportProtocolNumber = 6 +) + +// SourcePort returns the "source port" field of the tcp header. +func (b TCP) SourcePort() uint16 { + return binary.BigEndian.Uint16(b[srcPort:]) +} + +// DestinationPort returns the "destination port" field of the tcp header. +func (b TCP) DestinationPort() uint16 { + return binary.BigEndian.Uint16(b[dstPort:]) +} + +// SequenceNumber returns the "sequence number" field of the tcp header. +func (b TCP) SequenceNumber() uint32 { + return binary.BigEndian.Uint32(b[seqNum:]) +} + +// AckNumber returns the "ack number" field of the tcp header. +func (b TCP) AckNumber() uint32 { + return binary.BigEndian.Uint32(b[ackNum:]) +} + +// DataOffset returns the "data offset" field of the tcp header. +func (b TCP) DataOffset() uint8 { + return (b[dataOffset] >> 4) * 4 // 以32bits为单位 最小为5 20bytes +} + +// Payload returns the data in the tcp packet. +func (b TCP) Payload() []byte { + return b[b.DataOffset():] +} + +// TCPViewSize TCP报文概览长度 +const TCPViewSize = IPViewSize - TCPMinimumSize + +func (b TCP) viewPayload() []byte { + if len(b.Payload())-int(b.DataOffset()) < TCPViewSize { + return b.Payload() + } + return b[b.DataOffset():][:TCPViewSize] +} + +// Flags returns the flags field of the tcp header. +func (b TCP) Flags() uint8 { + return b[tcpFlags] +} + +// WindowSize returns the "window size" field of the tcp header. +func (b TCP) WindowSize() uint16 { + return binary.BigEndian.Uint16(b[winSize:]) +} + +// Checksum returns the "checksum" field of the tcp header. +func (b TCP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[tcpChecksum:]) +} + +// UrgentPtr returns the "urgentptr" field of the tcp header. +func (b TCP) UrgentPtr() uint16 { + return binary.BigEndian.Uint16(b[urgentPtr:]) +} + +// SetSourcePort sets the "source port" field of the tcp header. +func (b TCP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[srcPort:], port) +} + +// SetDestinationPort sets the "destination port" field of the tcp header. +func (b TCP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[dstPort:], port) +} + +// SetChecksum sets the checksum field of the tcp header. +func (b TCP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[tcpChecksum:], checksum) +} + +// CalculateChecksum calculates the checksum of the tcp segment given +// the totalLen and partialChecksum(descriptions below) +// totalLen is the total length of the segment +// partialChecksum is the checksum of the network-layer pseudo-header +// (excluding the total length) and the checksum of the segment data. +func (b TCP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { + // Add the length portion of the checksum to the pseudo-checksum. + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + checksum := Checksum(tmp, partialChecksum) + + // Calculate the rest of the checksum. + return Checksum(b[:b.DataOffset()], checksum) +} + +// Options returns a slice that holds the unparsed TCP options in the segment. +func (b TCP) Options() []byte { + return b[TCPMinimumSize:b.DataOffset()] +} + +func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) { + binary.BigEndian.PutUint32(b[seqNum:], seq) + binary.BigEndian.PutUint32(b[ackNum:], ack) + b[tcpFlags] = flags + binary.BigEndian.PutUint16(b[winSize:], rcvwnd) +} + +// Encode encodes all the fields of the tcp header. +func (b TCP) Encode(t *TCPFields) { + b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize) + binary.BigEndian.PutUint16(b[srcPort:], t.SrcPort) + binary.BigEndian.PutUint16(b[dstPort:], t.DstPort) + b[dataOffset] = (t.DataOffset / 4) << 4 + binary.BigEndian.PutUint16(b[tcpChecksum:], t.Checksum) + binary.BigEndian.PutUint16(b[urgentPtr:], t.UrgentPointer) +} + +// ParseTCPOptions extracts and stores all known options in the provided byte +// slice in a TCPOptions structure. +func ParseTCPOptions(b []byte) TCPOptions { + opts := TCPOptions{} + limit := len(b) + for i := 0; i < limit; { + switch b[i] { + case TCPOptionEOL: // 末尾 + i = limit + case TCPOptionNOP: // 空值 + i++ + case TCPOptionMSS: + if i+4 > limit || b[i+1] != 4 { + return opts + } + mss := uint16(b[i+2])<<8 | uint16(b[i+3]) + if mss == 0 { + return opts + } + opts.mss = mss + i += 4 + case TCPOptionWS: + if i+3 > limit || b[i+1] != 3 { + return opts + } + ws := int(b[i+2]) + if ws > MaxWndScale { + ws = MaxWndScale + } + opts.ws = ws + i += 3 + case TCPOptionTS: // 计时 + if i+10 > limit || (b[i+1] != 10) { + return opts + } + opts.TS = true + opts.TSVal = binary.BigEndian.Uint32(b[i+2:]) + opts.TSEcr = binary.BigEndian.Uint32(b[i+6:]) + i += 10 + case TCPOptionSACKPermitted: + if i+2 > limit || b[i+1] != 2 { + return opts + } + opts.sackPermitted = true + i += 2 + case TCPOptionSACK: + if i+2 > limit { + // Malformed SACK block, just return and stop parsing. + return opts + } + sackOptionLen := int(b[i+1]) + numBlocks := (sackOptionLen - 2) / 8 // 去头 每个block长为8 + opts.SACKBlocks = []SACKBlock{} + for j := 0; j < numBlocks; j++ { + start := binary.BigEndian.Uint32(b[i+2+j*8:]) + end := binary.BigEndian.Uint32(b[i+2+j*8+4:]) + opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{ + Start: seqnum.Value(start), + End: seqnum.Value(end), + }) + } + + i += sackOptionLen + default: + // 这里不做进一步解析 留到后面进行 + if i+2 > limit { + return opts + } + l := int(b[i+1]) + // If the length is incorrect or if l+i overflows the + // total options length then return false. + if l < 2 || i+l > limit { + return opts + } + i += l + } + } + + return opts +} + +func (opts TCPOptions) String() string { + return fmt.Sprintf("|MSS|% 29d|\n|WS |% 29d|\n|TS |% 29v|\n|TSV|% 29d|\n|TSE|% 29d|\n|SP |% 29v|\n|SBS|%v|", + opts.mss, opts.ws, opts.TS, opts.TSVal, opts.TSEcr, opts.sackPermitted, opts.SACKBlocks) +} + +// ParseSynOptions parses the options received in a SYN segment and returns the +// relevant ones. opts should point to the option part of the TCP Header. +func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { + synOpts := TCPSynOptions{ + // Per RFC 1122, page 85: "If an MSS option is not received at + // connection setup, TCP MUST assume a default send MSS of 536." + MSS: 536, + // If no window scale option is specified, WS in options is + // returned as -1; this is because the absence of the option + // indicates that the we cannot use window scaling on the + // receive end either. + WS: -1, + } + + limit := len(opts) + for i := 0; i < limit; { + switch opts[i] { + case TCPOptionEOL: + i = limit + case TCPOptionNOP: + i++ + case TCPOptionMSS: + if i+4 > limit || opts[i+1] != 4 { + return synOpts + } + mss := uint16(opts[i+2])<<8 | uint16(opts[i+3]) + if mss == 0 { + return synOpts + } + synOpts.MSS = mss + i += 4 + case TCPOptionWS: + if i+3 > limit || opts[i+1] != 3 { + return synOpts + } + ws := int(opts[i+2]) + if ws > MaxWndScale { + ws = MaxWndScale + } + synOpts.WS = ws + i += 3 + case TCPOptionTS: + if i+10 > limit || opts[i+1] != 10 { + return synOpts + } + synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:]) + if isAck { // ACK报文需要记录时间间隔 + // If the segment is a SYN-ACK then store the Timestamp Echo Reply + // in the segment. + synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:]) + } + synOpts.TS = true + i += 10 + case TCPOptionSACKPermitted: + if i+2 > limit || opts[i+1] != 2 { + return synOpts + } + synOpts.SACKPermitted = true + i += 2 + default: + // We don't recognize this option, just skip over it. + if i+2 > limit { + return synOpts + } + l := int(opts[i+1]) + // If the length is incorrect or if l+i overflows the + // total options length then return false. + if l < 2 || i+l > limit { + return synOpts + } + i += l + } + } + return synOpts +} + +func (opts TCPSynOptions) String() string { + return fmt.Sprintf("|%d|%d|%v|%d|%d|%v|", opts.MSS, opts.WS, opts.TS, opts.TSVal, opts.TSEcr, opts.SACKPermitted) +} + +// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in +// the supplied buffer. If the provided buffer is not large enough then it just +// returns without encoding anything. It returns the number of bytes written to +// the provided buffer. +func EncodeMSSOption(mss uint32, b []byte) int { + // mssOptionSize is the number of bytes in a valid MSS option. + const mssOptionSize = 4 + + if len(b) < mssOptionSize { + return 0 + } + b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss) + return mssOptionSize +} + +// EncodeWSOption encodes the WS TCP option with the WS value in the +// provided buffer. If the provided buffer is not large enough then it just +// returns without encoding anything. It returns the number of bytes written to +// the provided buffer. +func EncodeWSOption(ws int, b []byte) int { + if len(b) < 3 { + return 0 + } + b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws) + return int(b[1]) +} + +// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp +// option into the provided buffer. If the buffer is smaller than expected it +// just returns without encoding anything. It returns the number of bytes +// written to the provided buffer. +func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int { + if len(b) < 10 { + return 0 + } + b[0], b[1] = TCPOptionTS, 10 + binary.BigEndian.PutUint32(b[2:], tsVal) + binary.BigEndian.PutUint32(b[6:], tsEcr) + return int(b[1]) +} + +// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided +// buffer. If the buffer is smaller than required it just returns without +// encoding anything. It returns the number of bytes written to the provided +// buffer. +func EncodeSACKPermittedOption(b []byte) int { + if len(b) < 2 { + return 0 + } + + b[0], b[1] = TCPOptionSACKPermitted, 2 + return int(b[1]) +} + +// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block +// in the provided slice. It tries to fit in as many blocks as possible based on +// number of bytes available in the provided buffer. It returns the number of +// bytes written to the provided buffer. +func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int { + if len(sackBlocks) == 0 { + return 0 + } + l := len(sackBlocks) + if l > TCPMaxSACKBlocks { + l = TCPMaxSACKBlocks + } + if ll := (len(b) - 2) / 8; ll < l { + l = ll + } + if l == 0 { + // There is not enough space in the provided buffer to add + // any SACK blocks. + return 0 + } + b[0] = TCPOptionSACK + b[1] = byte(l*8 + 2) + for i := 0; i < l; i++ { + binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start)) + binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End)) + } + return int(b[1]) +} + +// EncodeNOP adds an explicit NOP to the option list. +func EncodeNOP(b []byte) int { + if len(b) == 0 { + return 0 + } + b[0] = TCPOptionNOP + return 1 +} + +// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align +// the option buffer. It adds padding bytes after the offset specified and +// returns the number of padding bytes added. The passed in options slice +// must have space for the padding bytes. +func AddTCPOptionPadding(options []byte, offset int) int { + paddingToAdd := -offset & 3 + // Now add any padding bytes that might be required to quad align the + // options. + for i := offset; i < offset+paddingToAdd; i++ { + options[i] = TCPOptionNOP + } + return paddingToAdd +} + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Source Port | Destination Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Acknowledgment Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Data | |U|A|P|R|S|F| | +| Offset| Reserved |R|C|S|S|Y|I| Window | +| | |G|K|H|T|N|N| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Checksum | Urgent Pointer | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Options | Padding | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| data | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +var tcpFmt string = ` +|% 16s|% 16s| +|% 32s | +|% 32s | +|% 4s|% 4s|%06b|% 16s| +|% 16s|% 16s| + ---------------------------------` + +func (b TCP) String() string { + return fmt.Sprintf(tcpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), + atoi(b.SequenceNumber()), + atoi(b.AckNumber()), + atoi(b.DataOffset()), "0", b.Flags(), atoi(b.WindowSize()), + atoi(b.Checksum()), atoi(b.UrgentPtr())) +} diff --git a/tcpip/seqnum/seqnum.go b/tcpip/seqnum/seqnum.go index 1e968e8..7aee5f9 100644 --- a/tcpip/seqnum/seqnum.go +++ b/tcpip/seqnum/seqnum.go @@ -1,50 +1,50 @@ -package seqnum - -// Value represents the value of a sequence number. -type Value uint32 - -// Size represents the size (length) of a sequence number window -type Size uint32 - -// LessThan v < w -func (v Value) LessThan(w Value) bool { - return int32(v-w) < 0 -} - -// LessThanEq returns true if v==w or v is before i.e., v < w. -func (v Value) LessThanEq(w Value) bool { - if v == w { - return true - } - return v.LessThan(w) -} - -// InRange v ∈ [a, b) -func (v Value) InRange(a, b Value) bool { - return a <= v && v < b -} - -// InWindows check v in [first, first+size) -func (v Value) InWindows(first Value, size Size) bool { - return v.InRange(first, first.Add(size)) -} - -// Add return v + s -func (v Value) Add(s Size) Value { - return v + Value(s) -} - -// Size return the size of [v, w) -func (v Value) Size(w Value) Size { - return Size(w - v) -} - -// UpdateForward update the value to v+s -func (v *Value) UpdateForward(s Size) { - *v += Value(s) -} - -// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). -func Overlap(a Value, b Size, x Value, y Size) bool { - return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) -} +package seqnum + +// Value represents the value of a sequence number. +type Value uint32 + +// Size represents the size (length) of a sequence number window +type Size uint32 + +// LessThan v < w +func (v Value) LessThan(w Value) bool { + return int32(v-w) < 0 +} + +// LessThanEq returns true if v==w or v is before i.e., v < w. +func (v Value) LessThanEq(w Value) bool { + if v == w { + return true + } + return v.LessThan(w) +} + +// InRange v ∈ [a, b) +func (v Value) InRange(a, b Value) bool { + return a <= v && v < b +} + +// InWindows check v in [first, first+size) +func (v Value) InWindows(first Value, size Size) bool { + return v.InRange(first, first.Add(size)) +} + +// Add return v + s +func (v Value) Add(s Size) Value { + return v + Value(s) +} + +// Size return the size of [v, w) +func (v Value) Size(w Value) Size { + return Size(w - v) +} + +// UpdateForward update the value to v+s +func (v *Value) UpdateForward(s Size) { + *v += Value(s) +} + +// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). +func Overlap(a Value, b Size, x Value, y Size) bool { + return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) +} diff --git a/tcpip/transport/tcp/README.md b/tcpip/transport/tcp/README.md index 25fc936..a7a5751 100644 --- a/tcpip/transport/tcp/README.md +++ b/tcpip/transport/tcp/README.md @@ -1,77 +1,77 @@ -# TCP 协议 - -## tcp特点 -1. tcp 是面向连接的传输协议。 -2. tcp 的连接是端到端的。 -3. tcp 提供可靠的传输。 -4. tcp 的传输以字节流的方式。 -5. tcp 提供全双工的通信。 -6. tcp 有拥塞控制。 - -![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555573949562.png) - -``` sh - 0 1 2 3 - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Source Port | Destination Port | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Sequence Number | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Acknowledgment Number | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Data | |U|A|P|R|S|F| | -| Offset| Reserved |R|C|S|S|Y|I| Window | -| | |G|K|H|T|N|N| | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Checksum | Urgent Pointer | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| Options | Padding | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -| data | -+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -``` - -1. 源端口和目的端口 各占 2 个字节,分别 tcp 连接的源端口和目的端口。关于端口的概念之前已经介绍过了。 -2. 序号 占 4 字节,序号范围是[0,2^32 - 1],共 2^32(即 4294967296)个序号。序号增加到 2^32-1 后,下一个序号就又回到 0。TCP 是面向字节流的,在一个 TCP 连接中传送的字节流中的每一个字节都按顺序编号。整个要传送的字节流的起始序号(ISN)必须在连接建立时设置。首部中的序号字段值则是指的是本报文段所发送的数据的第一个字节的序号。例如,一报文段的序号是 301,而接待的数据共有 100 字节。这就表明:本报文段的数据的第一个字节的序号是 301,最后一个字节的序号是 400。显然,下一个报文段(如果还有的话)的数据序号应当从 401 开始,即下一个报文段的序号字段值应为 401。 -3. 确认号 占 4 字节,是期望收到对方下一个报文段的第一个数据字节的序号。例如,B 正确收到了 A 发送过来的一个报文段,其序号字段值是 501,而数据长度是 200 字节(序号 501~700),这表明 B 正确收到了 A 发送的到序号 700 为止的数据。因此,B 期望收到 A 的下一个数据序号是 701,于是 B 在发送给 A 的确认报文段中把确认号置为 701。注意,现在确认号不是 501,也不是 700,而是 701。 总之:若确认号为 N,则表明:到序号 N-1 为止的所有数据都已正确收到。TCP 除了第一个 SYN 报文之外,所有 TCP 报文都需要携带 ACK 状态位。 -4. 数据偏移 占 4 位,它指出 TCP 报文段的数据起始处距离 TCP 报文段的起始处有多远。这个字段实际上是指出 TCP 报文段的首部长度。由于首部中还有长度不确定的选项字段,因此数据偏移字段是必要的,但应注意,“数据偏移”的单位是 4 个字节,由于 4 位二进制数能表示的最大十进制数字是 15,因此数据偏移的最大值是 60 字节。 -5. 保留 占 6 位,保留为今后使用,但目前应置为 0。 -6. 控制报文标志 - - **紧急URG(URGent)** 当 URG=1 时,表明紧急指针字段有效。它告诉系统此报文段中有紧急数据,应尽快发送(相当于高优先级的数据),而不要按原来的排队顺序来传送。例如,已经发送了很长的一个程序要在远地的主机上运行。但后来发现了一些问题,需要取消该程序的运行,因此用户从键盘发出中断命令。如果不使用紧急数据,那么这两个字符将存储在接收 TCP 的缓存末尾。只有在所有的数据被处理完毕后这两个字符才被交付接收方的应用进程。这样做就浪费了很多时间。 当 URG 置为 1 时,发送应用进程就告诉发送方的 TCP 有紧急数据要传送。于是发送方 TCP 就把紧急数据插入到本报文段数据的最前面,而在紧急数据后面的数据仍然是普通数据。这时要与首部中紧急指针(Urgent Pointer)字段配合使用。 - - **确认ACK(ACKnowledgment)** 仅当 ACK=1 时确认号字段才有效,当 ACK=0 时确认号无效。TCP 规定,在连接建立后所有的传送的报文段都必须把 ACK 置为 1。 - - **推送 PSH(PuSH)** 当两个应用进程进行交互式的通信时,有时在一端的应用进程希望在键入一个命令后立即就能收到对方的响应。在这种情况下,TCP 就可以使用推送(push)操作。这时,发送方 TCP 把 PSH 置为 1,并立即创建一个报文段发送出去。接收方 TCP 收到 PSH=1 的报文段,就尽快地交付接收应用进程。 - - **复位RST(ReSeT)** 当 RST=1 时,表名 TCP 连接中出现了严重错误(如由于主机崩溃或其他原因),必须释放连接,然后再重新建立传输连接。RST 置为 1 用来拒绝一个非法的报文段或拒绝打开一个连接。 - - **同步SYN(SYNchronization)** 在连接建立时用来同步序号。当 SYN=1 而 ACK=0 时,表明这是一个连接请求报文段。对方若同意建立连接,则应在响应的报文段中使 SYN=1 和 ACK=1,因此 SYN 置为 1 就表示这是一个连接请求或连接接受报文。 - - **终止FIN(FINis,意思是“完”“终”)** 用来释放一个连接。当 FIN=1 时,表明此报文段的发送发的数据已发送完毕,并要求释放运输连接。 -7. 窗口 占 2 字节,窗口值是[0,2^16-1]之间的整数。窗口指的是发送本报文段的一方的接受窗口(而不是自己的发送窗口)。窗口值告诉对方:从本报文段首部中的确认号算起,接收方目前允许对方发送的数据量(以字节为单位)。之所以要有这个限制,是因为接收方的数据缓存空间是有限的。总之,窗口值作为接收方让发送方设置其发送窗口的依据,作为流量控制的依据,后面会详细介绍。 总之:窗口字段明确指出了现在允许对方发送的数据量。窗口值经常在动态变化。 -8. 检验和 占 2 字节,检验和字段检验的范围包括首部和数据这两部分。和 UDP 用户数据报一样,在计算检验和时,要在 TCP 报文段的前面加上 12 字节的伪首部。伪首部的格式和 UDP 用户数据报的伪首部一样。但应把伪首部第 4 个字段中的 17 改为 6(TCP 的协议号是 6);把第 5 字段中的 UDP 中的长度改为 TCP 长度。接收方收到此报文段后,仍要加上这个伪首部来计算检验和。若使用 IPv6,则相应的伪首部也要改变。 -9. 紧急指针 占 2 字节,紧急指针仅在 URG=1 时才有意义,它指出本报文段中的紧急数据的字节数(紧急数据结束后就是普通数据) 。因此,在紧急指针指出了紧急数据的末尾在报文段中的位置。当所有紧急数据都处理完时,TCP 就告诉应用程序恢复到正常操作。值得注意的是,即使窗口为 0 时也可以发送紧急数据。 -10. 选项 选项长度可变,最长可达 40 字节。当没有使用“选项”时,TCP 的首部长度是 20 字节。TCP 首部总长度由 TCP 头中的“数据偏移”字段决定,前面说了,最长偏移为 60 字节。那么“tcp 选项”的长度最大为 60-20=40 字节。 - -## tcp选项 - -TCP 最初只规定了一种选项,即最大报文段长度 MSS(Maximum Segment Szie)。后来又增加了几个选项如窗口扩大选项、时间戳选项等,下面说明常用的选项。 - -1. kind=0 是选项表结束选项。 - -2. kind=1 是空操作(nop)选项 - - 没有特殊含义,一般用于将 TCP 选项的总长度填充为 4 字节的整数倍,为啥需要 4 字节整数倍?因为前面讲了数据偏移字段的单位是 4 个字节。 - -3. kind=2 是最大报文段长度选项 - TCP 连接初始化时,通信双方使用该选项来协商最大报文段长度(Max Segment Size,MSS)。TCP 模块通常将 MSS 设置为(MTU-40)字节(减掉的这 40 字节包括 20 字节的 TCP 头部和 20 字节的 IP 头部)。这样携带 TCP 报文段的 IP 数据报的长度就不会超过 MTU(假设 TCP 头部和 IP 头部都不包含选项字段,并且这也是一般情况),从而避免本机发生 IP 分片。对以太网而言,MSS 值是 1460(1500-40)字节。 - -4. kind=3 是窗口扩大因子选项 - TCP 连接初始化时,通信双方使用该选项来协商接收通告窗口的扩大因子。在 TCP 的头部中,接收通告窗口大小是用 16 位表示的,故最大为 65535 字节,但实际上 TCP 模块允许的接收通告窗口大小远不止这个数(为了提高 TCP 通信的吞吐量)。窗口扩大因子解决了这个问题。假设 TCP 头部中的接收通告窗口大小是 N,窗口扩大因子(移位数)是 M,那么 TCP 报文段的实际接收通告窗口大小是 N 乘 2M,或者说 N 左移 M 位。注意,M 的取值范围是 0 ~ 14。 - - 和 MSS 选项一样,窗口扩大因子选项只能出现在同步报文段中,否则将被忽略。但同步报文段本身不执行窗口扩大操作,即同步报文段头部的接收通告窗口大小就是该 TCP 报文段的实际接收通告窗口大小。当连接建立好之后,每个数据传输方向的窗口扩大因子就固定不变了。关于窗口扩大因子选项的细节,可参考标准文档 RFC 1323。 - -5. kind=4 是选择性确认(Selective Acknowledgment,SACK)选项 - TCP 通信时,如果某个 TCP 报文段丢失,则 TCP 模块会重传最后被确认的 TCP 报文段后续的所有报文段,这样原先已经正确传输的 TCP 报文段也可能重复发送,从而降低了 TCP 性能。SACK 技术正是为改善这种情况而产生的,它使 TCP 模块只重新发送丢失的 TCP 报文段,不用发送所有未被确认的 TCP 报文段。选择性确认选项用在连接初始化时,表示是否支持 SACK 技术。 - -6. kind=5 是 SACK 实际工作的选项 - 该选项的参数告诉发送方本端已经收到并缓存的不连续的数据块,从而让发送端可以据此检查并重发丢失的数据块。每个块边沿(edge of block)参数包含一个 4 字节的序号。其中块左边沿表示不连续块的第一个数据的序号,而块右边沿则表示不连续块的最后一个数据的序号的下一个序号。这样一对参数(块左边沿和块右边沿)之间的数据是没有收到的。因为一个块信息占用 8 字节,所以 TCP 头部选项中实际上最多可以包含 4 个这样的不连续数据块(考虑选项类型和长度占用的 2 字节)。 - -7. kind=8 是时间戳选项 - 该选项提供了较为准确的计算通信双方之间的回路时间(Round Trip Time,RTT)的方法,从而为 TCP 流量控制提供重要信息。 +# TCP 协议 + +## tcp特点 +1. tcp 是面向连接的传输协议。 +2. tcp 的连接是端到端的。 +3. tcp 提供可靠的传输。 +4. tcp 的传输以字节流的方式。 +5. tcp 提供全双工的通信。 +6. tcp 有拥塞控制。 + +![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555573949562.png) + +``` sh + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Source Port | Destination Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Acknowledgment Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Data | |U|A|P|R|S|F| | +| Offset| Reserved |R|C|S|S|Y|I| Window | +| | |G|K|H|T|N|N| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Checksum | Urgent Pointer | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Options | Padding | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| data | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +``` + +1. 源端口和目的端口 各占 2 个字节,分别 tcp 连接的源端口和目的端口。关于端口的概念之前已经介绍过了。 +2. 序号 占 4 字节,序号范围是[0,2^32 - 1],共 2^32(即 4294967296)个序号。序号增加到 2^32-1 后,下一个序号就又回到 0。TCP 是面向字节流的,在一个 TCP 连接中传送的字节流中的每一个字节都按顺序编号。整个要传送的字节流的起始序号(ISN)必须在连接建立时设置。首部中的序号字段值则是指的是本报文段所发送的数据的第一个字节的序号。例如,一报文段的序号是 301,而接待的数据共有 100 字节。这就表明:本报文段的数据的第一个字节的序号是 301,最后一个字节的序号是 400。显然,下一个报文段(如果还有的话)的数据序号应当从 401 开始,即下一个报文段的序号字段值应为 401。 +3. 确认号 占 4 字节,是期望收到对方下一个报文段的第一个数据字节的序号。例如,B 正确收到了 A 发送过来的一个报文段,其序号字段值是 501,而数据长度是 200 字节(序号 501~700),这表明 B 正确收到了 A 发送的到序号 700 为止的数据。因此,B 期望收到 A 的下一个数据序号是 701,于是 B 在发送给 A 的确认报文段中把确认号置为 701。注意,现在确认号不是 501,也不是 700,而是 701。 总之:若确认号为 N,则表明:到序号 N-1 为止的所有数据都已正确收到。TCP 除了第一个 SYN 报文之外,所有 TCP 报文都需要携带 ACK 状态位。 +4. 数据偏移 占 4 位,它指出 TCP 报文段的数据起始处距离 TCP 报文段的起始处有多远。这个字段实际上是指出 TCP 报文段的首部长度。由于首部中还有长度不确定的选项字段,因此数据偏移字段是必要的,但应注意,“数据偏移”的单位是 4 个字节,由于 4 位二进制数能表示的最大十进制数字是 15,因此数据偏移的最大值是 60 字节。 +5. 保留 占 6 位,保留为今后使用,但目前应置为 0。 +6. 控制报文标志 + - **紧急URG(URGent)** 当 URG=1 时,表明紧急指针字段有效。它告诉系统此报文段中有紧急数据,应尽快发送(相当于高优先级的数据),而不要按原来的排队顺序来传送。例如,已经发送了很长的一个程序要在远地的主机上运行。但后来发现了一些问题,需要取消该程序的运行,因此用户从键盘发出中断命令。如果不使用紧急数据,那么这两个字符将存储在接收 TCP 的缓存末尾。只有在所有的数据被处理完毕后这两个字符才被交付接收方的应用进程。这样做就浪费了很多时间。 当 URG 置为 1 时,发送应用进程就告诉发送方的 TCP 有紧急数据要传送。于是发送方 TCP 就把紧急数据插入到本报文段数据的最前面,而在紧急数据后面的数据仍然是普通数据。这时要与首部中紧急指针(Urgent Pointer)字段配合使用。 + - **确认ACK(ACKnowledgment)** 仅当 ACK=1 时确认号字段才有效,当 ACK=0 时确认号无效。TCP 规定,在连接建立后所有的传送的报文段都必须把 ACK 置为 1。 + - **推送 PSH(PuSH)** 当两个应用进程进行交互式的通信时,有时在一端的应用进程希望在键入一个命令后立即就能收到对方的响应。在这种情况下,TCP 就可以使用推送(push)操作。这时,发送方 TCP 把 PSH 置为 1,并立即创建一个报文段发送出去。接收方 TCP 收到 PSH=1 的报文段,就尽快地交付接收应用进程。 + - **复位RST(ReSeT)** 当 RST=1 时,表名 TCP 连接中出现了严重错误(如由于主机崩溃或其他原因),必须释放连接,然后再重新建立传输连接。RST 置为 1 用来拒绝一个非法的报文段或拒绝打开一个连接。 + - **同步SYN(SYNchronization)** 在连接建立时用来同步序号。当 SYN=1 而 ACK=0 时,表明这是一个连接请求报文段。对方若同意建立连接,则应在响应的报文段中使 SYN=1 和 ACK=1,因此 SYN 置为 1 就表示这是一个连接请求或连接接受报文。 + - **终止FIN(FINis,意思是“完”“终”)** 用来释放一个连接。当 FIN=1 时,表明此报文段的发送发的数据已发送完毕,并要求释放运输连接。 +7. 窗口 占 2 字节,窗口值是[0,2^16-1]之间的整数。窗口指的是发送本报文段的一方的接受窗口(而不是自己的发送窗口)。窗口值告诉对方:从本报文段首部中的确认号算起,接收方目前允许对方发送的数据量(以字节为单位)。之所以要有这个限制,是因为接收方的数据缓存空间是有限的。总之,窗口值作为接收方让发送方设置其发送窗口的依据,作为流量控制的依据,后面会详细介绍。 总之:窗口字段明确指出了现在允许对方发送的数据量。窗口值经常在动态变化。 +8. 检验和 占 2 字节,检验和字段检验的范围包括首部和数据这两部分。和 UDP 用户数据报一样,在计算检验和时,要在 TCP 报文段的前面加上 12 字节的伪首部。伪首部的格式和 UDP 用户数据报的伪首部一样。但应把伪首部第 4 个字段中的 17 改为 6(TCP 的协议号是 6);把第 5 字段中的 UDP 中的长度改为 TCP 长度。接收方收到此报文段后,仍要加上这个伪首部来计算检验和。若使用 IPv6,则相应的伪首部也要改变。 +9. 紧急指针 占 2 字节,紧急指针仅在 URG=1 时才有意义,它指出本报文段中的紧急数据的字节数(紧急数据结束后就是普通数据) 。因此,在紧急指针指出了紧急数据的末尾在报文段中的位置。当所有紧急数据都处理完时,TCP 就告诉应用程序恢复到正常操作。值得注意的是,即使窗口为 0 时也可以发送紧急数据。 +10. 选项 选项长度可变,最长可达 40 字节。当没有使用“选项”时,TCP 的首部长度是 20 字节。TCP 首部总长度由 TCP 头中的“数据偏移”字段决定,前面说了,最长偏移为 60 字节。那么“tcp 选项”的长度最大为 60-20=40 字节。 + +## tcp选项 + +TCP 最初只规定了一种选项,即最大报文段长度 MSS(Maximum Segment Szie)。后来又增加了几个选项如窗口扩大选项、时间戳选项等,下面说明常用的选项。 + +1. kind=0 是选项表结束选项。 + +2. kind=1 是空操作(nop)选项 + + 没有特殊含义,一般用于将 TCP 选项的总长度填充为 4 字节的整数倍,为啥需要 4 字节整数倍?因为前面讲了数据偏移字段的单位是 4 个字节。 + +3. kind=2 是最大报文段长度选项 + TCP 连接初始化时,通信双方使用该选项来协商最大报文段长度(Max Segment Size,MSS)。TCP 模块通常将 MSS 设置为(MTU-40)字节(减掉的这 40 字节包括 20 字节的 TCP 头部和 20 字节的 IP 头部)。这样携带 TCP 报文段的 IP 数据报的长度就不会超过 MTU(假设 TCP 头部和 IP 头部都不包含选项字段,并且这也是一般情况),从而避免本机发生 IP 分片。对以太网而言,MSS 值是 1460(1500-40)字节。 + +4. kind=3 是窗口扩大因子选项 + TCP 连接初始化时,通信双方使用该选项来协商接收通告窗口的扩大因子。在 TCP 的头部中,接收通告窗口大小是用 16 位表示的,故最大为 65535 字节,但实际上 TCP 模块允许的接收通告窗口大小远不止这个数(为了提高 TCP 通信的吞吐量)。窗口扩大因子解决了这个问题。假设 TCP 头部中的接收通告窗口大小是 N,窗口扩大因子(移位数)是 M,那么 TCP 报文段的实际接收通告窗口大小是 N 乘 2M,或者说 N 左移 M 位。注意,M 的取值范围是 0 ~ 14。 + + 和 MSS 选项一样,窗口扩大因子选项只能出现在同步报文段中,否则将被忽略。但同步报文段本身不执行窗口扩大操作,即同步报文段头部的接收通告窗口大小就是该 TCP 报文段的实际接收通告窗口大小。当连接建立好之后,每个数据传输方向的窗口扩大因子就固定不变了。关于窗口扩大因子选项的细节,可参考标准文档 RFC 1323。 + +5. kind=4 是选择性确认(Selective Acknowledgment,SACK)选项 + TCP 通信时,如果某个 TCP 报文段丢失,则 TCP 模块会重传最后被确认的 TCP 报文段后续的所有报文段,这样原先已经正确传输的 TCP 报文段也可能重复发送,从而降低了 TCP 性能。SACK 技术正是为改善这种情况而产生的,它使 TCP 模块只重新发送丢失的 TCP 报文段,不用发送所有未被确认的 TCP 报文段。选择性确认选项用在连接初始化时,表示是否支持 SACK 技术。 + +6. kind=5 是 SACK 实际工作的选项 + 该选项的参数告诉发送方本端已经收到并缓存的不连续的数据块,从而让发送端可以据此检查并重发丢失的数据块。每个块边沿(edge of block)参数包含一个 4 字节的序号。其中块左边沿表示不连续块的第一个数据的序号,而块右边沿则表示不连续块的最后一个数据的序号的下一个序号。这样一对参数(块左边沿和块右边沿)之间的数据是没有收到的。因为一个块信息占用 8 字节,所以 TCP 头部选项中实际上最多可以包含 4 个这样的不连续数据块(考虑选项类型和长度占用的 2 字节)。 + +7. kind=8 是时间戳选项 + 该选项提供了较为准确的计算通信双方之间的回路时间(Round Trip Time,RTT)的方法,从而为 TCP 流量控制提供重要信息。 diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go index 4de6e2d..3640967 100644 --- a/tcpip/transport/tcp/accept.go +++ b/tcpip/transport/tcp/accept.go @@ -1,330 +1,330 @@ -package tcp - -import ( - "crypto/rand" - "crypto/sha1" - "encoding/binary" - "hash" - "io" - "log" - "netstack/sleep" - "netstack/tcpip" - "netstack/tcpip/header" - "netstack/tcpip/seqnum" - "netstack/tcpip/stack" - "sync" - "time" -) - -const ( - // tsLen is the length, in bits, of the timestamp in the SYN cookie. - tsLen = 8 - - // tsMask is a mask for timestamp values (i.e., tsLen bits). - tsMask = (1 << tsLen) - 1 - - // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. - tsOffset = 24 - - // hashMask is the mask for hash values (i.e., tsOffset bits). - hashMask = (1 << tsOffset) - 1 - - // maxTSDiff is the maximum allowed difference between a received cookie - // timestamp and the current timestamp. If the difference is greater - // than maxTSDiff, the cookie is expired. - maxTSDiff = 2 -) - -var ( - // SynRcvdCountThreshold is the global maximum number of connections - // that are allowed to be in SYN-RCVD state before TCP starts using SYN - // cookies to accept connections. - // - // It is an exported variable only for testing, and should not otherwise - // be used by importers of this package. - SynRcvdCountThreshold uint64 = 1000 - - // mssTable is a slice containing the possible MSS values that we - // encode in the SYN cookie with two bits. - mssTable = []uint16{536, 1300, 1440, 1460} -) - -func encodeMSS(mss uint16) uint32 { - for i := len(mssTable) - 1; i > 0; i-- { - if mss >= mssTable[i] { - return uint32(i) - } - } - return 0 -} - -// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is -// protected by a mutex so that we can increment only when it's guaranteed not -// to go above a threshold. -var synRcvdCount struct { - sync.Mutex - value uint64 - pending sync.WaitGroup -} - -// listenContext is used by a listening endpoint to store state used while -// listening for connections. This struct is allocated by the listen goroutine -// and must not be accessed or have its methods called concurrently as they -// may mutate the stored objects. -type listenContext struct { - stack *stack.Stack - rcvWnd seqnum.Size - nonce [2][sha1.BlockSize]byte // nonce 随机数 - - hasherMu sync.Mutex - hasher hash.Hash // 散列实现 - v6only bool - netProto tcpip.NetworkProtocolNumber -} - -// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. -func timeStamp() uint32 { - return uint32(time.Now().Unix()>>6) & tsMask // 00 00 00 FF -} - -// 增加一个任务 最多1000个 -func incSynRcvdCount() bool { - synRcvdCount.Mutex.Lock() - defer synRcvdCount.Unlock() - - if synRcvdCount.value >= SynRcvdCountThreshold { - return false - } - - synRcvdCount.pending.Add(1) - synRcvdCount.value++ - return true -} - -// 结束一个任务 -func decSynRcvdCount() { - synRcvdCount.Mutex.Lock() - defer synRcvdCount.Unlock() - synRcvdCount.value-- - synRcvdCount.pending.Done() -} - -// newListenContext creates a new listen context. -func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { - l := &listenContext{ - stack: stack, - rcvWnd: rcvWnd, - hasher: sha1.New(), - v6only: v6only, - netProto: netProto, - } - - rand.Read(l.nonce[0][:]) - rand.Read(l.nonce[1][:]) - - return l -} - -// cookieHash calculates the cookieHash for the given id, timestamp and nonce -// index. The hash is used to create and validate cookies. -func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { - - // Initialize block with fixed-size data: local ports and v. - var payload [8]byte - binary.BigEndian.PutUint16(payload[0:], id.LocalPort) - binary.BigEndian.PutUint16(payload[2:], id.RemotePort) - binary.BigEndian.PutUint32(payload[4:], ts) - - // Feed everything to the hasher. - l.hasherMu.Lock() - l.hasher.Reset() - l.hasher.Write(payload[:]) - l.hasher.Write(l.nonce[nonceIndex][:]) - io.WriteString(l.hasher, string(id.LocalAddress)) - io.WriteString(l.hasher, string(id.RemoteAddress)) - - // Finalize the calculation of the hash and return the first 4 bytes. - h := make([]byte, 0, sha1.Size) - h = l.hasher.Sum(h) - l.hasherMu.Unlock() - - return binary.BigEndian.Uint32(h[:]) -} - -// createCookie creates a SYN cookie for the given id and incoming sequence -// number. -func (l *listenContext) createCookie(id stack.TransportEndpointID, - seq seqnum.Value, data uint32) seqnum.Value { - ts := timeStamp() - v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) - v += (l.cookieHash(id, ts, 1) + data) & hashMask - return seqnum.Value(v) -} - -// isCookieValid checks if the supplied cookie is valid for the given id and -// sequence number. If it is, it also returns the data originally encoded in the -// cookie when createCookie was called. -func (l *listenContext) isCookieValid(id stack.TransportEndpointID, - cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { - ts := timeStamp() - v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) - cookieTS := v >> tsOffset - if ((ts - cookieTS) & tsMask) > maxTSDiff { - return 0, false - } - - return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true -} - -// 新建一个tcp端 这个tcp端与segment同属一个tcp连接 但属于不同阶段 用于写回远端 -func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, - irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { - // Create a new endpoint. - netProto := l.netProto - if netProto == 0 { - netProto = s.route.NetProto - } - n := newEndpoint(l.stack, netProto, nil) - n.v6only = l.v6only - n.id = s.id - n.boundNICID = s.route.NICID() - n.route = s.route.Clone() - n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} - n.rcvBufSize = int(l.rcvWnd) - - n.maybeEnableTimestamp(rcvdSynOpts) - n.maybeEnableSACKPermitted(rcvdSynOpts) - - // Register new endpoint so that packets are routed to it. - // 在网络协议栈中去注册这个tcp端 - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, - n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil { - n.Close() - return nil, err - } - - n.isRegistered = true - n.state = stateConnected - - // Create sender and receiver. - // The receiver at least temporarily has a zero receive window scale, - // but the caller may change it (before starting the protocol loop). - n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) - n.rcv = newReceiver(n, irs, l.rcvWnd, 0) - - return n, nil -} - -func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { - // create new endpoint - irs := s.sequenceNumber - cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) - log.Println("收到一个远端握手申请", irs, "标记cookie", cookie) - ep, err := l.createConnectedEndpoint(s, cookie, irs, opts) - if err != nil { - return nil, err - } - - // 以下执行三次握手 - - // 构建handshake管理器 - h, err := newHandshake(ep, l.rcvWnd) - if err != nil { - ep.Close() - return nil, err - } - - // 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack - h.resetToSynRcvd(cookie, irs, opts) - if err := h.execute(); err != nil { - ep.Close() - return nil, err - } - - // 更新接收窗口扩张因子 - - return ep, nil -} - -// 一旦侦听端点收到SYN段,handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。 -// 在TCP开始使用SYN cookie接受连接之前,允许使用有限数量的这些goroutine。 -func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { - defer decSynRcvdCount() - defer s.decRef() - - _, err := ctx.createEndpointAndPerformHandshake(s, opts) - if err != nil { - return - } - // 到这里,三次握手已经完成,那么分发一个新的连接 - //e.deliverAccepted(n) -} - -// handleListenSegment is called when a listening endpoint receives a segment -// and needs to handle it. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { - switch s.flags { - case flagSyn: // syn报文处理 - // 分析tcp选项 - opts := parseSynSegmentOptions(s) - if incSynRcvdCount() { - s.incRef() - go e.handleSynSegment(ctx, s, &opts) - } else { - log.Println("暂时不处理") - } - // 返回一个syn+ack报文 - case flagFin: // fin报文处理 - // 三次握手最后一次 ack 报文 - } -} - -func parseSynSegmentOptions(s *segment) header.TCPSynOptions { - synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck)) - if synOpts.TS { - s.parsedOptions.TSVal = synOpts.TSVal - s.parsedOptions.TSEcr = synOpts.TSEcr - } - return synOpts -} - -// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行,负责处理连接请求 -func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { - defer func() { - // TODO 后置处理 - }() - - e.mu.Lock() - v6only := e.v6only - e.mu.Unlock() - ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto) - // 初始化事件触发器 并添加事件 - s := sleep.Sleeper{} - s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) - s.AddWaker(&e.notificationWaker, wakerForNotification) - - for { - switch index, _ := s.Fetch(true); index { // Fetch(true) 阻塞获取 - case wakerForNewSegment: - mayRequeue := true - // 接收和处理tcp报文 - for i := 0; i < maxSegmentsPerWake; i++ { - s := e.segmentQueue.dequeue() - if s == nil { - mayRequeue = false - break - } - e.handleListenSegment(ctx, s) - s.decRef() - } - // If the queue is not empty, make sure we'll wake up - // in the next iteration. - if mayRequeue && !e.segmentQueue.empty() { // 主协程又添加了新数据 - e.newSegmentWaker.Assert() // 重新尝试获取数据 - } - case wakerForNotification: - // TODO 触发其他事件 - log.Println("其他事件?") - } - } -} +package tcp + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/binary" + "hash" + "io" + "log" + "netstack/sleep" + "netstack/tcpip" + "netstack/tcpip/header" + "netstack/tcpip/seqnum" + "netstack/tcpip/stack" + "sync" + "time" +) + +const ( + // tsLen is the length, in bits, of the timestamp in the SYN cookie. + tsLen = 8 + + // tsMask is a mask for timestamp values (i.e., tsLen bits). + tsMask = (1 << tsLen) - 1 + + // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. + tsOffset = 24 + + // hashMask is the mask for hash values (i.e., tsOffset bits). + hashMask = (1 << tsOffset) - 1 + + // maxTSDiff is the maximum allowed difference between a received cookie + // timestamp and the current timestamp. If the difference is greater + // than maxTSDiff, the cookie is expired. + maxTSDiff = 2 +) + +var ( + // SynRcvdCountThreshold is the global maximum number of connections + // that are allowed to be in SYN-RCVD state before TCP starts using SYN + // cookies to accept connections. + // + // It is an exported variable only for testing, and should not otherwise + // be used by importers of this package. + SynRcvdCountThreshold uint64 = 1000 + + // mssTable is a slice containing the possible MSS values that we + // encode in the SYN cookie with two bits. + mssTable = []uint16{536, 1300, 1440, 1460} +) + +func encodeMSS(mss uint16) uint32 { + for i := len(mssTable) - 1; i > 0; i-- { + if mss >= mssTable[i] { + return uint32(i) + } + } + return 0 +} + +// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is +// protected by a mutex so that we can increment only when it's guaranteed not +// to go above a threshold. +var synRcvdCount struct { + sync.Mutex + value uint64 + pending sync.WaitGroup +} + +// listenContext is used by a listening endpoint to store state used while +// listening for connections. This struct is allocated by the listen goroutine +// and must not be accessed or have its methods called concurrently as they +// may mutate the stored objects. +type listenContext struct { + stack *stack.Stack + rcvWnd seqnum.Size + nonce [2][sha1.BlockSize]byte // nonce 随机数 + + hasherMu sync.Mutex + hasher hash.Hash // 散列实现 + v6only bool + netProto tcpip.NetworkProtocolNumber +} + +// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. +func timeStamp() uint32 { + return uint32(time.Now().Unix()>>6) & tsMask // 00 00 00 FF +} + +// 增加一个任务 最多1000个 +func incSynRcvdCount() bool { + synRcvdCount.Mutex.Lock() + defer synRcvdCount.Unlock() + + if synRcvdCount.value >= SynRcvdCountThreshold { + return false + } + + synRcvdCount.pending.Add(1) + synRcvdCount.value++ + return true +} + +// 结束一个任务 +func decSynRcvdCount() { + synRcvdCount.Mutex.Lock() + defer synRcvdCount.Unlock() + synRcvdCount.value-- + synRcvdCount.pending.Done() +} + +// newListenContext creates a new listen context. +func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { + l := &listenContext{ + stack: stack, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + } + + rand.Read(l.nonce[0][:]) + rand.Read(l.nonce[1][:]) + + return l +} + +// cookieHash calculates the cookieHash for the given id, timestamp and nonce +// index. The hash is used to create and validate cookies. +func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { + + // Initialize block with fixed-size data: local ports and v. + var payload [8]byte + binary.BigEndian.PutUint16(payload[0:], id.LocalPort) + binary.BigEndian.PutUint16(payload[2:], id.RemotePort) + binary.BigEndian.PutUint32(payload[4:], ts) + + // Feed everything to the hasher. + l.hasherMu.Lock() + l.hasher.Reset() + l.hasher.Write(payload[:]) + l.hasher.Write(l.nonce[nonceIndex][:]) + io.WriteString(l.hasher, string(id.LocalAddress)) + io.WriteString(l.hasher, string(id.RemoteAddress)) + + // Finalize the calculation of the hash and return the first 4 bytes. + h := make([]byte, 0, sha1.Size) + h = l.hasher.Sum(h) + l.hasherMu.Unlock() + + return binary.BigEndian.Uint32(h[:]) +} + +// createCookie creates a SYN cookie for the given id and incoming sequence +// number. +func (l *listenContext) createCookie(id stack.TransportEndpointID, + seq seqnum.Value, data uint32) seqnum.Value { + ts := timeStamp() + v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) + v += (l.cookieHash(id, ts, 1) + data) & hashMask + return seqnum.Value(v) +} + +// isCookieValid checks if the supplied cookie is valid for the given id and +// sequence number. If it is, it also returns the data originally encoded in the +// cookie when createCookie was called. +func (l *listenContext) isCookieValid(id stack.TransportEndpointID, + cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { + ts := timeStamp() + v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) + cookieTS := v >> tsOffset + if ((ts - cookieTS) & tsMask) > maxTSDiff { + return 0, false + } + + return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true +} + +// 新建一个tcp端 这个tcp端与segment同属一个tcp连接 但属于不同阶段 用于写回远端 +func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, + irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // Create a new endpoint. + netProto := l.netProto + if netProto == 0 { + netProto = s.route.NetProto + } + n := newEndpoint(l.stack, netProto, nil) + n.v6only = l.v6only + n.id = s.id + n.boundNICID = s.route.NICID() + n.route = s.route.Clone() + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.rcvBufSize = int(l.rcvWnd) + + n.maybeEnableTimestamp(rcvdSynOpts) + n.maybeEnableSACKPermitted(rcvdSynOpts) + + // Register new endpoint so that packets are routed to it. + // 在网络协议栈中去注册这个tcp端 + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, + n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil { + n.Close() + return nil, err + } + + n.isRegistered = true + n.state = stateConnected + + // Create sender and receiver. + // The receiver at least temporarily has a zero receive window scale, + // but the caller may change it (before starting the protocol loop). + n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) + n.rcv = newReceiver(n, irs, l.rcvWnd, 0) + + return n, nil +} + +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // create new endpoint + irs := s.sequenceNumber + cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) + log.Println("收到一个远端握手申请", irs, "标记cookie", cookie) + ep, err := l.createConnectedEndpoint(s, cookie, irs, opts) + if err != nil { + return nil, err + } + + // 以下执行三次握手 + + // 构建handshake管理器 + h, err := newHandshake(ep, l.rcvWnd) + if err != nil { + ep.Close() + return nil, err + } + + // 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack + h.resetToSynRcvd(cookie, irs, opts) + if err := h.execute(); err != nil { + ep.Close() + return nil, err + } + + // 更新接收窗口扩张因子 + + return ep, nil +} + +// 一旦侦听端点收到SYN段,handleSynSegment就会在其自己的goroutine中调用。它负责完成握手并将新端点排队以进行接受。 +// 在TCP开始使用SYN cookie接受连接之前,允许使用有限数量的这些goroutine。 +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { + defer decSynRcvdCount() + defer s.decRef() + + _, err := ctx.createEndpointAndPerformHandshake(s, opts) + if err != nil { + return + } + // 到这里,三次握手已经完成,那么分发一个新的连接 + //e.deliverAccepted(n) +} + +// handleListenSegment is called when a listening endpoint receives a segment +// and needs to handle it. +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + switch s.flags { + case flagSyn: // syn报文处理 + // 分析tcp选项 + opts := parseSynSegmentOptions(s) + if incSynRcvdCount() { + s.incRef() + go e.handleSynSegment(ctx, s, &opts) + } else { + log.Println("暂时不处理") + } + // 返回一个syn+ack报文 + case flagFin: // fin报文处理 + // 三次握手最后一次 ack 报文 + } +} + +func parseSynSegmentOptions(s *segment) header.TCPSynOptions { + synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck)) + if synOpts.TS { + s.parsedOptions.TSVal = synOpts.TSVal + s.parsedOptions.TSEcr = synOpts.TSEcr + } + return synOpts +} + +// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行,负责处理连接请求 +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + defer func() { + // TODO 后置处理 + }() + + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto) + // 初始化事件触发器 并添加事件 + s := sleep.Sleeper{} + s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) + s.AddWaker(&e.notificationWaker, wakerForNotification) + + for { + switch index, _ := s.Fetch(true); index { // Fetch(true) 阻塞获取 + case wakerForNewSegment: + mayRequeue := true + // 接收和处理tcp报文 + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + mayRequeue = false + break + } + e.handleListenSegment(ctx, s) + s.decRef() + } + // If the queue is not empty, make sure we'll wake up + // in the next iteration. + if mayRequeue && !e.segmentQueue.empty() { // 主协程又添加了新数据 + e.newSegmentWaker.Assert() // 重新尝试获取数据 + } + case wakerForNotification: + // TODO 触发其他事件 + log.Println("其他事件?") + } + } +} diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go index 81933cd..d0618ed 100644 --- a/tcpip/transport/tcp/connect.go +++ b/tcpip/transport/tcp/connect.go @@ -1,364 +1,364 @@ -package tcp - -import ( - "crypto/rand" - "fmt" - "log" - "netstack/sleep" - "netstack/tcpip" - "netstack/tcpip/buffer" - "netstack/tcpip/header" - "netstack/tcpip/seqnum" - "netstack/tcpip/stack" - "sync" - "time" -) - -const maxSegmentsPerWake = 100 - -type handshakeState int - -const ( - handshakeSynSent handshakeState = iota - handshakeSynRcvd - handshakeCompleted -) - -// The following are used to set up sleepers. -const ( - wakerForNotification = iota - wakerForNewSegment - wakerForResend - wakerForResolution -) - -// handshake holds the state used during a TCP 3-way handshake. -// tcp三次握手时候使用的对象 -type handshake struct { - ep *endpoint - // 握手的状态 - state handshakeState - active bool - flags uint8 - ackNum seqnum.Value - - // iss is the initial send sequence number, as defined in RFC 793. - // 初始序列号 - iss seqnum.Value - - // rcvWnd is the receive window, as defined in RFC 793. - // 接收窗口 - rcvWnd seqnum.Size - - // sndWnd is the send window, as defined in RFC 793. - // 发送窗口 - sndWnd seqnum.Size - - // mss is the maximum segment size received from the peer. - // 最大报文段大小 - mss uint16 - - // sndWndScale is the send window scale, as defined in RFC 1323. A - // negative value means no scaling is supported by the peer. - // 发送窗口扩展因子 - sndWndScale int - - // rcvWndScale is the receive window scale, as defined in RFC 1323. - // 接收窗口扩展因子 - rcvWndScale int -} - -const ( - // Maximum space available for options. - // tcp选项的最大长度 - maxOptionSize = 40 -) - -func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { - h := handshake{ - ep: ep, - active: true, // 激活这个管理器 - rcvWnd: rcvWnd, // 初始接收窗口 - // TODO - } - if err := h.resetState(); err != nil { - return handshake{}, err - } - return h, nil -} - -func (h *handshake) resetState() *tcpip.Error { - // 随机一个iss(对方将收到的序号) 防止黑客搞事 - b := make([]byte, 4) - if _, err := rand.Read(b); err != nil { - panic(err) - } - // 初始化状态为 SynSent - h.state = handshakeSynSent - log.Println("收到 syn 同步报文 设置tcp状态为 [sent]") - h.flags = flagSyn - h.ackNum = 0 - h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) - - return nil -} - -// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD -// state. -func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { - h.active = false - h.state = handshakeSynRcvd - log.Println("发送 syn|ack 确认报文 设置tcp状态为 [rcvd]") - h.flags = flagSyn | flagAck - h.iss = iss - h.ackNum = irs + 1 // NOTE ACK = synNum + 1 - h.mss = opts.MSS - h.sndWndScale = opts.WS -} - -func (h *handshake) resolveRoute() *tcpip.Error { - log.Printf("tcp resolveRoute") - // Set up the wakers. - s := sleep.Sleeper{} - resolutionWaker := &sleep.Waker{} - s.AddWaker(resolutionWaker, wakerForResolution) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - defer s.Done() - - // Initial action is to resolve route. - index := wakerForResolution - for { - log.Println(index) - switch index { - case wakerForResolution: - if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { - // Either success (err == nil) or failure. - return err - } - // Resolution not completed. Keep trying... - - case wakerForNotification: - // TODO - //n := h.ep.fetchNotifications() - //if n¬ifyClose != 0 { - // h.ep.route.RemoveWaker(resolutionWaker) - // return tcpip.ErrAborted - //} - //if n¬ifyDrain != 0 { - // close(h.ep.drainDone) - // <-h.ep.undrain - //} - } - - // Wait for notification. - index, _ = s.Fetch(true) - } -} - -// execute executes the TCP 3-way handshake. -// 执行tcp 3次握手,客户端和服务端都是调用该函数来实现三次握手 -/* - c flag s - | | - sync_sent|------sync---->|sync_rcvd - | | - | | - established|<--sync|ack----| - | | - | | - |------ack----->|established -*/ -func (h *handshake) execute() *tcpip.Error { - // 是否需要拿到下一条地址 - if h.ep.route.IsResolutionRequired() { - if err := h.resolveRoute(); err != nil { - return err - } - } - // Initialize the resend timer. - // 初始化重传定时器 - resendWaker := sleep.Waker{} - // 设置1s超时 - timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, func() { - resendWaker.Assert() - }) - defer rt.Stop() - - // Set up the wakers. - s := sleep.Sleeper{} - s.AddWaker(&resendWaker, wakerForResend) - s.AddWaker(&h.ep.notificationWaker, wakerForNotification) - s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) - defer s.Done() - - // sync报文的选项参数 - synOpts := header.TCPSynOptions{} - // 如果是客户端发送 syn 报文,如果是服务端发送 syn+ack 报文 - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - - for h.state != handshakeCompleted { - // 获取事件id - switch index, _ := s.Fetch(true); index { - case wakerForResend: // NOTE tcp超时重传机制 - // 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传 - // 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传 - // 超时时间变为上次的2倍 - timeOut *= 2 - if timeOut > 60*time.Second { - return tcpip.ErrTimeout - } - rt.Reset(timeOut) - // 重新发送syn报文 - sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - - case wakerForNotification: - - case wakerForNewSegment: - // 处理握手报文 - } - } - return nil -} - -var optionPool = sync.Pool{ - New: func() interface{} { - return make([]byte, maxOptionSize) - }, -} - -// 减少资源浪费 -func getOptions() []byte { - return optionPool.Get().([]byte) -} - -func putOptions(options []byte) { - // Reslice to full capacity. - optionPool.Put(options[0:cap(options)]) -} - -// tcp选项的编码 将一个TCPSyncOptions编码到 []byte 中 -func makeSynOptions(opts header.TCPSynOptions) []byte { - // Emulate linux option order. This is as follows: - // - // if md5: NOP NOP MD5SIG 18 md5sig(16) - // if mss: MSS 4 mss(2) - // if ts and sack_advertise: - // SACK 2 TIMESTAMP 2 timestamp(8) - // elif ts: NOP NOP TIMESTAMP 10 timestamp(8) - // elif sack: NOP NOP SACK 2 - // if wscale: NOP WINDOW 3 ws(1) - // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8)) - // [for each block] start_seq(4) end_seq(4) - // if fastopen_cookie: - // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2) - // else: FASTOPEN (2 + len(cookie)) - // cookie(variable) [padding to four bytes] - // - options := getOptions() - - // Always encode the mss. - offset := header.EncodeMSSOption(uint32(opts.MSS), options) - - // Special ordering is required here. If both TS and SACK are enabled, - // then the SACK option precedes TS, with no padding. If they are - // enabled individually, then we see padding before the option. - if opts.TS && opts.SACKPermitted { - offset += header.EncodeSACKPermittedOption(options[offset:]) - offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) - } else if opts.TS { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) - } else if opts.SACKPermitted { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeSACKPermittedOption(options[offset:]) - } - - // Initialize the WS option. - if opts.WS >= 0 { - offset += header.EncodeNOP(options[offset:]) - offset += header.EncodeWSOption(opts.WS, options[offset:]) - } - - // Padding to the end; note that this never apply unless we add a - // fastopen option, we always expect the offset to remain the same. - if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { - panic("unexpected option encoding") - } - - return options[:offset] -} - -// 封装 sendTCP ,发送 syn 报文 -func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, - seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { - if opts.MSS == 0 { - opts.MSS = uint16(r.MTU() - header.TCPMinimumSize) - } - options := makeSynOptions(opts) - err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options) - return err -} - -// sendTCP sends a TCP segment with the provided options via the provided -// network endpoint and under the provided identity. -// 发送一个tcp段数据,封装 tcp 首部,并写入网路层 -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, - seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { - log.Println("进行一个报文的发送") - optLen := len(opts) - // Allocate a buffer for the TCP header. - hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) - - if rcvWnd > 0xffff { - rcvWnd = 0xffff - } - - // Initialize the header. - tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) - tcp.Encode(&header.TCPFields{ - SrcPort: id.LocalPort, - DstPort: id.RemotePort, - SeqNum: uint32(seq), - AckNum: uint32(ack), - DataOffset: uint8(header.TCPMinimumSize + optLen), - Flags: flags, - WindowSize: uint16(rcvWnd), - }) - copy(tcp[header.TCPMinimumSize:], opts) - - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { - length := uint16(hdr.UsedLength() + data.Size()) - // tcp伪首部校验和的计算 - xsum := r.PseudoHeaderChecksum(ProtocolNumber) - for _, v := range data.Views() { - xsum = header.Checksum(v, xsum) - } - - // tcp的可靠性:校验和的计算,用于检测损伤的报文段 - tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) - } - - r.Stats().TCP.SegmentsSent.Increment() - if (flags & flagRst) != 0 { - r.Stats().TCP.ResetsSent.Increment() - } - - log.Printf("send tcp %s segment to %s, seq: %d, ack: %d, rcvWnd: %d", - flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort), - seq, ack, rcvWnd) - - return r.WritePacket(hdr, data, ProtocolNumber, ttl) -} - -// protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行,负责握手、发送段和处理收到的段 -func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { - for { - log.Println("三次握手机制在这里实现") - select {} - } -} +package tcp + +import ( + "crypto/rand" + "fmt" + "log" + "netstack/sleep" + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/seqnum" + "netstack/tcpip/stack" + "sync" + "time" +) + +const maxSegmentsPerWake = 100 + +type handshakeState int + +const ( + handshakeSynSent handshakeState = iota + handshakeSynRcvd + handshakeCompleted +) + +// The following are used to set up sleepers. +const ( + wakerForNotification = iota + wakerForNewSegment + wakerForResend + wakerForResolution +) + +// handshake holds the state used during a TCP 3-way handshake. +// tcp三次握手时候使用的对象 +type handshake struct { + ep *endpoint + // 握手的状态 + state handshakeState + active bool + flags uint8 + ackNum seqnum.Value + + // iss is the initial send sequence number, as defined in RFC 793. + // 初始序列号 + iss seqnum.Value + + // rcvWnd is the receive window, as defined in RFC 793. + // 接收窗口 + rcvWnd seqnum.Size + + // sndWnd is the send window, as defined in RFC 793. + // 发送窗口 + sndWnd seqnum.Size + + // mss is the maximum segment size received from the peer. + // 最大报文段大小 + mss uint16 + + // sndWndScale is the send window scale, as defined in RFC 1323. A + // negative value means no scaling is supported by the peer. + // 发送窗口扩展因子 + sndWndScale int + + // rcvWndScale is the receive window scale, as defined in RFC 1323. + // 接收窗口扩展因子 + rcvWndScale int +} + +const ( + // Maximum space available for options. + // tcp选项的最大长度 + maxOptionSize = 40 +) + +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { + h := handshake{ + ep: ep, + active: true, // 激活这个管理器 + rcvWnd: rcvWnd, // 初始接收窗口 + // TODO + } + if err := h.resetState(); err != nil { + return handshake{}, err + } + return h, nil +} + +func (h *handshake) resetState() *tcpip.Error { + // 随机一个iss(对方将收到的序号) 防止黑客搞事 + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + // 初始化状态为 SynSent + h.state = handshakeSynSent + log.Println("收到 syn 同步报文 设置tcp状态为 [sent]") + h.flags = flagSyn + h.ackNum = 0 + h.mss = 0 + h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + + return nil +} + +// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD +// state. +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { + h.active = false + h.state = handshakeSynRcvd + log.Println("发送 syn|ack 确认报文 设置tcp状态为 [rcvd]") + h.flags = flagSyn | flagAck + h.iss = iss + h.ackNum = irs + 1 // NOTE ACK = synNum + 1 + h.mss = opts.MSS + h.sndWndScale = opts.WS +} + +func (h *handshake) resolveRoute() *tcpip.Error { + log.Printf("tcp resolveRoute") + // Set up the wakers. + s := sleep.Sleeper{} + resolutionWaker := &sleep.Waker{} + s.AddWaker(resolutionWaker, wakerForResolution) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + defer s.Done() + + // Initial action is to resolve route. + index := wakerForResolution + for { + log.Println(index) + switch index { + case wakerForResolution: + if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + // Either success (err == nil) or failure. + return err + } + // Resolution not completed. Keep trying... + + case wakerForNotification: + // TODO + //n := h.ep.fetchNotifications() + //if n¬ifyClose != 0 { + // h.ep.route.RemoveWaker(resolutionWaker) + // return tcpip.ErrAborted + //} + //if n¬ifyDrain != 0 { + // close(h.ep.drainDone) + // <-h.ep.undrain + //} + } + + // Wait for notification. + index, _ = s.Fetch(true) + } +} + +// execute executes the TCP 3-way handshake. +// 执行tcp 3次握手,客户端和服务端都是调用该函数来实现三次握手 +/* + c flag s + | | + sync_sent|------sync---->|sync_rcvd + | | + | | + established|<--sync|ack----| + | | + | | + |------ack----->|established +*/ +func (h *handshake) execute() *tcpip.Error { + // 是否需要拿到下一条地址 + if h.ep.route.IsResolutionRequired() { + if err := h.resolveRoute(); err != nil { + return err + } + } + // Initialize the resend timer. + // 初始化重传定时器 + resendWaker := sleep.Waker{} + // 设置1s超时 + timeOut := time.Duration(time.Second) + rt := time.AfterFunc(timeOut, func() { + resendWaker.Assert() + }) + defer rt.Stop() + + // Set up the wakers. + s := sleep.Sleeper{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + + // sync报文的选项参数 + synOpts := header.TCPSynOptions{} + // 如果是客户端发送 syn 报文,如果是服务端发送 syn+ack 报文 + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + for h.state != handshakeCompleted { + // 获取事件id + switch index, _ := s.Fetch(true); index { + case wakerForResend: // NOTE tcp超时重传机制 + // 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传 + // 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传 + // 超时时间变为上次的2倍 + timeOut *= 2 + if timeOut > 60*time.Second { + return tcpip.ErrTimeout + } + rt.Reset(timeOut) + // 重新发送syn报文 + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + case wakerForNotification: + + case wakerForNewSegment: + // 处理握手报文 + } + } + return nil +} + +var optionPool = sync.Pool{ + New: func() interface{} { + return make([]byte, maxOptionSize) + }, +} + +// 减少资源浪费 +func getOptions() []byte { + return optionPool.Get().([]byte) +} + +func putOptions(options []byte) { + // Reslice to full capacity. + optionPool.Put(options[0:cap(options)]) +} + +// tcp选项的编码 将一个TCPSyncOptions编码到 []byte 中 +func makeSynOptions(opts header.TCPSynOptions) []byte { + // Emulate linux option order. This is as follows: + // + // if md5: NOP NOP MD5SIG 18 md5sig(16) + // if mss: MSS 4 mss(2) + // if ts and sack_advertise: + // SACK 2 TIMESTAMP 2 timestamp(8) + // elif ts: NOP NOP TIMESTAMP 10 timestamp(8) + // elif sack: NOP NOP SACK 2 + // if wscale: NOP WINDOW 3 ws(1) + // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8)) + // [for each block] start_seq(4) end_seq(4) + // if fastopen_cookie: + // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2) + // else: FASTOPEN (2 + len(cookie)) + // cookie(variable) [padding to four bytes] + // + options := getOptions() + + // Always encode the mss. + offset := header.EncodeMSSOption(uint32(opts.MSS), options) + + // Special ordering is required here. If both TS and SACK are enabled, + // then the SACK option precedes TS, with no padding. If they are + // enabled individually, then we see padding before the option. + if opts.TS && opts.SACKPermitted { + offset += header.EncodeSACKPermittedOption(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.TS { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.SACKPermitted { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKPermittedOption(options[offset:]) + } + + // Initialize the WS option. + if opts.WS >= 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeWSOption(opts.WS, options[offset:]) + } + + // Padding to the end; note that this never apply unless we add a + // fastopen option, we always expect the offset to remain the same. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +// 封装 sendTCP ,发送 syn 报文 +func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, + seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { + if opts.MSS == 0 { + opts.MSS = uint16(r.MTU() - header.TCPMinimumSize) + } + options := makeSynOptions(opts) + err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +// 发送一个tcp段数据,封装 tcp 首部,并写入网路层 +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, + seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { + log.Println("进行一个报文的发送") + optLen := len(opts) + // Allocate a buffer for the TCP header. + hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) + + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + // Initialize the header. + tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + tcp.Encode(&header.TCPFields{ + SrcPort: id.LocalPort, + DstPort: id.RemotePort, + SeqNum: uint32(seq), + AckNum: uint32(ack), + DataOffset: uint8(header.TCPMinimumSize + optLen), + Flags: flags, + WindowSize: uint16(rcvWnd), + }) + copy(tcp[header.TCPMinimumSize:], opts) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { + length := uint16(hdr.UsedLength() + data.Size()) + // tcp伪首部校验和的计算 + xsum := r.PseudoHeaderChecksum(ProtocolNumber) + for _, v := range data.Views() { + xsum = header.Checksum(v, xsum) + } + + // tcp的可靠性:校验和的计算,用于检测损伤的报文段 + tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) + } + + r.Stats().TCP.SegmentsSent.Increment() + if (flags & flagRst) != 0 { + r.Stats().TCP.ResetsSent.Increment() + } + + log.Printf("send tcp %s segment to %s, seq: %d, ack: %d, rcvWnd: %d", + flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort), + seq, ack, rcvWnd) + + return r.WritePacket(hdr, data, ProtocolNumber, ttl) +} + +// protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行,负责握手、发送段和处理收到的段 +func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { + for { + log.Println("三次握手机制在这里实现") + select {} + } +} diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go index 4fd6ab6..366019b 100644 --- a/tcpip/transport/tcp/endpoint.go +++ b/tcpip/transport/tcp/endpoint.go @@ -1,403 +1,403 @@ -package tcp - -import ( - "fmt" - "log" - "netstack/sleep" - "netstack/tcpip" - "netstack/tcpip/buffer" - "netstack/tcpip/header" - "netstack/tcpip/seqnum" - "netstack/tcpip/stack" - "netstack/waiter" - "sync" -) - -// tcp状态机的状态 -type endpointState int - -// tcp 状态机的各种状态 -const ( - stateInitial endpointState = iota - stateBound - stateListen - stateConnecting - stateConnected - stateClosed - stateError -) - -// endpoint 表示TCP端点。该结构用作端点用户和协议实现之间的接口;让并发goroutine调用端点是合法的, -// 它们是正确同步的。然而,协议实现在单个goroutine中运行。 -type endpoint struct { - stack *stack.Stack // 网络协议栈 - netProto tcpip.NetworkProtocolNumber // 网络协议号 ipv4 ipv6 - waiterQueue *waiter.Queue // 事件驱动机制 - - // TODO 需要添加 - - // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex - rcvList segmentList - rcvClosed bool - rcvBufSize int - rcvBufUsed int - - // The following fields are protected by the mutex. - mu sync.RWMutex - id stack.TransportEndpointID // tcp端在网络协议栈的唯一ID - state endpointState // 目前tcp状态机的状态 - isPortReserved bool // 是否已经分配端口 - isRegistered bool // 是否已经注册在网络协议栈 - boundNICID tcpip.NICID - route stack.Route // tcp端在网络协议栈中的路由地址 - v6only bool // 是否仅仅支持ipv6 - isConnectNotified bool - - // effectiveNetProtos contains the network protocols actually in use. In - // most cases it will only contain "netProto", but in cases like IPv6 - // endpoints with v6only set to false, this could include multiple - // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., - // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped - // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber - - // workerRunning specifies if a worker goroutine is running. - workerRunning bool - - // sendTSOk is used to indicate when the TS Option has been negotiated. - // When sendTSOk is true every non-RST segment should carry a TS as per - // RFC7323#section-1.1 - sendTSOk bool - - // recentTS is the timestamp that should be sent in the TSEcr field of - // the timestamp for future segments sent by the endpoint. This field is - // updated if required when a new segment is received by this endpoint. - recentTS uint32 - - // sackPermitted is set to true if the peer sends the TCPSACKPermitted - // option in the SYN/SYN-ACK. - sackPermitted bool - - segmentQueue segmentQueue - - // When the send side is closed, the protocol goroutine is notified via - // sndCloseWaker, and sndClosed is set to true. - sndBufMu sync.Mutex - sndBufSize int - sndBufUsed int - sndClosed bool - sndBufInQueue seqnum.Size - sndQueue segmentList - sndWaker sleep.Waker - sndCloseWaker sleep.Waker - - // notificationWaker is used to indicate to the protocol goroutine that - // it needs to wake up and check for notifications. - notificationWaker sleep.Waker - - // newSegmentWaker is used to indicate to the protocol goroutine that - // it needs to wake up and handle new segments queued to it. - // HandlePacket收到segment后通知处理的事件驱动器 - newSegmentWaker sleep.Waker - - // acceptedChan is used by a listening endpoint protocol goroutine to - // send newly accepted connections to the endpoint so that they can be - // read by Accept() calls. - acceptedChan chan *endpoint - - // The following are only used from the protocol goroutine, and - // therefore don't need locks to protect them. - rcv *receiver - snd *sender - - // The following are only used to assist the restore run to re-connect. - bindAddress tcpip.Address - connectingAddress tcpip.Address -} - -func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - e := &endpoint{ - stack: stack, - netProto: netProto, - waiterQueue: waiterQueue, - rcvBufSize: DefaultBufferSize, - sndBufSize: DefaultBufferSize, - } - // TODO 需要添加 - e.segmentQueue.setLimit(2 * e.rcvBufSize) - return e -} - -func (e *endpoint) Close() { - log.Println("TODO 在写了 在写了") -} - -func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return nil, tcpip.ControlMessages{}, nil -} - -func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - return 0, nil, nil -} - -func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { - return 0, tcpip.ControlMessages{}, nil -} - -func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { - netProto := e.netProto - if header.IsV4MappedAddress(addr.Addr) { - // Fail if using a v4 mapped address on a v6only endpoint. - if e.v6only { - return 0, tcpip.ErrNoRoute - } - - netProto = header.IPv4ProtocolNumber - addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] - if addr.Addr == "\x00\x00\x00\x00" { - addr.Addr = "" - } - } - - // Fail if we're bound to an address length different from the one we're - // checking. - if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { - return 0, tcpip.ErrInvalidEndpointState - } - - return netProto, nil -} - -func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error { - return nil -} - -func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { - return nil -} - -func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { - log.Println("监听一个tcp端口") - e.mu.Lock() - defer e.mu.Unlock() - defer func() { - if err != nil && err.IgnoreStats() { - e.stack.Stats().TCP.FailedConnectionAttempts.Increment() - } - }() - - // TODO 需要添加 - - // 在调用 Listen 之前,必须先 Bind - if e.state != stateBound { - return tcpip.ErrInvalidEndpointState - } - // 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。 - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, - e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { - return err - } - - e.isRegistered = true - e.state = stateListen - if e.acceptedChan == nil { - e.acceptedChan = make(chan *endpoint, backlog) - } - e.workerRunning = true - - e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() - // TODO tcp服务端实现的主循环,这个函数很重要,用一个goroutine来服务 - go e.protocolListenLoop(seqnum.Size(0)) - - return nil -} - -// startAcceptedLoop sets up required state and starts a goroutine with the -// main loop for accepted connections. -func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { - e.waiterQueue = waiterQueue - e.workerRunning = true - go e.protocolMainLoop(false) -} - -func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - // Endpoint must be in listen state before it can accept connections. - if e.state != stateListen { - return nil, nil, tcpip.ErrInvalidEndpointState - } - - var n *endpoint - select { - case n = <-e.acceptedChan: - default: - return nil, nil, tcpip.ErrWouldBlock - } - wq := &waiter.Queue{} - n.startAcceptedLoop(wq) - return n, wq, nil -} - -// Bind binds the endpoint to a specific local port and optionally address. -// 将端点绑定到特定的本地端口和可选的地址。 -func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { - e.mu.Lock() - defer e.mu.Unlock() - - // 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。 - if e.state != stateInitial { - return tcpip.ErrAlreadyBound - } - // 确定tcp端的绑定ip - e.bindAddress = addr.Addr - netProto, err := e.checkV4Mapped(&addr) - if err != nil { - return err - } - // 确定tcp支持的网络层协议 - netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, - } - } - // 绑定端口 - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) - if err != nil { - return err - } - e.isPortReserved = true - e.effectiveNetProtos = netProtos - e.id.LocalPort = port - - defer func() { - // 如果有错,在退出的时候应该解除端口绑定 - if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) - e.isPortReserved = false - e.effectiveNetProtos = nil - e.id.LocalPort = 0 - e.id.LocalAddress = "" - e.boundNICID = 0 - } - }() - // 如果指定了ip地址 需要检查一下这个ip地址本地是否绑定过 - if len(addr.Addr) != 0 { - nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) - if nic == 0 { - return tcpip.ErrBadLocalAddress - } - - e.boundNICID = nic - e.id.LocalAddress = addr.Addr - } - - // Check the commit function. - if commit != nil { - if err := commit(); err != nil { - // The defer takes care of unwind. - return err - } - } - // 标记状态为 stateBound - e.state = stateBound - - return nil -} - -func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - return tcpip.FullAddress{ - Addr: e.id.LocalAddress, - Port: e.id.LocalPort, - NIC: e.boundNICID, - }, nil -} - -func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - e.mu.RLock() - defer e.mu.RUnlock() - - if e.state != stateConnected { - return tcpip.FullAddress{}, tcpip.ErrNotConnected - } - - return tcpip.FullAddress{ - Addr: e.id.RemoteAddress, - Port: e.id.RemotePort, - NIC: e.boundNICID, - }, nil -} - -func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return waiter.EventErr -} - -func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return nil -} - -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - return nil -} - -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { - s := newSegment(r, id, vv) - // 解析tcp段,如果解析失败,丢弃该报文 - if !s.parse() { - e.stack.Stats().MalformedRcvdPackets.Increment() - e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() - s.decRef() - return - } - - e.stack.Stats().TCP.ValidSegmentsReceived.Increment() // 有效报文喜加一 - if (s.flags & flagRst) != 0 { // RST报文需要拒绝 - e.stack.Stats().TCP.ResetsReceived.Increment() - } - // Send packet to worker goroutine. - if e.segmentQueue.enqueue(s) { - log.Printf("收到 tcp [%s] 报文片段 from %s, seq: %d, ack: %d", - flagString(s.flags), fmt.Sprintf("%s:%d", s.id.RemoteAddress, s.id.RemotePort), - s.sequenceNumber, s.ackNumber) - e.newSegmentWaker.Assert() - } else { - // The queue is full, so we drop the segment. - e.stack.Stats().DroppedPackets.Increment() - s.decRef() - } -} - -func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { - -} - -// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if -// the SYN options indicate that timestamp option was negotiated. It also -// initializes the recentTS with the value provided in synOpts.TSval. -func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { - if synOpts.TS { - e.sendTSOk = true - e.recentTS = synOpts.TSVal - } -} - -// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint -// if the SYN options indicate that the SACK option was negotiated and the TCP -// stack is configured to enable TCP SACK option. -func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v SACKEnabled - if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { - // Stack doesn't support SACK. So just return. - return - } - if bool(v) && synOpts.SACKPermitted { - e.sackPermitted = true - } -} +package tcp + +import ( + "fmt" + "log" + "netstack/sleep" + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/seqnum" + "netstack/tcpip/stack" + "netstack/waiter" + "sync" +) + +// tcp状态机的状态 +type endpointState int + +// tcp 状态机的各种状态 +const ( + stateInitial endpointState = iota + stateBound + stateListen + stateConnecting + stateConnected + stateClosed + stateError +) + +// endpoint 表示TCP端点。该结构用作端点用户和协议实现之间的接口;让并发goroutine调用端点是合法的, +// 它们是正确同步的。然而,协议实现在单个goroutine中运行。 +type endpoint struct { + stack *stack.Stack // 网络协议栈 + netProto tcpip.NetworkProtocolNumber // 网络协议号 ipv4 ipv6 + waiterQueue *waiter.Queue // 事件驱动机制 + + // TODO 需要添加 + + // rcvListMu can be taken after the endpoint mu below. + rcvListMu sync.Mutex + rcvList segmentList + rcvClosed bool + rcvBufSize int + rcvBufUsed int + + // The following fields are protected by the mutex. + mu sync.RWMutex + id stack.TransportEndpointID // tcp端在网络协议栈的唯一ID + state endpointState // 目前tcp状态机的状态 + isPortReserved bool // 是否已经分配端口 + isRegistered bool // 是否已经注册在网络协议栈 + boundNICID tcpip.NICID + route stack.Route // tcp端在网络协议栈中的路由地址 + v6only bool // 是否仅仅支持ipv6 + isConnectNotified bool + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber + + // workerRunning specifies if a worker goroutine is running. + workerRunning bool + + // sendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1 + sendTSOk bool + + // recentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field is + // updated if required when a new segment is received by this endpoint. + recentTS uint32 + + // sackPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + sackPermitted bool + + segmentQueue segmentQueue + + // When the send side is closed, the protocol goroutine is notified via + // sndCloseWaker, and sndClosed is set to true. + sndBufMu sync.Mutex + sndBufSize int + sndBufUsed int + sndClosed bool + sndBufInQueue seqnum.Size + sndQueue segmentList + sndWaker sleep.Waker + sndCloseWaker sleep.Waker + + // notificationWaker is used to indicate to the protocol goroutine that + // it needs to wake up and check for notifications. + notificationWaker sleep.Waker + + // newSegmentWaker is used to indicate to the protocol goroutine that + // it needs to wake up and handle new segments queued to it. + // HandlePacket收到segment后通知处理的事件驱动器 + newSegmentWaker sleep.Waker + + // acceptedChan is used by a listening endpoint protocol goroutine to + // send newly accepted connections to the endpoint so that they can be + // read by Accept() calls. + acceptedChan chan *endpoint + + // The following are only used from the protocol goroutine, and + // therefore don't need locks to protect them. + rcv *receiver + snd *sender + + // The following are only used to assist the restore run to re-connect. + bindAddress tcpip.Address + connectingAddress tcpip.Address +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + e := &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSize: DefaultBufferSize, + sndBufSize: DefaultBufferSize, + } + // TODO 需要添加 + e.segmentQueue.setLimit(2 * e.rcvBufSize) + return e +} + +func (e *endpoint) Close() { + log.Println("TODO 在写了 在写了") +} + +func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return nil, tcpip.ControlMessages{}, nil +} + +func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + return 0, nil, nil +} + +func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error { + return nil +} + +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + return nil +} + +func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { + log.Println("监听一个tcp端口") + e.mu.Lock() + defer e.mu.Unlock() + defer func() { + if err != nil && err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() + + // TODO 需要添加 + + // 在调用 Listen 之前,必须先 Bind + if e.state != stateBound { + return tcpip.ErrInvalidEndpointState + } + // 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。 + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, + e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { + return err + } + + e.isRegistered = true + e.state = stateListen + if e.acceptedChan == nil { + e.acceptedChan = make(chan *endpoint, backlog) + } + e.workerRunning = true + + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + // TODO tcp服务端实现的主循环,这个函数很重要,用一个goroutine来服务 + go e.protocolListenLoop(seqnum.Size(0)) + + return nil +} + +// startAcceptedLoop sets up required state and starts a goroutine with the +// main loop for accepted connections. +func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.waiterQueue = waiterQueue + e.workerRunning = true + go e.protocolMainLoop(false) +} + +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // Endpoint must be in listen state before it can accept connections. + if e.state != stateListen { + return nil, nil, tcpip.ErrInvalidEndpointState + } + + var n *endpoint + select { + case n = <-e.acceptedChan: + default: + return nil, nil, tcpip.ErrWouldBlock + } + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + return n, wq, nil +} + +// Bind binds the endpoint to a specific local port and optionally address. +// 将端点绑定到特定的本地端口和可选的地址。 +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。 + if e.state != stateInitial { + return tcpip.ErrAlreadyBound + } + // 确定tcp端的绑定ip + e.bindAddress = addr.Addr + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + // 确定tcp支持的网络层协议 + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + // 绑定端口 + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) + if err != nil { + return err + } + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.id.LocalPort = port + + defer func() { + // 如果有错,在退出的时候应该解除端口绑定 + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.isPortReserved = false + e.effectiveNetProtos = nil + e.id.LocalPort = 0 + e.id.LocalAddress = "" + e.boundNICID = 0 + } + }() + // 如果指定了ip地址 需要检查一下这个ip地址本地是否绑定过 + if len(addr.Addr) != 0 { + nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + + e.boundNICID = nic + e.id.LocalAddress = addr.Addr + } + + // Check the commit function. + if commit != nil { + if err := commit(); err != nil { + // The defer takes care of unwind. + return err + } + } + // 标记状态为 stateBound + e.state = stateBound + + return nil +} + +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + NIC: e.boundNICID, + }, nil +} + +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + NIC: e.boundNICID, + }, nil +} + +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + return waiter.EventErr +} + +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return nil +} + +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + return nil +} + +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + s := newSegment(r, id, vv) + // 解析tcp段,如果解析失败,丢弃该报文 + if !s.parse() { + e.stack.Stats().MalformedRcvdPackets.Increment() + e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + s.decRef() + return + } + + e.stack.Stats().TCP.ValidSegmentsReceived.Increment() // 有效报文喜加一 + if (s.flags & flagRst) != 0 { // RST报文需要拒绝 + e.stack.Stats().TCP.ResetsReceived.Increment() + } + // Send packet to worker goroutine. + if e.segmentQueue.enqueue(s) { + log.Printf("收到 tcp [%s] 报文片段 from %s, seq: %d, ack: %d", + flagString(s.flags), fmt.Sprintf("%s:%d", s.id.RemoteAddress, s.id.RemotePort), + s.sequenceNumber, s.ackNumber) + e.newSegmentWaker.Assert() + } else { + // The queue is full, so we drop the segment. + e.stack.Stats().DroppedPackets.Increment() + s.decRef() + } +} + +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + +} + +// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if +// the SYN options indicate that timestamp option was negotiated. It also +// initializes the recentTS with the value provided in synOpts.TSval. +func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { + if synOpts.TS { + e.sendTSOk = true + e.recentTS = synOpts.TSVal + } +} + +// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint +// if the SYN options indicate that the SACK option was negotiated and the TCP +// stack is configured to enable TCP SACK option. +func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { + var v SACKEnabled + if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return + } + if bool(v) && synOpts.SACKPermitted { + e.sackPermitted = true + } +} diff --git a/tcpip/transport/tcp/protocol.go b/tcpip/transport/tcp/protocol.go index 8982ab6..09bae1b 100644 --- a/tcpip/transport/tcp/protocol.go +++ b/tcpip/transport/tcp/protocol.go @@ -1,73 +1,73 @@ -package tcp - -import ( - "netstack/tcpip" - "netstack/tcpip/buffer" - "netstack/tcpip/header" - "netstack/tcpip/stack" - "netstack/waiter" -) - -const ( - // ProtocolName is the string representation of the tcp protocol name. - ProtocolName = "tcp" - - // ProtocolNumber is the tcp protocol number. - ProtocolNumber = header.TCPProtocolNumber - // MinBufferSize is the smallest size of a receive or send buffer. - minBufferSize = 4 << 10 // 4096 bytes. - - // DefaultBufferSize is the default size of the receive and send buffers. - DefaultBufferSize = 1 << 20 // 1MB - - // MaxBufferSize is the largest size a receive and send buffer can grow to. - maxBufferSize = 4 << 20 // 4MB -) - -// SACKEnabled option can be used to enable SACK support in the TCP -// protocol. See: https://tools.ietf.org/html/rfc2018. -type SACKEnabled bool - -type protocol struct{} - -// Number returns the tcp protocol number. -func (*protocol) Number() tcpip.TransportProtocolNumber { - return ProtocolNumber -} - -// NewEndpoint creates a new tcp endpoint. -func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - return newEndpoint(stack, netProto, waiterQueue), nil -} - -// ParsePorts returns the source and destination ports stored in the given tcp -// packet. -func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { - h := header.TCP(v) - return h.SourcePort(), h.DestinationPort(), nil -} - -// MinimumPacketSize returns the minimum valid tcp packet size. -func (*protocol) MinimumPacketSize() int { - return header.TCPMinimumSize -} - -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { - return false -} - -// SetOption implements TransportProtocol.SetOption. -func (p *protocol) SetOption(option interface{}) *tcpip.Error { - return nil -} - -// Option implements TransportProtocol.Option. -func (p *protocol) Option(option interface{}) *tcpip.Error { - return nil -} - -func init() { - stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { - return &protocol{} - }) -} +package tcp + +import ( + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/stack" + "netstack/waiter" +) + +const ( + // ProtocolName is the string representation of the tcp protocol name. + ProtocolName = "tcp" + + // ProtocolNumber is the tcp protocol number. + ProtocolNumber = header.TCPProtocolNumber + // MinBufferSize is the smallest size of a receive or send buffer. + minBufferSize = 4 << 10 // 4096 bytes. + + // DefaultBufferSize is the default size of the receive and send buffers. + DefaultBufferSize = 1 << 20 // 1MB + + // MaxBufferSize is the largest size a receive and send buffer can grow to. + maxBufferSize = 4 << 20 // 4MB +) + +// SACKEnabled option can be used to enable SACK support in the TCP +// protocol. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +type protocol struct{} + +// Number returns the tcp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new tcp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// ParsePorts returns the source and destination ports stored in the given tcp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.TCP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// MinimumPacketSize returns the minimum valid tcp packet size. +func (*protocol) MinimumPacketSize() int { + return header.TCPMinimumSize +} + +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { + return false +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return nil +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return nil +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { + return &protocol{} + }) +} diff --git a/tcpip/transport/tcp/rcv.go b/tcpip/transport/tcp/rcv.go index ece6469..b151d52 100644 --- a/tcpip/transport/tcp/rcv.go +++ b/tcpip/transport/tcp/rcv.go @@ -1,11 +1,11 @@ -package tcp - -import "netstack/tcpip/seqnum" - -type receiver struct{} - -// 新建并初始化接收器 -func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { - r := &receiver{} - return r -} +package tcp + +import "netstack/tcpip/seqnum" + +type receiver struct{} + +// 新建并初始化接收器 +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { + r := &receiver{} + return r +} diff --git a/tcpip/transport/tcp/segment.go b/tcpip/transport/tcp/segment.go index 804c428..eab7a3b 100644 --- a/tcpip/transport/tcp/segment.go +++ b/tcpip/transport/tcp/segment.go @@ -1,135 +1,135 @@ -package tcp - -import ( - "fmt" - "log" - "netstack/tcpip/buffer" - "netstack/tcpip/header" - "netstack/tcpip/seqnum" - "netstack/tcpip/stack" - "strings" - "sync/atomic" -) - -// tcp 太复杂了 专门写一个协议解析器 segment 是有种类之分的 - -// Flags that may be set in a TCP segment. -const ( - flagFin = 1 << iota - flagSyn - flagRst - flagPsh - flagAck - flagUrg -) - -func flagString(flags uint8) string { - var s []string - if (flags & flagAck) != 0 { - s = append(s, "ack") - } - if (flags & flagFin) != 0 { - s = append(s, "fin") - } - if (flags & flagPsh) != 0 { - s = append(s, "psh") - } - if (flags & flagRst) != 0 { - s = append(s, "rst") - } - if (flags & flagSyn) != 0 { - s = append(s, "syn") - } - if (flags & flagUrg) != 0 { - s = append(s, "urg") - } - return strings.Join(s, "|") -} - -// segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中 -type segment struct { - segmentEntry - refCnt int32 // 引用计数 - id stack.TransportEndpointID - route stack.Route - data buffer.VectorisedView - // views is used as buffer for data when its length is large - // enough to store a VectorisedView. - views [8]buffer.View - // TODO 需要解析 - viewToDeliver int - sequenceNumber seqnum.Value // tcp序号 第一个字节在整个报文的位置 - ackNumber seqnum.Value // 确认号 希望继续获取的下一个字节序号 - flags uint8 - window seqnum.Size - // parsedOptions stores the parsed values from the options in the segment. - parsedOptions header.TCPOptions - options []byte -} - -func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { - s := &segment{refCnt: 1, id: id, route: r.Clone()} - s.data = vv.Clone(s.views[:]) - return s -} - -func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { - s := &segment{ - refCnt: 1, - id: id, - route: r.Clone(), - } - s.views[0] = v - s.data = buffer.NewVectorisedView(len(v), s.views[:1]) // TODO 为什么只复制1? - return s -} - -func (s *segment) clone() *segment { - t := &segment{ - refCnt: 1, - id: s.id, - sequenceNumber: s.sequenceNumber, - ackNumber: s.ackNumber, - flags: s.flags, - window: s.window, - route: s.route.Clone(), - viewToDeliver: s.viewToDeliver, - } - t.data = s.data.Clone(t.views[:]) - return t -} - -func (s *segment) flagIsSet(flag uint8) bool { - return (s.flags & flag) != 0 -} - -func (s *segment) decRef() { - if atomic.AddInt32(&s.refCnt, -1) == 0 { - s.route.Release() - } -} - -func (s *segment) incRef() { - atomic.AddInt32(&s.refCnt, 1) -} - -func (s *segment) parse() bool { - h := header.TCP(s.data.First()) - offset := int(h.DataOffset()) - if offset < header.TCPMinimumSize || offset > len(h) { - return false - } - s.options = h.Options() - s.parsedOptions = header.ParseTCPOptions(s.options) - - log.Println(h) - fmt.Println(s.parsedOptions) - - s.data.TrimFront(offset) - - s.sequenceNumber = seqnum.Value(h.SequenceNumber()) - s.ackNumber = seqnum.Value(h.AckNumber()) - s.flags = h.Flags() // U|A|P|R|S|F - s.window = seqnum.Size(h.WindowSize()) - return true -} +package tcp + +import ( + "fmt" + "log" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/seqnum" + "netstack/tcpip/stack" + "strings" + "sync/atomic" +) + +// tcp 太复杂了 专门写一个协议解析器 segment 是有种类之分的 + +// Flags that may be set in a TCP segment. +const ( + flagFin = 1 << iota + flagSyn + flagRst + flagPsh + flagAck + flagUrg +) + +func flagString(flags uint8) string { + var s []string + if (flags & flagAck) != 0 { + s = append(s, "ack") + } + if (flags & flagFin) != 0 { + s = append(s, "fin") + } + if (flags & flagPsh) != 0 { + s = append(s, "psh") + } + if (flags & flagRst) != 0 { + s = append(s, "rst") + } + if (flags & flagSyn) != 0 { + s = append(s, "syn") + } + if (flags & flagUrg) != 0 { + s = append(s, "urg") + } + return strings.Join(s, "|") +} + +// segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中 +type segment struct { + segmentEntry + refCnt int32 // 引用计数 + id stack.TransportEndpointID + route stack.Route + data buffer.VectorisedView + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View + // TODO 需要解析 + viewToDeliver int + sequenceNumber seqnum.Value // tcp序号 第一个字节在整个报文的位置 + ackNumber seqnum.Value // 确认号 希望继续获取的下一个字节序号 + flags uint8 + window seqnum.Size + // parsedOptions stores the parsed values from the options in the segment. + parsedOptions header.TCPOptions + options []byte +} + +func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { + s := &segment{refCnt: 1, id: id, route: r.Clone()} + s.data = vv.Clone(s.views[:]) + return s +} + +func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { + s := &segment{ + refCnt: 1, + id: id, + route: r.Clone(), + } + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) // TODO 为什么只复制1? + return s +} + +func (s *segment) clone() *segment { + t := &segment{ + refCnt: 1, + id: s.id, + sequenceNumber: s.sequenceNumber, + ackNumber: s.ackNumber, + flags: s.flags, + window: s.window, + route: s.route.Clone(), + viewToDeliver: s.viewToDeliver, + } + t.data = s.data.Clone(t.views[:]) + return t +} + +func (s *segment) flagIsSet(flag uint8) bool { + return (s.flags & flag) != 0 +} + +func (s *segment) decRef() { + if atomic.AddInt32(&s.refCnt, -1) == 0 { + s.route.Release() + } +} + +func (s *segment) incRef() { + atomic.AddInt32(&s.refCnt, 1) +} + +func (s *segment) parse() bool { + h := header.TCP(s.data.First()) + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { + return false + } + s.options = h.Options() + s.parsedOptions = header.ParseTCPOptions(s.options) + + log.Println(h) + fmt.Println(s.parsedOptions) + + s.data.TrimFront(offset) + + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() // U|A|P|R|S|F + s.window = seqnum.Size(h.WindowSize()) + return true +} diff --git a/tcpip/transport/tcp/segment_queue.go b/tcpip/transport/tcp/segment_queue.go index 1c9735d..e97e472 100644 --- a/tcpip/transport/tcp/segment_queue.go +++ b/tcpip/transport/tcp/segment_queue.go @@ -1,50 +1,50 @@ -package tcp - -import ( - "netstack/tcpip/header" - "sync" -) - -type segmentQueue struct { - mu sync.Mutex - list segmentList // 队列实现 - limit int // 队列容量 - used int // 队列长度 -} - -func (q *segmentQueue) empty() bool { - q.mu.Lock() - r := q.used == 0 - q.mu.Unlock() - return r -} - -func (q *segmentQueue) enqueue(s *segment) bool { - q.mu.Lock() - r := q.used < q.limit - if r { - q.list.PushBack(s) - q.used += s.data.Size() + header.TCPMinimumSize - } - q.mu.Unlock() - - return r -} - -func (q *segmentQueue) dequeue() *segment { - q.mu.Lock() - s := q.list.Front() - if s != nil { - q.list.Remove(s) - q.used -= s.data.Size() + header.TCPMinimumSize - } - q.mu.Unlock() - - return s -} - -func (q *segmentQueue) setLimit(limit int) { - q.mu.Lock() - q.limit = limit - q.mu.Unlock() -} +package tcp + +import ( + "netstack/tcpip/header" + "sync" +) + +type segmentQueue struct { + mu sync.Mutex + list segmentList // 队列实现 + limit int // 队列容量 + used int // 队列长度 +} + +func (q *segmentQueue) empty() bool { + q.mu.Lock() + r := q.used == 0 + q.mu.Unlock() + return r +} + +func (q *segmentQueue) enqueue(s *segment) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used += s.data.Size() + header.TCPMinimumSize + } + q.mu.Unlock() + + return r +} + +func (q *segmentQueue) dequeue() *segment { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used -= s.data.Size() + header.TCPMinimumSize + } + q.mu.Unlock() + + return s +} + +func (q *segmentQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go index 3585c15..3ba6026 100644 --- a/tcpip/transport/tcp/snd.go +++ b/tcpip/transport/tcp/snd.go @@ -1,12 +1,12 @@ -package tcp - -import "netstack/tcpip/seqnum" - -type sender struct { -} - -// 新建并初始化发送器 -func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { - s := &sender{} - return s -} +package tcp + +import "netstack/tcpip/seqnum" + +type sender struct { +} + +// 新建并初始化发送器 +func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { + s := &sender{} + return s +} diff --git a/tcpip/transport/tcp/tcp_segment_list.go b/tcpip/transport/tcp/tcp_segment_list.go index 029f98a..48b4640 100644 --- a/tcpip/transport/tcp/tcp_segment_list.go +++ b/tcpip/transport/tcp/tcp_segment_list.go @@ -1,173 +1,173 @@ -package tcp - -// ElementMapper provides an identity mapping by default. -// -// This can be replaced to provide a struct that maps elements to linker -// objects, if they are not the same. An ElementMapper is not typically -// required if: Linker is left as is, Element is left as is, or Linker and -// Element are the same type. -type segmentElementMapper struct{} - -// linkerFor maps an Element to a Linker. -// -// This default implementation should be inlined. -// -//go:nosplit -func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } - -// List is an intrusive list. Entries can be added to or removed from the list -// in O(1) time and with no additional memory allocations. -// -// The zero value for List is an empty list ready to use. -// -// To iterate over a list (where l is a List): -// for e := l.Front(); e != nil; e = e.Next() { -// // do something with e. -// } -// -// +stateify savable -type segmentList struct { - head *segment - tail *segment -} - -// Reset resets list l to the empty state. -func (l *segmentList) Reset() { - l.head = nil - l.tail = nil -} - -// Empty returns true iff the list is empty. -func (l *segmentList) Empty() bool { - return l.head == nil -} - -// Front returns the first element of list l or nil. -func (l *segmentList) Front() *segment { - return l.head -} - -// Back returns the last element of list l or nil. -func (l *segmentList) Back() *segment { - return l.tail -} - -// PushFront inserts the element e at the front of list l. -func (l *segmentList) PushFront(e *segment) { - segmentElementMapper{}.linkerFor(e).SetNext(l.head) - segmentElementMapper{}.linkerFor(e).SetPrev(nil) - - if l.head != nil { - segmentElementMapper{}.linkerFor(l.head).SetPrev(e) - } else { - l.tail = e - } - - l.head = e -} - -// PushBack inserts the element e at the back of list l. -func (l *segmentList) PushBack(e *segment) { - segmentElementMapper{}.linkerFor(e).SetNext(nil) - segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) - - if l.tail != nil { - segmentElementMapper{}.linkerFor(l.tail).SetNext(e) - } else { - l.head = e - } - - l.tail = e -} - -// PushBackList inserts list m at the end of list l, emptying m. -func (l *segmentList) PushBackList(m *segmentList) { - if l.head == nil { - l.head = m.head - l.tail = m.tail - } else if m.head != nil { - segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) - segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) - - l.tail = m.tail - } - - m.head = nil - m.tail = nil -} - -// InsertAfter inserts e after b. -func (l *segmentList) InsertAfter(b, e *segment) { - a := segmentElementMapper{}.linkerFor(b).Next() - segmentElementMapper{}.linkerFor(e).SetNext(a) - segmentElementMapper{}.linkerFor(e).SetPrev(b) - segmentElementMapper{}.linkerFor(b).SetNext(e) - - if a != nil { - segmentElementMapper{}.linkerFor(a).SetPrev(e) - } else { - l.tail = e - } -} - -// InsertBefore inserts e before a. -func (l *segmentList) InsertBefore(a, e *segment) { - b := segmentElementMapper{}.linkerFor(a).Prev() - segmentElementMapper{}.linkerFor(e).SetNext(a) - segmentElementMapper{}.linkerFor(e).SetPrev(b) - segmentElementMapper{}.linkerFor(a).SetPrev(e) - - if b != nil { - segmentElementMapper{}.linkerFor(b).SetNext(e) - } else { - l.head = e - } -} - -// Remove removes e from l. -func (l *segmentList) Remove(e *segment) { - prev := segmentElementMapper{}.linkerFor(e).Prev() - next := segmentElementMapper{}.linkerFor(e).Next() - - if prev != nil { - segmentElementMapper{}.linkerFor(prev).SetNext(next) - } else { - l.head = next - } - - if next != nil { - segmentElementMapper{}.linkerFor(next).SetPrev(prev) - } else { - l.tail = prev - } -} - -// Entry is a default implementation of Linker. Users can add anonymous fields -// of this type to their structs to make them automatically implement the -// methods needed by List. -// -// +stateify savable -type segmentEntry struct { - next *segment - prev *segment -} - -// Next returns the entry that follows e in the list. -func (e *segmentEntry) Next() *segment { - return e.next -} - -// Prev returns the entry that precedes e in the list. -func (e *segmentEntry) Prev() *segment { - return e.prev -} - -// SetNext assigns 'entry' as the entry that follows e in the list. -func (e *segmentEntry) SetNext(elem *segment) { - e.next = elem -} - -// SetPrev assigns 'entry' as the entry that precedes e in the list. -func (e *segmentEntry) SetPrev(elem *segment) { - e.prev = elem -} +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *segmentList) Back() *segment { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *segmentList) PushFront(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(l.head) + segmentElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *segmentList) PushBack(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(nil) + segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *segmentList) InsertAfter(b, e *segment) { + a := segmentElementMapper{}.linkerFor(b).Next() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *segmentList) InsertBefore(a, e *segment) { + b := segmentElementMapper{}.linkerFor(a).Prev() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *segmentList) Remove(e *segment) { + prev := segmentElementMapper{}.linkerFor(e).Prev() + next := segmentElementMapper{}.linkerFor(e).Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +}