diff --git a/cmd/netstack/TcpConn.go b/cmd/netstack/TcpConn.go index 8739118..2e74f56 100644 --- a/cmd/netstack/TcpConn.go +++ b/cmd/netstack/TcpConn.go @@ -10,6 +10,7 @@ import ( "netstack/tcpip/stack" "netstack/tcpip/transport/tcp" "netstack/waiter" + "time" ) // Dial 呼叫tcp服务端 @@ -36,6 +37,11 @@ func Dial(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, } } + ep.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) + ep.SetSockOpt(tcpip.KeepaliveIntervalOption(75 * time.Second)) + ep.SetSockOpt(tcpip.KeepaliveIdleOption(30 * time.Second)) // 30s的探活心跳 + ep.SetSockOpt(tcpip.KeepaliveCountOption(9)) + return &TcpConn{ ep: ep, wq: &wq, diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index 66ecc8f..6f23ada 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -147,23 +147,30 @@ func main() { } log.Println("服务端 建立连接") - go func() { + go func(*TcpConn) { cnt := 0 - time.Sleep(2 * time.Second) + time.Sleep(1 * time.Second) for { // 一个慢读者 才能体现出网络的情况 buf := make([]byte, 1024) n, err := conn.Read(buf) if err != nil { - log.Println(n, err) + // TODO 添加一个 error 表明无法继续读取 对端要求关闭 break } cnt+=n logger.NOTICE("服务端读取了数据", fmt.Sprintf("n: %d, cnt: %d", n, cnt), string(buf)) - //conn.Write([]byte("Hello Client")) } - }() + + // 我端收到了 fin 关闭读 继续写 + conn.Write([]byte("Bye Client")) + // 我端向对端发一个终止报文 + conn.ep.Close() + log.Println("服务端 结束读取") + + }(conn) } + }() go func() { @@ -174,31 +181,23 @@ func main() { if err != nil { log.Fatal(err) } - log.Printf("客户端 建立连接\n") - conn.SetSockOpt(tcpip.KeepaliveEnabledOption(1)) - conn.SetSockOpt(tcpip.KeepaliveIntervalOption(75 * time.Second)) - conn.SetSockOpt(tcpip.KeepaliveIdleOption(30 * time.Second)) // 30s的探活心跳 - conn.SetSockOpt(tcpip.KeepaliveCountOption(9)) + log.Printf("客户端 建立连接\n\n客户端 写入数据\n") - log.Printf("\n\n客户端 写入数据") - - cnt := 0 - for i := 0; i < 20; i++ { + for i := 0; i < 3; i++ { conn.Write(make([]byte, 1<<(5))) - cnt += 1<<(5) - //buf := make([]byte, 1024) - //n, err := conn.Read(buf) - //if err != nil { - // log.Println(err) - // break - //} - //logger.NOTICE(string(buf[:n])) } - logger.NOTICE("写完了", fmt.Sprintf("共计写入: %d", cnt)) - conn.Close() + + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + log.Println(err) + return + } + logger.NOTICE(string(buf[:n])) + }() //l, err := net.Listen("tcp", "127.0.0.1:9999") diff --git a/tcpip/link/loopback/loopback.go b/tcpip/link/loopback/loopback.go index 79c7e82..9934430 100644 --- a/tcpip/link/loopback/loopback.go +++ b/tcpip/link/loopback/loopback.go @@ -1,8 +1,6 @@ package loopback import ( - "fmt" - "netstack/logger" "netstack/tcpip" "netstack/tcpip/buffer" "netstack/tcpip/stack" @@ -51,10 +49,10 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b //time.Sleep(time.Duration(rand.Intn(50)+50) * time.Millisecond) e.count++ - if e.count == 6 { // 丢掉客户端写入的第二个包 - logger.NOTICE(fmt.Sprintf("统计 %d 丢掉这个报文", e.count)) - return nil - } + //if e.count == 6 { // 丢掉客户端写入的第二个包 + // logger.NOTICE(fmt.Sprintf("统计 %d 丢掉这个报文", e.count)) + // return nil + //} // Because we're immediately turning around and writing the packet back to the // rx path, we intentionally don't preserve the remote and local link // addresses from the stack.Route we're passed. diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go index 692e3f0..1f7f776 100644 --- a/tcpip/tcpip.go +++ b/tcpip/tcpip.go @@ -7,6 +7,7 @@ import ( "netstack/waiter" "reflect" "strings" + "sync" "sync/atomic" "time" ) @@ -638,3 +639,36 @@ func (a Address) String() string { return fmt.Sprintf("%s", string(a)) } } + + +// danglingEndpointsMu protects access to danglingEndpoints. +// 一个摇摆不定的端点 一个行将就木的端点 +var danglingEndpointsMu sync.Mutex + +// danglingEndpoints tracks all dangling endpoints no longer owned by the app. +var danglingEndpoints = make(map[Endpoint]struct{}) + +// GetDanglingEndpoints returns all dangling endpoints. +func GetDanglingEndpoints() []Endpoint { + es := make([]Endpoint, 0, len(danglingEndpoints)) + danglingEndpointsMu.Lock() + for e := range danglingEndpoints { + es = append(es, e) + } + danglingEndpointsMu.Unlock() + return es +} + +// AddDanglingEndpoint adds a dangling endpoint. +func AddDanglingEndpoint(e Endpoint) { + danglingEndpointsMu.Lock() + danglingEndpoints[e] = struct{}{} + danglingEndpointsMu.Unlock() +} + +// DeleteDanglingEndpoint removes a dangling endpoint. +func DeleteDanglingEndpoint(e Endpoint) { + danglingEndpointsMu.Lock() + delete(danglingEndpoints, e) + danglingEndpointsMu.Unlock() +} diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go index 6642d4d..05f26a2 100644 --- a/tcpip/transport/tcp/connect.go +++ b/tcpip/transport/tcp/connect.go @@ -172,12 +172,11 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Resolution not completed. Keep trying... case wakerForNotification: - // TODO - //n := h.ep.fetchNotifications() - //if n¬ifyClose != 0 { - // h.ep.route.RemoveWaker(resolutionWaker) - // return tcpip.ErrAborted - //} + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + h.ep.route.RemoveWaker(resolutionWaker) + return tcpip.ErrAborted + } //if n¬ifyDrain != 0 { // close(h.ep.drainDone) // <-h.ep.undrain @@ -608,8 +607,8 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise logger.GetInstance().Info(logger.TCP, func() { }) - log.Printf("TCP 发送 [%s] 报文片段到 %s, seq: %d, ack: %d, 可接收rcvWnd: %d", - flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort), + log.Printf("TCP :%d 发送 [%s] 报文片段到 %s, seq: %d, ack: %d, 可接收rcvWnd: %d", + id.LocalPort, flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort), seq, ack, rcvWnd) return r.WritePacket(hdr, data, ProtocolNumber, ttl) @@ -703,12 +702,32 @@ func (e *endpoint) handleClose() *tcpip.Error { e.handleWrite() // Mark send side as closed. - // 标记发送器关闭 + // 标记发送器关闭 标记过之后 e.rcv.closed && e.snd.closed 主循环将会退出 e.snd.closed = true return nil } +func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { + e.sendRaw(buffer.VectorisedView{}, flagRst|flagAck, e.snd.sndUna, e.rcv.rcvNxt, 0) + e.state = stateError + e.hardError = err +} + +func (e *endpoint) completeWorkerLocked() { + e.workerRunning = false // 标记当前goroutine已经停运 + if e.workerCleanup { + //if e.id.LocalPort != 9999 { + // logger.NOTICE("客户端开始清理资源") + // log.Println(e.snd.sndUna , e.snd.sndNxtList) + //} else { + // logger.NOTICE("服务端开始清理资源") + // log.Println(e.snd.sndUna , e.snd.sndNxtList) + //} + e.cleanupLocked() + } +} + // handleSegments 从队列中取出 tcp 段数据,然后处理它们。 func (e *endpoint) handleSegments() *tcpip.Error { checkRequeue := true @@ -753,7 +772,7 @@ func (e *endpoint) handleSegments() *tcpip.Error { // information." // 处理tcp数据段,同时给接收器和发送器 // 为何要给发送器传接收到的数据段呢?主要是为了滑动窗口的滑动和拥塞控制处理 - e.rcv.handleRcvdSegment(s) + e.rcv.handleRcvdSegment(s) // 在收到fin报文后 将不再接受任何报文 e.snd.handleRcvdSegment(s) } s.decRef() // 该segment处理完成 @@ -829,11 +848,22 @@ func (e *endpoint) disableKeepaliveTimer() { // protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行,负责握手、发送段和处理收到的段 func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { + var closeTimer *time.Timer + var closeWaker sleep.Waker // 收尾工作 // 收尾的一些工作 epilogue := func() { // e.mu is expected to be hold upon entering this section. + if e.snd != nil { + e.snd.resendTimer.cleanup() // 放弃所有重发报文 + } + + if closeTimer != nil { + closeTimer.Stop() // 正常结束 MainLoop + } + + e.completeWorkerLocked() // TODO 需要添加 e.mu.Unlock() @@ -907,6 +937,12 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { w: &e.newSegmentWaker, f: e.handleSegments, }, + { + w: &closeWaker, + f: func() *tcpip.Error { + return tcpip.ErrConnectionAborted // 如果在3s内没有正常结束四次挥手 将强制结束连接 + }, + }, { w: &e.snd.resendWaker, f: func() *tcpip.Error { @@ -939,12 +975,18 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } if n¬ifyReset != 0 { - + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() } - //if n¬ifyClose != 0 && closeTimer == nil { - // - //} + if n¬ifyClose != 0 && closeTimer == nil { + // Reset the connection 3 seconds after the + // endpoint has been closed. + closeTimer = time.AfterFunc(3*time.Second, func() { + closeWaker.Assert() + }) + } if n¬ifyDrain != 0 { } @@ -977,12 +1019,32 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { } e.rcvListMu.Unlock() - // TODO 需要添加 workerCleanup + e.mu.RLock() + if e.workerCleanup { + e.notifyProtocolGoroutine(notifyClose) + } + e.mu.RUnlock() + // 主循环,处理tcp报文 // 要使这个主循环结束,也就是tcp连接完全关闭,得同时满足三个条件: // 1,接收器关闭了 2,发送器关闭了 3,下一个未确认的序列号等于添加到发送列表的下一个段的序列号 - //for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + // + // 对于服务端而言: + // 1. 在收到 FIN 报文后 在handleSegment 中处理报文时 handleRcvSegment -> consumeSegment 将e.rcv.closed + // 2. 在用户层主动调用 Close 时 在Shutdown 中 唤醒 e.sndCloseWaker 执行 handleClose 将e.snd.closed + // 3. 在用户层主动调用 Close 后 将会发送给 客户端 一个 FIN 报文 当收到正确的客户端ack时 + // 如果 e.snd.sndUna == e.snd.sndNxtList 也就是没有可以发送的数据了 服务端就可以退出了 + // + // 对于客户端而言: + // 1. 应用层主动调用了 Close -> Shutdown 唤醒e.sndCloseWaker 执行 handleClose + // 将e.snd.closed snd设置close并不是关闭写 将会发送给 服务端 一个 FIN 报文 当收到正确的服务端端ack时 客户端不直接退出 + // 而是等待服务端的后续数据 并且去回复对应的ack 但是服务端并不会去消费这些ack + // NOTE 这里仅仅是不通知上层用户程序消费 底层的重发机制什么的都还在工作 因此仍然是可靠传输 + // 2. 当客户端收到来自服务端的 FIN 报文的时候 在handleSegment 中处理报文时 + // handleRcvSegment -> consumeSegment 将e.rcv.closed + // 3. 在收完完 FIN 报文后 e.snd.sndUna == e.snd.sndNxtList 也就是没有可以发送的数据了 客户端就可以退出了 + // for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { e.workMu.Unlock() // s.Fetch 会返回事件的index,比如 v=0 的话, @@ -992,7 +1054,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { e.workMu.Lock() if err := funcs[v].f(); err != nil { e.mu.Lock() - //e.resetConnectionLocked(err) + e.resetConnectionLocked(err) // Lock released below. epilogue() log.Println(err) diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go index dfb87a1..ce5e227 100644 --- a/tcpip/transport/tcp/endpoint.go +++ b/tcpip/transport/tcp/endpoint.go @@ -122,6 +122,11 @@ type endpoint struct { // workerRunning specifies if a worker goroutine is running. workerRunning bool + // workerCleanup specifies if the worker goroutine must perform cleanup + // before exitting. This can only be set to true when workerRunning is + // also true, and they're both protected by the mutex. + workerCleanup bool + // sendTSOk is used to indicate when the TS Option has been negotiated. // When sendTSOk is true every non-RST segment should carry a TS as per // RFC7323#section-1.1 @@ -277,6 +282,8 @@ func (e *endpoint) Close() { // for reuse after Close() is called. If also registered, it means this // is a listening socket, so we must unregister as well otherwise the // next user would fail in Listen() when trying to register. + // 释放绑定端口 客户端释放随机绑定的port + // 注销协议栈中的端点 if e.isPortReserved { e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) e.isPortReserved = false @@ -287,10 +294,44 @@ func (e *endpoint) Close() { } } - logger.TODO("添加清理资源的逻辑") + tcpip.AddDanglingEndpoint(e) + if !e.workerRunning { // workerRunning 监听者 客户端 tcp连接 都会设置 + e.cleanupLocked() + } else { + e.workerCleanup = true // 在端点调用了 Close 后将会走这个分支 + e.notifyProtocolGoroutine(notifyClose) + } e.mu.Unlock() } +// cleanupLocked frees all resources associated with the endpoint. It is called +// after Close() is called and the worker goroutine (if any) is done with its +// work. +func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by + // the client. + if e.acceptedChan != nil { // 监听者 + close(e.acceptedChan) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + e.acceptedChan = nil + } + e.workerCleanup = false + + // 注销掉这个端点 + if e.isRegistered { + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + } + + // 释放掉这个路由 + e.route.Release() + tcpip.DeleteDanglingEndpoint(e) +} + // Read 从tcp的接收队列中读取数据 func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.mu.RLock() @@ -596,8 +637,9 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { defer e.mu.Unlock() e.shutdownFlags |= flags + switch e.state { - case stateConnected: // 客户端关闭 + case stateConnected: // tcp连接关闭 // 不能直接关闭读数据包,因为关闭连接的时候四次挥手还需要读取报文。 if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { e.rcvListMu.Lock() @@ -605,7 +647,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.rcvListMu.Unlock() if rcvBufUsed > 0 { // 如果接收队列中还有数据 通知对端RESET - logger.TODO("通知对端RESET") + e.notifyProtocolGoroutine(notifyReset) return nil } } @@ -617,6 +659,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { break } + // 发送一个 FIN 报文 告知对面关闭上层用户程序 // Queue fin segment. s := newSegmentFromView(&e.route, e.id, nil) e.sndQueue.PushBack(s) @@ -627,8 +670,8 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { // 触发调用 handleClose e.sndCloseWaker.Assert() - case stateListen: // 服务端关闭 - logger.FIXME("添加服务端关闭逻辑") + case stateListen: // 监听器关闭 + logger.FIXME("添加监听器关闭逻辑") default: return tcpip.ErrNotConnected } diff --git a/tcpip/transport/tcp/rcv.go b/tcpip/transport/tcp/rcv.go index 4fb0bdd..e619bce 100644 --- a/tcpip/transport/tcp/rcv.go +++ b/tcpip/transport/tcp/rcv.go @@ -181,8 +181,9 @@ func (r *receiver) handleRcvdSegment(s *segment) { // tcp的可靠性:通过使用当前段,我们可能填补了序列号域中的间隙,该间隙允许现在使用待处理段。 // 所以试着去消费等待处理段。 + // 当进行关闭操作的时候 只关写 不关读 for !r.closed && r.pendingRcvdSegments.Len() > 0 { - //log.Fatal("出现空隙端", r.pendingRcvdSegments.Len()) + //log.Fatal("出现空隙段", r.pendingRcvdSegments.Len()) s := r.pendingRcvdSegments[0] segLen := seqnum.Size(s.data.Size()) segSeq := s.sequenceNumber diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go index 86f95f6..5272442 100644 --- a/tcpip/transport/tcp/snd.go +++ b/tcpip/transport/tcp/snd.go @@ -20,7 +20,7 @@ const ( // InitialCwnd is the initial congestion window. // 初始拥塞窗口大小 - InitialCwnd = 4 + InitialCwnd = 10 // nDupAckThreshold is the number of duplicate ACK's required // before fast-retransmit is entered. @@ -593,14 +593,14 @@ func (s *sender) sendData() { if !dataSent { // 没有成功发送任何数据 dataSent = true - // TODO + s.ep.disableKeepaliveTimer() } // 发送包 开始计算RTT s.sendSegment(seg.data, seg.flags, seg.sequenceNumber) // 发送一个数据段后,更新sndNxt if s.sndNxt.LessThan(segEnd) { - log.Println("更新sndNxt", s.sndNxt, " 为 ", segEnd, "下一次发送的数据头为", segEnd) + log.Println(s.ep.id.LocalPort, " 更新sndNxt", s.sndNxt, " 为 ", segEnd, "下一次发送的数据头为", segEnd) s.sndNxt = segEnd } }