diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index 308835f..882b099 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -147,11 +147,12 @@ func main() { go func() { for { buf := make([]byte, 1024) - if _, err := conn.Read(buf); err != nil { + n, err := conn.Read(buf) + if err != nil { log.Println(err) break } - fmt.Println("data: ", len(buf), string(buf)) + fmt.Println("data: ", n, len(buf), string(buf)) // conn.Write([]byte("Server echo")) //} } @@ -160,16 +161,20 @@ func main() { }() <-done + go func() { port := localPort conn, err := Dial(s, header.IPv4ProtocolNumber, addr, port) if err != nil { log.Fatal(err) } - log.Println("客户端 建立连接") - buf := make([]byte, 1<<21) + log.Printf("客户端 建立连接\n") + + time.Sleep(time.Second) + log.Printf("\n\n客户端 写入数据") + buf := make([]byte, 1<<17) conn.Write(buf) - time.Sleep(3 * time.Second) + time.Sleep(1 * time.Minute) conn.Close() }() @@ -237,7 +242,7 @@ func (conn *TcpConn) Read(rcv []byte) (int, error) { n = cap(rcv) } rcv = append(rcv[:0], buf[:n]...) - return n, nil + return len(buf), nil } } diff --git a/tcpip/seqnum/seqnum.go b/tcpip/seqnum/seqnum.go index 53d33ae..36df6ac 100644 --- a/tcpip/seqnum/seqnum.go +++ b/tcpip/seqnum/seqnum.go @@ -21,11 +21,10 @@ func (v Value) LessThanEq(w Value) bool { // InRange v ∈ [a, b) func (v Value) InRange(a, b Value) bool { - //return a <= v && v < b - return v-a < b-a + return v-a < b-a // 注意 uint32(-1) > uint32(0) } -// InWindows check v in [first, first+size) +// InWindow check v in [first, first+size) func (v Value) InWindow(first Value, size Size) bool { return v.InRange(first, first.Add(size)) } diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go index b39209c..bea155f 100644 --- a/tcpip/transport/tcp/connect.go +++ b/tcpip/transport/tcp/connect.go @@ -82,10 +82,10 @@ const ( // 根据这些参数生成 syn+ack 报文的选项参数,附在 tcp 选项中,回复给主机 A。 func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { h := handshake{ - ep: ep, - active: true, // 激活这个管理器 - rcvWnd: rcvWnd, // 初始接收窗口 - // TODO + ep: ep, + active: true, // 激活这个管理器 + rcvWnd: rcvWnd, // 初始接收窗口 + rcvWndScale: FindWndScale(rcvWnd), // 接收窗口扩展因子 } if err := h.resetState(); err != nil { return handshake{}, err @@ -93,6 +93,25 @@ func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { return h, nil } +// FindWndScale determines the window scale to use for the given maximum window +// size. +// 因为窗口的大小不能超过序列号范围的一半,即窗口最大2^30, +// so (2^16)*(2^maxWnsScale) < 2^30,get maxWnsScale = 14 +func FindWndScale(wnd seqnum.Size) int { + if wnd < 0x10000 { + return 0 + } + + max := seqnum.Size(0xffff) + s := 0 + for wnd > max && s < header.MaxWndScale { + s++ + max <<= 1 + } + + return s +} + func (h *handshake) resetState() *tcpip.Error { // 随机一个iss(对方将收到的序号) 防止黑客搞事 b := make([]byte, 4) @@ -301,7 +320,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { func (h *handshake) handleSegment(s *segment) *tcpip.Error { h.sndWnd = s.window if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { - h.sndWnd <<= uint8(h.sndWndScale) + h.sndWnd <<= uint8(h.sndWndScale) // 收紧窗口 } switch h.state { @@ -619,13 +638,13 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn func (e *endpoint) handleWrite() *tcpip.Error { e.sndBufMu.Lock() - // 得到第一个tcp段 + // 得到第一个tcp段 注意并不是取出只是查看 first := e.sndQueue.Front() if first != nil { // 向发送链表添加元素 e.snd.writeList.PushBackList(&e.sndQueue) // NOTE 更新发送队列下一个发送字节的序号 一次性将链表全部取用 - // 当有新的数据需要发送时会有相关逻辑更新这个数值 + // 当有新的数据需要发送时会有相逻辑更新这个数值 e.snd.sndNxtList.UpdateForward(e.sndBufInQueue) e.sndBufInQueue = 0 } @@ -684,7 +703,6 @@ func (e *endpoint) handleSegments() *tcpip.Error { // Patch the window size in the segment according to the // send window scale. s.window <<= e.snd.sndWndScale - // If the timestamp option is negotiated and the segment // does not carry a timestamp option then the segment // must be dropped as per @@ -713,8 +731,8 @@ func (e *endpoint) handleSegments() *tcpip.Error { // tcp可靠性:累积确认 // 如果发送的最大ack不等于下一个接收的序列号,发送ack - log.Println("============", e.rcv.rcvNxt, e.snd.maxSentAck, "=============") if e.rcv.rcvNxt != e.snd.maxSentAck { + fmt.Printf("\n\n=======ACK=======%d=======ACK======\n\n", e.rcv.rcvNxt-e.snd.maxSentAck) e.snd.sendAck() } diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go index 9ba7dd3..7c9b94f 100644 --- a/tcpip/transport/tcp/endpoint.go +++ b/tcpip/transport/tcp/endpoint.go @@ -818,6 +818,19 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C } +// 当收到ack确认时 需要更新发送确认缓冲占用 +func (e *endpoint) updateSndBufferUsage(v int) { + e.sndBufMu.Lock() + notify := e.sndBufUsed >= e.sndBufSize>>1 + e.sndBufUsed -= v + notify = notify && e.sndBufUsed < e.sndBufSize>>1 + e.sndBufMu.Unlock() + if notify { // 如果缓存中剩余的数据过多是不需要补充的 + log.Fatal("缓存中剩余的数据", e.sndBufUsed, notify) + e.waiterQueue.Notify(waiter.EventOut) + } +} + func (e *endpoint) readyToRead(s *segment) { e.rcvListMu.Lock() if s != nil { @@ -840,14 +853,11 @@ func (e *endpoint) receiveBufferAvailable() int { size := e.rcvBufSize used := e.rcvBufUsed e.rcvListMu.Unlock() - // We may use more bytes than the buffer size when the receive buffer // shrinks. if used >= size { return 0 } - - log.Println("Init Recv Windeow Size: ", size-used) return size - used } diff --git a/tcpip/transport/tcp/rcv.go b/tcpip/transport/tcp/rcv.go index 6e776ac..fdb4134 100644 --- a/tcpip/transport/tcp/rcv.go +++ b/tcpip/transport/tcp/rcv.go @@ -58,6 +58,7 @@ func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { r.rcvAcc = acc } + log.Println("-------------", n, acc, r.rcvWndScale) return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale } diff --git a/tcpip/transport/tcp/segment.go b/tcpip/transport/tcp/segment.go index d938b48..a6b718b 100644 --- a/tcpip/transport/tcp/segment.go +++ b/tcpip/transport/tcp/segment.go @@ -114,6 +114,20 @@ func (s *segment) incRef() { atomic.AddInt32(&s.refCnt, 1) } +// logicalLen is the segment length in the sequence number space. It's defined +// as the data length plus one for each of the SYN and FIN bits set. +// 计算tcp段的逻辑长度,包括负载数据的长度,如果有控制标记,需要加1 +func (s *segment) logicalLen() seqnum.Size { + l := seqnum.Size(s.data.Size()) + if s.flagIsSet(flagSyn) { + l++ + } + if s.flagIsSet(flagFin) { + l++ + } + return l +} + func (s *segment) parse() bool { h := header.TCP(s.data.First()) offset := int(h.DataOffset()) diff --git a/tcpip/transport/tcp/snd.go b/tcpip/transport/tcp/snd.go index 1c3af44..c6b54c3 100644 --- a/tcpip/transport/tcp/snd.go +++ b/tcpip/transport/tcp/snd.go @@ -9,6 +9,7 @@ import ( "netstack/tcpip/seqnum" "sync" "time" + "unsafe" ) const ( @@ -253,6 +254,7 @@ func (s *sender) sendSegment(data buffer.VectorisedView, flags byte, seq seqnum. // Remember the max sent ack. s.maxSentAck = rcvNxt + log.Println(s.ep.id.LocalPort, "要求扩展窗口", s.sndWnd) return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) } @@ -262,6 +264,62 @@ func (s *sender) handleRcvdSegment(seg *segment) { // 因此发送更多数据。如果需要,这也将重新启用重传计时器。 // 存放当前窗口大小。 s.sndWnd = seg.window + log.Println(s.ep.id.LocalPort, "移动窗口", s.sndWnd) + // 获取确认号 + ack := seg.ackNumber + // 如果ack在最小未确认的seq和segNext之间 + if (ack - 1).InRange(s.sndUna, s.sndNxt) { + log.Printf("[...XXXXXX]-[%d|\t%d\t|%d]==>", s.sndNxt, ack-1, s.sndUna) + if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + // TSVal/Ecr values sent by Netstack are at a millisecond + // granularity. + //elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + //s.updateRTO(elapsed) + } + // 获取这次确认的字节数,即 ack - snaUna + acked := s.sndUna.Size(ack) + // 更新下一个未确认的序列号 + s.sndUna = ack + + ackLeft := acked + //originalOutstanding := s.outstanding + // 从发送链表中删除已经确认的数据,发送窗口的滑动。 + //log.Printf("[...XXXXXX]-[%d|\t\t|%d]==>", s.sndNxt, s.sndUna) + for ackLeft > 0 { // 有成功确认的数据 丢弃它们 有剩余数据的话继续发送(根据拥塞策略控制) + seg := s.writeList.Front() + datalen := seg.logicalLen() + + if datalen > ackLeft { + seg.data.TrimFront(int(ackLeft)) + break + } + + log.Println(s.writeNext == seg) + if s.writeNext == seg { + log.Fatal("更新 下一段") + s.writeNext = seg.Next() + } + // 从发送链表中删除已确认的tcp段。 + s.writeList.Remove(seg) + // 因为有一个tcp段确认了,所以 outstanding 减1 + s.outstanding-- + seg.decRef() + ackLeft -= datalen + } + // 当收到ack确认时,需要更新发送缓冲占用 + s.ep.updateSndBufferUsage(int(acked)) + + // 如果发生超时重传时,s.outstanding可能会降到零以下, + // 重置为零但后来得到一个覆盖先前发送数据的确认。 + if s.outstanding < 0 { + s.outstanding = 0 + } + } + + // TODO tcp拥塞控制 + if s.writeList.Front() != nil { + log.Println("确认成功 继续发送") + } s.sendData() } @@ -297,6 +355,7 @@ func (s *sender) sendData() { panic("Netstack queues FIN segments without data.") } if !seg.sequenceNumber.LessThan(end) { + log.Println("暂停数据发送", seg.sequenceNumber, end) break } @@ -309,18 +368,19 @@ func (s *sender) sendData() { // 如果seg的payload字节数大于available // 将seg进行分段,并且插入到该seg的后面 if seg.data.Size() > available { + log.Println("-------------------------------------分段!!!", seg.data.Size(), available, end) nSeg := seg.clone() - nSeg.data.TrimFront(available) nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) s.writeList.InsertAfter(seg, nSeg) seg.data.CapLength(available) } s.outstanding++ - log.Println("发送窗口一开始是", s.sndWnd, - "最多发送数据", available, dataSent, - "发送端缓存包数量", s.outstanding) segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + log.Println("发送窗口一开始是", s.sndWnd, + "最多发送数据", available, + "缓存数据尾", segEnd, + "发送端缓存包数量", s.outstanding) } if !dataSent { // 上面有个break能跳过这一步 @@ -331,12 +391,18 @@ func (s *sender) sendData() { s.sendSegment(seg.data, seg.flags, seg.sequenceNumber) // 发送一个数据段后,更新sndNxt if s.sndNxt.LessThan(segEnd) { + log.Println("更新sndNxt", s.sndNxt, segEnd) s.sndNxt = segEnd } } - // Remember the next segment we'll write. s.writeNext = seg + if seg != nil { + log.Println("-------------------------------------分段!!!", s.writeNext.data.Size()) + log.Println(unsafe.Pointer(seg), seg.data.Size()) + } + + time.Sleep(200 * time.Millisecond) // TODO 启动定时器 }