From c7fe592b0f1d0affd6ae3f9fe6af8632b2814737 Mon Sep 17 00:00:00 2001 From: impact-eintr Date: Mon, 5 Dec 2022 18:20:21 +0800 Subject: [PATCH] =?UTF-8?q?tcp=E6=8A=A5=E6=96=87=E5=A4=B4=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E5=8F=AF=E8=A7=86=E5=8C=96=E8=A7=A3=E6=9E=90;?= =?UTF-8?q?=E4=BC=98=E5=8C=96IP=20udp=E6=8A=A5=E6=96=87=E7=9A=84=E5=8F=AF?= =?UTF-8?q?=E8=A7=86=E5=8C=96=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/netstack/main.go | 60 +++++-- tcpip/header/ipv4.go | 22 +-- tcpip/header/tcp.go | 112 +++++++++++++ tcpip/header/udp.go | 19 ++- tcpip/network/ipv4/ipv4.go | 2 +- tcpip/seqnum/seqnum.go | 50 ++++++ tcpip/transport/tcp/accept.go | 11 ++ tcpip/transport/tcp/connect.go | 14 ++ tcpip/transport/tcp/endpoint.go | 201 ++++++++++++++++++++++-- tcpip/transport/tcp/segment.go | 45 ++++++ tcpip/transport/tcp/tcp_segment_list.go | 173 ++++++++++++++++++++ tcpip/transport/udp/endpoint.go | 2 +- 12 files changed, 665 insertions(+), 46 deletions(-) create mode 100644 tcpip/seqnum/seqnum.go create mode 100644 tcpip/transport/tcp/accept.go create mode 100644 tcpip/transport/tcp/connect.go create mode 100644 tcpip/transport/tcp/segment.go create mode 100644 tcpip/transport/tcp/tcp_segment_list.go diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index 300730c..1cfbfd2 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -124,7 +124,7 @@ func main() { //go func() { // echo server // // 监听udp localPort端口 - // conn := udpListen(s, proto, localPort) + // conn := udpListen(s, proto, addr, localPort) // for { // buf := make([]byte, 1024) @@ -141,23 +141,48 @@ func main() { //}() go func() { // echo server - conn := tcpListen(s, proto, localPort) - conn.Read(nil) + listener := tcpListen(s, proto, addr, localPort) + for { + conn, err := listener.Accept() + if err != nil { + log.Println(err) + continue + } + conn.Read(nil) + } }() c := make(chan os.Signal) signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2) <-c - - //conn, _ := net.Listen("tcp", "0.0.0.0:9999") - //rcv := &RCV{ - // Stack: s, - // addr: tcpip.FullAddress{}, - //} - //TCPServer(conn, rcv) } -func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { +type TcpConn struct { + raddr tcpip.FullAddress + ep tcpip.Endpoint + wq *waiter.Queue + we *waiter.Entry + notifyCh chan struct{} +} + +// Accept 封装tcp的accept操作 +func (conn *TcpConn) Accept() (tcpip.Endpoint, error) { + conn.wq.EventRegister(conn.we, waiter.EventIn) + defer conn.wq.EventUnregister(conn.we) + for { + ep, _, err := conn.ep.Accept() + if err != nil { + if err == tcpip.ErrWouldBlock { + <-conn.notifyCh + continue + } + return nil, fmt.Errorf("%s", err.String()) + } + return ep, nil + } +} + +func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn { var wq waiter.Queue // 新建一个tcp端 ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) @@ -167,7 +192,7 @@ func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP // 此时就会调用端口管理器 - if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}, nil); err != nil { + if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil { log.Fatal("Bind failed: ", err) } @@ -176,7 +201,12 @@ func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) log.Fatal("Listen failed: ", err) } - return ep + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + return &TcpConn{ + ep: ep, + wq: &wq, + we: &waitEntry, + notifyCh: notifyCh} } type UdpConn struct { @@ -226,7 +256,7 @@ func (conn *UdpConn) Write(snd []byte) error { } } -func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) *UdpConn { +func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn { var wq waiter.Queue // 新建一个udp端 ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) @@ -237,7 +267,7 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 // 此时就会调用端口管理器 - if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}, nil); err != nil { + if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil { log.Fatal("Bind failed: ", err) } diff --git a/tcpip/header/ipv4.go b/tcpip/header/ipv4.go index 4dc249e..0d01b9d 100644 --- a/tcpip/header/ipv4.go +++ b/tcpip/header/ipv4.go @@ -180,6 +180,16 @@ func (b IPv4) Payload() []byte { return b[b.HeaderLength():][:b.PayloadLength()] } +// IPViewSize IP报文内容概览 长度 +const IPViewSize = 128 + +func (b IPv4) viewPayload() []byte { + if b.PayloadLength() < IPViewSize { + return b[b.HeaderLength():][:b.PayloadLength()] + } + return b[b.HeaderLength():][:IPViewSize] +} + // PayloadLength returns the length of the payload portion of the ipv4 packet. func (b IPv4) PayloadLength() uint16 { return b.TotalLength() - uint16(b.HeaderLength()) @@ -294,20 +304,10 @@ func atoi[T int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32] } func (b IPv4) String() string { - for i := range b.Payload() { - if i != int(b.PayloadLength()-1) && b.Payload()[i]^b.Payload()[i+1] != 0 { - return fmt.Sprintf(ipv4Fmt, atoi(IPVersion(b)), atoi(b.HeaderLength()), atoi(0), atoi(b.TotalLength()), - atoi(b.ID()), atoi(b.Flags()>>2), atoi((b.Flags()&2)>>1), atoi(b.Flags()&1), atoi(b.FragmentOffset()), - atoi(b.TTL()), atoi(b.Protocol()), atoi(b.Checksum()), - b.SourceAddress().String(), - b.DestinationAddress().String(), - b.Payload()) - } - } return fmt.Sprintf(ipv4Fmt, atoi(IPVersion(b)), atoi(b.HeaderLength()), atoi(0), atoi(b.TotalLength()), atoi(b.ID()), atoi(b.Flags()>>2), atoi((b.Flags()&2)>>1), atoi(b.Flags()&1), atoi(b.FragmentOffset()), atoi(b.TTL()), atoi(b.Protocol()), atoi(b.Checksum()), b.SourceAddress().String(), b.DestinationAddress().String(), - fmt.Sprintf("%v x %d", b.Payload()[0], b.PayloadLength())) + b.viewPayload()) } diff --git a/tcpip/header/tcp.go b/tcpip/header/tcp.go index 0b7e319..a2d4cfc 100644 --- a/tcpip/header/tcp.go +++ b/tcpip/header/tcp.go @@ -2,6 +2,7 @@ package header import ( "encoding/binary" + "fmt" "netstack/tcpip" ) @@ -91,3 +92,114 @@ func (b TCP) SourcePort() uint16 { func (b TCP) DestinationPort() uint16 { return binary.BigEndian.Uint16(b[dstPort:]) } + +// SequenceNumber returns the "sequence number" field of the tcp header. +func (b TCP) SequenceNumber() uint32 { + return binary.BigEndian.Uint32(b[seqNum:]) +} + +// AckNumber returns the "ack number" field of the tcp header. +func (b TCP) AckNumber() uint32 { + return binary.BigEndian.Uint32(b[ackNum:]) +} + +// DataOffset returns the "data offset" field of the tcp header. +func (b TCP) DataOffset() uint8 { + return (b[dataOffset] >> 4) * 4 // 以32bits为单位 最小为5 20bytes +} + +// Payload returns the data in the tcp packet. +func (b TCP) Payload() []byte { + return b[b.DataOffset():] +} + +// TCPViewSize TCP报文概览长度 +const TCPViewSize = IPViewSize - TCPMinimumSize + +func (b TCP) viewPayload() []byte { + if len(b.Payload())-int(b.DataOffset()) < TCPViewSize { + return b.Payload() + } + return b[b.DataOffset():][:TCPViewSize] +} + +// Flags returns the flags field of the tcp header. +func (b TCP) Flags() uint8 { + return b[tcpFlags] +} + +// WindowSize returns the "window size" field of the tcp header. +func (b TCP) WindowSize() uint16 { + return binary.BigEndian.Uint16(b[winSize:]) +} + +// Checksum returns the "checksum" field of the tcp header. +func (b TCP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[tcpChecksum:]) +} + +// UrgentPtr returns the "urgentptr" field of the tcp header. +func (b TCP) UrgentPtr() uint16 { + return binary.BigEndian.Uint16(b[urgentPtr:]) +} + +// SetSourcePort sets the "source port" field of the tcp header. +func (b TCP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[srcPort:], port) +} + +// SetDestinationPort sets the "destination port" field of the tcp header. +func (b TCP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[dstPort:], port) +} + +// SetChecksum sets the checksum field of the tcp header. +func (b TCP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[tcpChecksum:], checksum) +} + +// Options returns a slice that holds the unparsed TCP options in the segment. +func (b TCP) Options() []byte { + return b[TCPMinimumSize:b.DataOffset()] +} + +/* + 0 1 2 3 + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Source Port | Destination Port | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Sequence Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Acknowledgment Number | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Data | |U|A|P|R|S|F| | +| Offset| Reserved |R|C|S|S|Y|I| Window | +| | |G|K|H|T|N|N| | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Checksum | Urgent Pointer | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| Options | Padding | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +| data | ++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +*/ + +var tcpFmt string = ` +|% 16s|% 16s| +|% 32s | +|% 32s | +|% 4s|% 4s|%06b|% 16s| +|% 16s|% 16s| +| Options | Padding | +%v +` + +func (b TCP) String() string { + return fmt.Sprintf(tcpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), + atoi(b.SequenceNumber()), + atoi(b.AckNumber()), + atoi(b.DataOffset()), "0", b.Flags(), atoi(b.WindowSize()), + atoi(b.Checksum()), atoi(b.UrgentPtr()), + b.viewPayload()) +} diff --git a/tcpip/header/udp.go b/tcpip/header/udp.go index 62eb194..1042a65 100644 --- a/tcpip/header/udp.go +++ b/tcpip/header/udp.go @@ -77,6 +77,16 @@ func (b UDP) Payload() []byte { return b[UDPMinimumSize:] } +// UDPViewSize UDP报文内容概览 长度 +const UDPViewSize = IPViewSize - UDPMinimumSize + +func (b UDP) viewPayload() []byte { + if b.Length()-UDPMinimumSize < UDPViewSize { + return b[UDPMinimumSize:] + } + return b[UDPMinimumSize:][:UDPViewSize] +} + // Checksum returns the "checksum" field of the udp header. func (b UDP) Checksum() uint16 { return binary.BigEndian.Uint16(b[udpChecksum:]) @@ -125,14 +135,7 @@ var udpFmt string = ` ` func (b UDP) String() string { - for i := range b.Payload() { - if i != int(b.Length()-8-1) && b.Payload()[i]^b.Payload()[i+1] != 0 { - return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), - atoi(b.Length()), atoi(b.Checksum()), - b.Payload()) - } - } return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), atoi(b.Length()), atoi(b.Checksum()), - fmt.Sprintf("%v x %d", b.Payload()[0], b.Length()-8)) + b.viewPayload()) } diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go index efc42fd..cac91e0 100644 --- a/tcpip/network/ipv4/ipv4.go +++ b/tcpip/network/ipv4/ipv4.go @@ -121,7 +121,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b // 收到ip包的处理 func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // 得到ip报文 - h := header.IPv4(vv.ToView()) + h := header.IPv4(vv.First()) // 检查报文是否有效 if !h.IsValid(vv.Size()) { return diff --git a/tcpip/seqnum/seqnum.go b/tcpip/seqnum/seqnum.go new file mode 100644 index 0000000..1e968e8 --- /dev/null +++ b/tcpip/seqnum/seqnum.go @@ -0,0 +1,50 @@ +package seqnum + +// Value represents the value of a sequence number. +type Value uint32 + +// Size represents the size (length) of a sequence number window +type Size uint32 + +// LessThan v < w +func (v Value) LessThan(w Value) bool { + return int32(v-w) < 0 +} + +// LessThanEq returns true if v==w or v is before i.e., v < w. +func (v Value) LessThanEq(w Value) bool { + if v == w { + return true + } + return v.LessThan(w) +} + +// InRange v ∈ [a, b) +func (v Value) InRange(a, b Value) bool { + return a <= v && v < b +} + +// InWindows check v in [first, first+size) +func (v Value) InWindows(first Value, size Size) bool { + return v.InRange(first, first.Add(size)) +} + +// Add return v + s +func (v Value) Add(s Size) Value { + return v + Value(s) +} + +// Size return the size of [v, w) +func (v Value) Size(w Value) Size { + return Size(w - v) +} + +// UpdateForward update the value to v+s +func (v *Value) UpdateForward(s Size) { + *v += Value(s) +} + +// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y). +func Overlap(a Value, b Size, x Value, y Size) bool { + return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b)) +} diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go new file mode 100644 index 0000000..c6e0af3 --- /dev/null +++ b/tcpip/transport/tcp/accept.go @@ -0,0 +1,11 @@ +package tcp + +import ( + "netstack/tcpip" + "netstack/tcpip/seqnum" +) + +// protocolListenLoop 是侦听TCP端点的主循环。它在自己的goroutine中运行,负责处理连接请求 +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + select {} +} diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go new file mode 100644 index 0000000..0f8f703 --- /dev/null +++ b/tcpip/transport/tcp/connect.go @@ -0,0 +1,14 @@ +package tcp + +import ( + "log" + "netstack/tcpip" +) + +// protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行,负责握手、发送段和处理收到的段 +func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { + for { + log.Println("三次握手机制在这里实现") + select {} + } +} diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go index f0be48a..664a915 100644 --- a/tcpip/transport/tcp/endpoint.go +++ b/tcpip/transport/tcp/endpoint.go @@ -4,6 +4,8 @@ import ( "log" "netstack/tcpip" "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/seqnum" "netstack/tcpip/stack" "netstack/waiter" "sync" @@ -34,14 +36,34 @@ type endpoint struct { // The following fields are protected by the mutex. mu sync.RWMutex - id stack.TransportEndpointID - state endpointState - isPortReserved bool - isRegistered bool + id stack.TransportEndpointID // tcp端在网络协议栈的唯一ID + state endpointState // 目前tcp状态机的状态 + isPortReserved bool // 是否已经分配端口 + isRegistered bool // 是否已经注册在网络协议栈 boundNICID tcpip.NICID - route stack.Route - v6only bool + route stack.Route // tcp端在网络协议栈中的路由地址 + v6only bool // 是否仅仅支持ipv6 isConnectNotified bool + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber + + // workerRunning specifies if a worker goroutine is running. + workerRunning bool + + // acceptedChan is used by a listening endpoint protocol goroutine to + // send newly accepted connections to the endpoint so that they can be + // read by Accept() calls. + acceptedChan chan *endpoint + + // The following are only used to assist the restore run to re-connect. + bindAddress tcpip.Address + connectingAddress tcpip.Address } func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { @@ -71,6 +93,30 @@ func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) return 0, tcpip.ControlMessages{}, nil } +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error { return nil } @@ -79,18 +125,134 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { return nil } -func (e *endpoint) Listen(backlog int) *tcpip.Error { +func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { + log.Println("监听一个tcp端口") + e.mu.Lock() + defer e.mu.Unlock() + defer func() { + if err != nil && err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() + + // TODO 需要添加 + + // 在调用 Listen 之前,必须先 Bind + if e.state != stateBound { + return tcpip.ErrInvalidEndpointState + } + // 注册该端点,这样网络层在分发数据包的时候就可以根据 id 来找到这个端点,接着把报文发送给这个端点。 + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, + e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { + return err + } + + e.isRegistered = true + e.state = stateListen + if e.acceptedChan == nil { + e.acceptedChan = make(chan *endpoint, backlog) + } + e.workerRunning = true + + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + // TODO tcp服务端实现的主循环,这个函数很重要,用一个goroutine来服务 + go e.protocolListenLoop(seqnum.Size(0)) + return nil } +// startAcceptedLoop sets up required state and starts a goroutine with the +// main loop for accepted connections. +func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.waiterQueue = waiterQueue + e.workerRunning = true + go e.protocolMainLoop(false) +} + func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, nil + e.mu.RLock() + defer e.mu.RUnlock() + + // Endpoint must be in listen state before it can accept connections. + if e.state != stateListen { + return nil, nil, tcpip.ErrInvalidEndpointState + } + + var n *endpoint + select { + case n = <-e.acceptedChan: + default: + return nil, nil, tcpip.ErrWouldBlock + } + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + return n, wq, nil } // Bind binds the endpoint to a specific local port and optionally address. // 将端点绑定到特定的本地端口和可选的地址。 -func (e *endpoint) Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { - log.Println("绑定一个tcp端口") +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // 如果端点不是处于初始状态,则不允许绑定。这是因为一旦端点进入连接或监听状态,它就已经绑定了。 + if e.state != stateInitial { + return tcpip.ErrAlreadyBound + } + // 确定tcp端的绑定ip + e.bindAddress = addr.Addr + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + // 确定tcp支持的网络层协议 + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + // 绑定端口 + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) + if err != nil { + return err + } + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.id.LocalPort = port + + defer func() { + // 如果有错,在退出的时候应该解除端口绑定 + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.isPortReserved = false + e.effectiveNetProtos = nil + e.id.LocalPort = 0 + e.id.LocalAddress = "" + e.boundNICID = 0 + } + }() + // 如果指定了ip地址 需要检查一下这个ip地址本地是否绑定过 + if len(addr.Addr) != 0 { + nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + + e.boundNICID = nic + e.id.LocalAddress = addr.Addr + } + + // Check the commit function. + if commit != nil { + if err := commit(); err != nil { + // The defer takes care of unwind. + return err + } + } + // 标记状态为 stateBound + e.state = stateBound return nil } @@ -132,3 +294,22 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil } + +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + log.Println("接收到数据") + s := newSegment(r, id, vv) + // 解析tcp段,如果解析失败,丢弃该报文 + if !s.parse() { + e.stack.Stats().MalformedRcvdPackets.Increment() + e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + s.decRef() + return + } + + e.stack.Stats().TCP.ValidSegmentsReceived.Increment() // 有效报文喜加一 + log.Println(s) +} + +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + +} diff --git a/tcpip/transport/tcp/segment.go b/tcpip/transport/tcp/segment.go new file mode 100644 index 0000000..611a19f --- /dev/null +++ b/tcpip/transport/tcp/segment.go @@ -0,0 +1,45 @@ +package tcp + +import ( + "log" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/stack" + "sync/atomic" +) + +// tcp 太复杂了 专门写一个协议解析器 + +// segment 表示一个 TCP 段。它保存有效负载和解析的 TCP 段信息,并且可以添加到侵入列表中 +type segment struct { + segmentEntry + refCnt int32 + id stack.TransportEndpointID + route stack.Route + data buffer.VectorisedView + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View + // TODO 需要添加 +} + +func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { + s := &segment{refCnt: 1, id: id, route: r.Clone()} + s.data = vv.Clone(s.views[:]) + return s +} + +func (s *segment) decRef() { + if atomic.AddInt32(&s.refCnt, -1) == 0 { + s.route.Release() + } +} + +func (s *segment) incRef() { + atomic.AddInt32(&s.refCnt, 1) +} + +func (s *segment) parse() bool { + log.Println(header.TCP(s.data.First())) + return false +} diff --git a/tcpip/transport/tcp/tcp_segment_list.go b/tcpip/transport/tcp/tcp_segment_list.go new file mode 100644 index 0000000..029f98a --- /dev/null +++ b/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,173 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *segmentList) Back() *segment { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *segmentList) PushFront(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(l.head) + segmentElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *segmentList) PushBack(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(nil) + segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *segmentList) InsertAfter(b, e *segment) { + a := segmentElementMapper{}.linkerFor(b).Next() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *segmentList) InsertBefore(a, e *segment) { + b := segmentElementMapper{}.linkerFor(a).Prev() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *segmentList) Remove(e *segment) { + prev := segmentElementMapper{}.linkerFor(e).Prev() + next := segmentElementMapper{}.linkerFor(e).Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go index 3ff2481..b1707bb 100644 --- a/tcpip/transport/udp/endpoint.go +++ b/tcpip/transport/udp/endpoint.go @@ -834,7 +834,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // 但是队列存数据量是有限制的,这个限制叫接收缓存大小,当接收队列中的数据总和超过这个缓存,那么接下来的这些报文将会被直接丢包。 func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { // Get the header then trim it from the view. - hdr := header.UDP(vv.ToView()) + hdr := header.UDP(vv.First()) if int(hdr.Length()) > vv.Size() { // Malformed packet. // 错误报文