diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index 6ad9e4e..4d87cd8 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -142,14 +142,21 @@ func main() { go func() { // echo server listener := tcpListen(s, proto, addr, localPort) - for { - conn, err := listener.Accept() - if err != nil { - log.Println(err) - continue - } - conn.Read(nil) + conn, err := listener.Accept() + if err != nil { + log.Println(err) } + buf := make([]byte, 1024) + if _, err := conn.Read(buf); err != nil { + log.Fatal(err) + } + fmt.Println(string(buf)) + if string(buf) != "" { + conn.Write([]byte("Server echo")) + } + os.Exit(1) + + select {} }() c := make(chan os.Signal) @@ -165,24 +172,72 @@ type TcpConn struct { notifyCh chan struct{} } -// Accept 封装tcp的accept操作 -func (conn *TcpConn) Accept() (tcpip.Endpoint, error) { +func (conn *TcpConn) Read(rcv []byte) (int, error) { conn.wq.EventRegister(conn.we, waiter.EventIn) defer conn.wq.EventUnregister(conn.we) for { - ep, _, err := conn.ep.Accept() + buf, _, err := conn.ep.Read(&conn.raddr) if err != nil { if err == tcpip.ErrWouldBlock { <-conn.notifyCh continue } - return nil, fmt.Errorf("%s", err.String()) + return 0, fmt.Errorf("%s", err.String()) } - return ep, nil + n := len(buf) + if n > cap(rcv) { + n = cap(rcv) + } + rcv = append(rcv[:0], buf[:n]...) + return n, nil } } -func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *TcpConn { +func (conn *TcpConn) Write(snd []byte) error { + for { + _, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr}) + if err != nil { + if err == tcpip.ErrNoLinkAddress { + <-notifyCh + continue + } + return fmt.Errorf("%s", err.String()) + } + return nil + } +} + +// Listener tcp连接监听器 +type Listener struct { + raddr tcpip.FullAddress + ep tcpip.Endpoint + wq *waiter.Queue + we *waiter.Entry + notifyCh chan struct{} +} + +// Accept 封装tcp的accept操作 +func (l *Listener) Accept() (*TcpConn, error) { + l.wq.EventRegister(l.we, waiter.EventIn) + defer l.wq.EventUnregister(l.we) + for { + ep, wq, err := l.ep.Accept() + if err != nil { + if err == tcpip.ErrWouldBlock { + <-l.notifyCh + continue + } + return nil, fmt.Errorf("%s", err.String()) + } + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + return &TcpConn{ep: ep, + wq: wq, + we: &waitEntry, + notifyCh: notifyCh}, nil + } +} + +func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *Listener { var wq waiter.Queue // 新建一个tcp端 ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) @@ -202,7 +257,7 @@ func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Add } waitEntry, notifyCh := waiter.NewChannelEntry(nil) - return &TcpConn{ + return &Listener{ ep: ep, wq: &wq, we: &waitEntry, diff --git a/cmd/tcpclient/main.go b/cmd/tcpclient/main.go index a4f548d..f6cd1a9 100644 --- a/cmd/tcpclient/main.go +++ b/cmd/tcpclient/main.go @@ -17,7 +17,9 @@ func main() { log.Println("连接建立") conn.Write([]byte("helloworld")) log.Println("发送了数据") - conn.Close() + buf := make([]byte, 1024) + conn.Read(buf) + //conn.Close() }() t := time.NewTimer(1000 * time.Millisecond) diff --git a/tcpip/seqnum/seqnum.go b/tcpip/seqnum/seqnum.go index 7aee5f9..82494fb 100644 --- a/tcpip/seqnum/seqnum.go +++ b/tcpip/seqnum/seqnum.go @@ -21,7 +21,8 @@ func (v Value) LessThanEq(w Value) bool { // InRange v ∈ [a, b) func (v Value) InRange(a, b Value) bool { - return a <= v && v < b + //return a <= v && v < b + return v-a < b-a } // InWindows check v in [first, first+size) diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go index 78653fc..361ebdc 100644 --- a/tcpip/transport/tcp/accept.go +++ b/tcpip/transport/tcp/accept.go @@ -251,7 +251,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.RLock() if e.state == stateListen { e.acceptedChan <- n - e.waiterQueue.Notify(waiter.EventIn) + e.waiterQueue.Notify(waiter.EventIn) // 通知 Accept() 停止阻塞 } else { n.Close() } diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go index f346ffc..e077b53 100644 --- a/tcpip/transport/tcp/connect.go +++ b/tcpip/transport/tcp/connect.go @@ -215,7 +215,7 @@ func (h *handshake) handleSegment(s *segment) *tcpip.Error { if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { h.sndWnd <<= uint8(h.sndWndScale) } - log.Println(h.sndWnd) + //log.Println(h.sndWnd) switch h.state { case handshakeSynRcvd: @@ -311,8 +311,7 @@ func (h *handshake) execute() *tcpip.Error { } rt.Reset(timeOut) // 重新发送syn|ack报文 - //sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) - log.Println("超时重发了 xdm") + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: case wakerForNewSegment: @@ -469,7 +468,7 @@ func (e *endpoint) handleClose() *tcpip.Error { // handleSegments 从队列中取出 tcp 段数据,然后处理它们。 func (e *endpoint) handleSegments() *tcpip.Error { - log.Println("年轻人的第一条数据") + //log.Println("年轻人的第一条数据") checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { s := e.segmentQueue.dequeue() diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go index f5b99e0..950988a 100644 --- a/tcpip/transport/tcp/endpoint.go +++ b/tcpip/transport/tcp/endpoint.go @@ -66,6 +66,8 @@ type endpoint struct { // address). effectiveNetProtos []tcpip.NetworkProtocolNumber + hardError *tcpip.Error + // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -139,8 +141,49 @@ func (e *endpoint) Close() { log.Println("TODO 在写了 在写了") } +// Read 从tcp的接收队列中读取数据 func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return nil, tcpip.ControlMessages{}, nil + e.mu.RLock() + + e.rcvListMu.Lock() + bufUsed := e.rcvBufUsed + if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + e.rcvListMu.Unlock() + he := e.hardError + e.mu.RUnlock() + if s == stateError { + return buffer.View{}, tcpip.ControlMessages{}, he + } + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + } + + v, err := e.readLocked() + e.rcvListMu.Unlock() + e.mu.RUnlock() + return v, tcpip.ControlMessages{}, err +} + +// 从tcp的接收队列中读取数据,并从接收队列中删除已读数据 +func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { + if e.rcvBufUsed == 0 { + if e.rcvClosed || e.state != stateConnected { + return buffer.View{}, tcpip.ErrClosedForReceive + } + return buffer.View{}, tcpip.ErrWouldBlock + } + s := e.rcvList.Front() + views := s.data.Views() + v := views[s.viewToDeliver] + s.viewToDeliver++ + + if s.viewToDeliver >= len(views) { + e.rcvList.Remove(s) + s.decRef() + } + log.Println("读到了数据", views, v) + // TODO 流量检测 + + return v, nil } func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { @@ -175,8 +218,117 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocol return netProto, nil } +// Connect 这是客户端用的吧 func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error { - return nil + return e.connect(address, true, true) +} + +// connect将端点连接到其对等端。在正常的非S/R情况下,新连接应该运行主goroutine并执行握手。 +// 在恢复先前连接的端点时,将被动地创建两端(因此不会进行新的握手);对于应用程序尚未接受的堆栈接受连接, +// 它们将在不运行主goroutine的情况下进行恢复。 +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + defer func() { + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() + + connectingAddr := addr.Addr + + // 检查ipv4是否映射到ipv6 + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + + nicid := addr.NIC + // 判断连接的状态 + switch e.state { + case stateBound: + // If we're already bound to a NIC but the caller is requesting + // that we use a different one now, we cannot proceed. + if e.boundNICID == 0 { + break + } + + if nicid != 0 && nicid != e.boundNICID { + return tcpip.ErrNoRoute + } + + nicid = e.boundNICID + + case stateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID + // (if any) when we find a route. + + case stateConnecting: + // A connection request has already been issued but hasn't + // completed yet. + return tcpip.ErrAlreadyConnecting + + case stateConnected: + // The endpoint is already connected. If caller hasn't been notified yet, return success. + if !e.isConnectNotified { + e.isConnectNotified = true + return nil + } + // Otherwise return that it's already connected. + return tcpip.ErrAlreadyConnected + + case stateError: + return e.hardError + + default: + return tcpip.ErrInvalidEndpointState + } + + // Find a route to the desired destination. + // 根据目标ip查找路由信息 + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) + if err != nil { + return err + } + defer r.Release() + + origID := e.id + + netProtos := []tcpip.NetworkProtocolNumber{netProto} + e.id.LocalAddress = r.LocalAddress + e.id.RemoteAddress = r.RemoteAddress + e.id.RemotePort = addr.Port + + if e.id.LocalPort != 0 { + // 记录和检查原端口是否已被使用 + // The endpoint is bound to a port, attempt to register it. + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) + if err != nil { + return err + } + } else { + // TODO 需要添加 + } + + // Remove the port reservation. This can happen when Bind is called + // before Connect: in such a case we don't want to hold on to + // reservations anymore. + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.isPortReserved = false + } + + // 记录该端点的参数 + e.isRegistered = true + e.state = stateConnecting + e.route = r.Clone() + e.boundNICID = nicid + e.effectiveNetProtos = netProtos + e.connectingAddress = connectingAddr + + // TODO 需要添加 + + return tcpip.ErrConnectStarted } func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { @@ -238,7 +390,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { var n *endpoint select { - case n = <-e.acceptedChan: + case n = <-e.acceptedChan: // 外部再次调用后尝试取出ep log.Println("监听者进行一个新连接的分发", n.id) default: return nil, nil, tcpip.ErrWouldBlock @@ -343,7 +495,48 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return waiter.EventErr + result := waiter.EventMask(0) + + e.mu.RLock() + defer e.mu.RUnlock() + + switch e.state { + case stateInitial, stateBound, stateConnecting: + // Ready for nothing. + + case stateClosed, stateError: + // Ready for anything. + result = mask + + case stateListen: + // Check if there's anything in the accepted channel. + if (mask & waiter.EventIn) != 0 { + if len(e.acceptedChan) > 0 { + result |= waiter.EventIn + } + } + + case stateConnected: + // Determine if the endpoint is writable if requested. + if (mask & waiter.EventOut) != 0 { + e.sndBufMu.Lock() + if e.sndClosed || e.sndBufUsed < e.sndBufSize { + result |= waiter.EventOut + } + e.sndBufMu.Unlock() + } + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvListMu.Lock() + if e.rcvBufUsed > 0 || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvListMu.Unlock() + } + } + + return result } func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { @@ -385,6 +578,20 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C } +func (e *endpoint) readyToRead(s *segment) { + e.rcvListMu.Lock() + if s != nil { + s.incRef() + e.rcvBufUsed += s.data.Size() + e.rcvList.PushBack(s) + } else { + e.rcvClosed = true + } + e.rcvListMu.Unlock() + + e.waiterQueue.Notify(waiter.EventIn) +} + // receiveBufferAvailable calculates how many bytes are still available in the // receive buffer. // tcp流量控制:计算未被占用的接收缓存大小 diff --git a/tcpip/transport/tcp/rcv.go b/tcpip/transport/tcp/rcv.go index 1d67a82..75f5b58 100644 --- a/tcpip/transport/tcp/rcv.go +++ b/tcpip/transport/tcp/rcv.go @@ -5,18 +5,81 @@ import ( "netstack/tcpip/seqnum" ) -type receiver struct{} +type receiver struct { + ep *endpoint + rcvNxt seqnum.Value // 准备接收的下一个报文序列号 + closed bool +} // 新建并初始化接收器 func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { - r := &receiver{} + r := &receiver{ + ep: ep, + rcvNxt: irs + 1, + } return r } +// tcp流量控制:判断 segSeq 在窗口內 +func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + // TODO 流量控制 + return true +} + +func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool { + if segLen > 0 { + // 我们期望接收到的序列号范围应该是 seqStart <= rcvNxt < seqEnd, + // 如果不在这个范围内说明我们少了数据段,返回false,表示不能立马消费 + if !r.rcvNxt.InWindows(segSeq, segLen) { + return false + } + // 尝试去除已经确认过的数据 + if segSeq.LessThan(r.rcvNxt) { + log.Println("收到重复数据") + diff := segSeq.Size(r.rcvNxt) + segLen -= diff + segSeq.UpdateForward(diff) + s.sequenceNumber.UpdateForward(diff) + s.data.TrimFront(int(diff)) + } + // 将tcp段插入接收链表,并通知应用层用数据来了 + r.ep.readyToRead(s) + } else if segSeq != r.rcvNxt { // 空数据 还是非顺序到达 丢弃 + return false + } + + // 如果收到 fin 报文 + if s.flagIsSet(flagFin) { + // TODO 处理fin报文 + } + + return true +} + // handleRcvdSegment handles TCP segments directed at the connection managed by // r as they arrive. It is called by the protocol main loop. // 从 handleSegments 接收到tcp段,然后进行处理消费,所谓的消费就是将负载内容插入到接收队列中 func (r *receiver) handleRcvdSegment(s *segment) { - log.Println(s.data) + if r.closed { + return + } + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // TODO tcp流量控制 + // tcp流量控制:判断该数据段的序列号是否在接收窗口内,如果不在,立即返回ack给对端。 + if !r.acceptable(segSeq, segLen) { + r.ep.snd.sendAck() + return + } + + log.Println(s.data, segLen, segSeq) + + // Defer segment processing if it can't be consumed now. + // tcp可靠性:r.consumeSegment 返回值是个bool类型,如果是true,表示已经消费该数据段, + // 如果不是,那么进行下面的处理,插入到 pendingRcvdSegments,且进行堆排序 + if !r.consumeSegment(s, segSeq, segLen) { + return + } } diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go index 3ba6026..6c19ac4 100644 --- a/tcpip/transport/tcp/snd.go +++ b/tcpip/transport/tcp/snd.go @@ -1,6 +1,9 @@ package tcp -import "netstack/tcpip/seqnum" +import ( + "log" + "netstack/tcpip/seqnum" +) type sender struct { } @@ -10,3 +13,7 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s := &sender{} return s } + +func (s *sender) sendAck() { + log.Fatal("TODO 需要发送一个ack") +}