From 0eff0e912faff11a8c7d5a7ee6fbd05c37e5b80e Mon Sep 17 00:00:00 2001 From: impact-eintr Date: Mon, 12 Dec 2022 15:46:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9C=AC=E5=9C=B0=E7=8E=AF?= =?UTF-8?q?=E5=9B=9E=E7=BD=91=E5=8D=A1=E8=AE=BE=E5=A4=87=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/netstack/main.go | 58 +++++++++++++++++++++++++++--- cmd/tcpclient/main.go | 1 + tcpip/link/loopback/loopback.go | 62 +++++++++++++++++++++++++++++++++ tcpip/transport/tcp/README.md | 27 ++++++++++++++ tcpip/transport/tcp/accept.go | 9 ++--- tcpip/transport/tcp/connect.go | 33 ++++++++++++++---- tcpip/transport/tcp/endpoint.go | 49 ++++++++++++++++++++++++-- tcpip/transport/tcp/rcv.go | 7 +++- 8 files changed, 229 insertions(+), 17 deletions(-) create mode 100644 tcpip/link/loopback/loopback.go diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index 54ae1a5..4764c84 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -7,7 +7,9 @@ import ( "net" "netstack/logger" "netstack/tcpip" + "netstack/tcpip/header" "netstack/tcpip/link/fdbased" + "netstack/tcpip/link/loopback" "netstack/tcpip/link/tuntap" "netstack/tcpip/network/arp" "netstack/tcpip/network/ipv4" @@ -95,12 +97,16 @@ func main() { ResolutionRequired: true, }) + _ = linkID + + loopbackLinkID := loopback.New() + // 新建相关协议的协议栈 s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{}) // 新建抽象的网卡 - if err := s.CreateNamedNIC(1, "vnic1", linkID); err != nil { + if err := s.CreateNamedNIC(1, "vnic1", loopbackLinkID); err != nil { log.Fatal(err) } @@ -124,9 +130,12 @@ func main() { }, }) + done := make(chan struct{}, 2) + //logger.SetFlags(logger.TCP) go func() { // echo server listener := tcpListen(s, proto, addr, localPort) + done <- struct{}{} conn, err := listener.Accept() if err != nil { log.Println(err) @@ -135,10 +144,10 @@ func main() { for { buf := make([]byte, 1024) if _, err := conn.Read(buf); err != nil { - log.Fatal(err) + log.Println(err) + break } fmt.Println(string(buf)) - //if string(buf) != "" { // conn.Write([]byte("Server echo")) //} } @@ -147,11 +156,52 @@ func main() { select {} }() + <-done + go func() { + port := localPort + _, err := Dial(s, header.IPv4ProtocolNumber, addr, port) + if err != nil { + log.Fatal(err) + } + }() + + close(done) + c := make(chan os.Signal) signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2) <-c } +// Dial 呼叫tcp服务端 +func Dial(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, port int) (*TcpConn, error) { + remote := tcpip.FullAddress{ + Addr: addr, + Port: uint16(port), + } + var wq waiter.Queue + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + // 新建一个tcp端 + ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) + if err != nil { + return nil, fmt.Errorf("%s", err.String()) + } + err = ep.Connect(remote) + if err != nil { + if err == tcpip.ErrConnectStarted { + <-notifyCh + } else { + return nil, fmt.Errorf("%s", err.String()) + } + } + + return &TcpConn{ + ep: ep, + wq: &wq, + we: &waitEntry, + notifyCh: notifyCh}, nil +} + +// TcpConn 一条tcp连接 type TcpConn struct { raddr tcpip.FullAddress ep tcpip.Endpoint @@ -235,7 +285,7 @@ func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Add // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP // 此时就会调用端口管理器 - if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: addr, Port: uint16(localPort)}, nil); err != nil { + if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: "", Port: uint16(localPort)}, nil); err != nil { log.Fatal("Bind failed: ", err) } diff --git a/cmd/tcpclient/main.go b/cmd/tcpclient/main.go index cfa6a1d..d5b8564 100644 --- a/cmd/tcpclient/main.go +++ b/cmd/tcpclient/main.go @@ -12,6 +12,7 @@ func main() { fmt.Println("err : ", err) return } + conn.Write([]byte("hello world")) //buf := make([]byte, 1024) //conn.Read(buf) if err = conn.Close(); err != nil { diff --git a/tcpip/link/loopback/loopback.go b/tcpip/link/loopback/loopback.go new file mode 100644 index 0000000..b898a54 --- /dev/null +++ b/tcpip/link/loopback/loopback.go @@ -0,0 +1,62 @@ +package loopback + +import ( + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/stack" +) + +type endpoint struct { + dispatcher stack.NetworkDispatcher +} + +func New() tcpip.LinkEndpointID { + return stack.RegisterLinkEndpoint(&endpoint{}) +} + +func (e *endpoint) MTU() uint32 { + return 65536 +} + +// Capabilities返回链路层端点支持的功能集。 +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityChecksumOffload | stack.CapabilitySaveRestore | stack.CapabilityLoopback +} + +// MaxHeaderLength 返回数据链接(和较低级别的图层组合)标头可以具有的最大大小。 +// 较高级别使用此信息来保留它们正在构建的数据包前面预留空间。 +func (e *endpoint) MaxHeaderLength() uint16 { + return 0 +} + +// 本地链路层地址 +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +// 要参与透明桥接,LinkEndpoint实现应调用eth.Encode, +// 并将header.EthernetFields.SrcAddr设置为r.LocalLinkAddress(如果已提供)。 +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, + protocol tcpip.NetworkProtocolNumber) *tcpip.Error { + views := make([]buffer.View, 1, 1+len(payload.Views())) + views[0] = hdr.View() + views = append(views, payload.Views()...) + vv := buffer.NewVectorisedView(len(views[0])+payload.Size(), views) + + // Because we're immediately turning around and writing the packet back to the + // rx path, we intentionally don't preserve the remote and local link + // addresses from the stack.Route we're passed. + e.dispatcher.DeliverNetworkPacket(e, "" /* remoteLinkAddr */, "" /* localLinkAddr */, protocol, vv) + + return nil +} + +// Attach 将数据链路层端点附加到协议栈的网络层调度程序。 +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +// 是否已经添加了网络层调度器 +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} diff --git a/tcpip/transport/tcp/README.md b/tcpip/transport/tcp/README.md index a7a5751..a844c48 100644 --- a/tcpip/transport/tcp/README.md +++ b/tcpip/transport/tcp/README.md @@ -75,3 +75,30 @@ TCP 最初只规定了一种选项,即最大报文段长度 MSS(Maximum Segm 7. kind=8 是时间戳选项 该选项提供了较为准确的计算通信双方之间的回路时间(Round Trip Time,RTT)的方法,从而为 TCP 流量控制提供重要信息。 + +## tcp连接的建立 + +![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555574034117.png) + +上面的图片显示了 tcp 的三次握手,但只是简单的降了三次报文的交互,下面讲讲详细的三次握手。先讲三次握手的正常情况,接着我们再讲异常情况。 + +### 正常情况(没有丢包) + +主机 A 的 TCP 向主机 B 的 TCP 发出连接请求,发送 syn 报文段在发送 syn 之前,设置握手状态为 SynSent,还需要做一些准备工作,包括:随机生成 ISN1、计算 MSS、计算接收窗口扩展因子、是否开启 sack。 根据这些参数生成 syn 报文的选项参数,附在 tcp 选项中,然后发送带着这些选项的 syn 报文。 + +主机 B 的 TCP 收到连接请求 syn 报文段后,需要回复 syn+ack 报文因为 tcp 的控制报文需要消耗一个字节的序列号,所以回复的 ack 序列号为 ISN1+1,设置接收窗口,设置握手状态为 SynRcvd,并随机生成 ISN2、计算 MSS、计算接收窗口扩展因子、是否开启 sack。根据这些参数生成 syn+ack 报文的选项参数,附在 tcp 选项中,回复给主机 A。 + +主机 A 的 TCP 收到 syn+ack 报文段后,还要向 B 回复确认和上面一样,tcp 的控制报文需要消耗一个字节的序列号,所以回复的 ack 序列号为 ISN2+1,发送 ack 报文给主机 B。 + +主机 A 的 TCP 通知上层应用进程,连接已经建立,可以发送数据了,当主机 B 的 TCP 收到主机 A 的确认后,也通知上层应用进程,连接建立。 + +### 异常情况(有丢包) + +主机 A 发给主机 B 的 SYN 中途丢失,没有到达主机 B 因为在发送 syn 之前,就设置了超时定时器,如果在一定的时间内没收到回复,就会触发重传,所以主机 A 会周期性超时重传,直到收到主机 B 的确认。重传的周期,一开始默认 1s,每重传一次,变为原来的 2 倍,如果重传周期超过 1 分钟,返回错误,不再尝试重连。 + +主机 B 发给主机 A 的 SYN +ACK 中途丢失,没有到达主机 A 主机 B 会周期性超时重传,直到收到主机 A 的确认,重传的策略和 syn 报文一样,每重传一次,周期变为原来的 2 倍。 + +主机 A 发给主机 B 的 ACK 中途被丢,没有到达主机 B 主机 A 发完 ACK,单方面认为 TCP 为 Established 状态,而 B 显然认为 TCP 为 Active 状态: + +a. 如果此时双方都没有数据发送,主机 B 会周期性超时重传,直到收到 A 的确认,收到之后主机 B 的 TCP 连接也为 Established 状态,双向可以发包。 +b. 如果此时 A 有数据发送,主机 B 收到主机 A 的 Data + ACK,自然会切换为 established 状态,并接受主机 A 的 Data。 \ No newline at end of file diff --git a/tcpip/transport/tcp/accept.go b/tcpip/transport/tcp/accept.go index 4c2866e..7814073 100644 --- a/tcpip/transport/tcp/accept.go +++ b/tcpip/transport/tcp/accept.go @@ -222,13 +222,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head irs := s.sequenceNumber cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) logger.GetInstance().Info(logger.HANDSHAKE, func() { - log.Println("收到一个远端握手申请 SYN seq=", irs, "客户端请携带 标记 iss ", cookie, "+1") + log.Println("收到一个远端握手申请 SYN seq =", irs, "客户端请携带 标记 iss ", cookie, "+1") }) ep, err := l.createConnectedEndpoint(s, cookie, irs, opts) if err != nil { return nil, err } - log.Println("TCP STATE LISTEN") // 以下执行三次握手 @@ -238,10 +237,11 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.Close() return nil, err } - // 标记状态为 handshakeSynRcvd 和 h.flags为 syn+ack - log.Println("TCP STATE SYN_RCVD") h.resetToSynRcvd(cookie, irs, opts) + + log.Println("TCP STATE SYN_RCVD") + // 发送ack报文 接收client返回的ack if err := h.execute(); err != nil { ep.Close() @@ -351,6 +351,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { mayRequeue = false break } + log.Println("TCP STATE LISTEN") e.handleListenSegment(ctx, s) s.decRef() } diff --git a/tcpip/transport/tcp/connect.go b/tcpip/transport/tcp/connect.go index 165899b..ce9c6de 100644 --- a/tcpip/transport/tcp/connect.go +++ b/tcpip/transport/tcp/connect.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "fmt" "log" + "netstack/logger" "netstack/sleep" "netstack/tcpip" "netstack/tcpip/buffer" @@ -74,6 +75,10 @@ const ( maxOptionSize = 40 ) +// 主机 B 的 TCP 收到连接请求 syn 报文段后,需要回复 syn+ack 报文因为 tcp 的控制报文需要消耗一个字节的序列号, +// 所以回复的 ack 序列号为 ISN1+1,设置接收窗口,设置握手状态为 SynRcvd, +// 并随机生成 ISN2、计算 MSS、计算接收窗口扩展因子、是否开启 sack。 +// 根据这些参数生成 syn+ack 报文的选项参数,附在 tcp 选项中,回复给主机 A。 func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { h := handshake{ ep: ep, @@ -98,8 +103,7 @@ func (h *handshake) resetState() *tcpip.Error { h.flags = flagSyn h.ackNum = 0 h.mss = 0 - h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) - log.Println("收到 syn 同步报文 设置tcp状态为 [sent]") + h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) // 随机生成ISN2 return nil } @@ -128,11 +132,11 @@ func (h *handshake) resolveRoute() *tcpip.Error { // Initial action is to resolve route. index := wakerForResolution for { - log.Println(index) switch index { case wakerForResolution: if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { // Either success (err == nil) or failure. + log.Println("没有地址", err) return err } // Resolution not completed. Keep trying... @@ -258,7 +262,7 @@ func (h *handshake) processSegments() *tcpip.Error { // 执行tcp 3次握手,客户端和服务端都是调用该函数来实现三次握手 /* c flag s - | | +生成ISN1 | |生成ISN2 sync_sent|------sync---->|sync_rcvd | | | | @@ -302,7 +306,7 @@ func (h *handshake) execute() *tcpip.Error { case wakerForResend: // NOTE tcp超时重传机制 // 如果是客户端当发送 syn 报文,超过一定的时间未收到回包,触发超时重传 // 如果是服务端当发送 syn+ack 报文,超过一定的时间未收到 ack 回包,触发超时重传 - // 超时时间变为上次的2倍 + // 超时时间变为上次的2倍 如果重传周期超过 1 分钟,返回错误,不再尝试重连 timeOut *= 2 if timeOut > 60*time.Second { return tcpip.ErrTimeout @@ -313,6 +317,8 @@ func (h *handshake) execute() *tcpip.Error { case wakerForNotification: case wakerForNewSegment: + // 对方主机的 TCP 收到 syn+ack 报文段后,还要向 本机 回复确认和上面一样, + // tcp 的控制报文需要消耗一个字节的序列号,所以回复的 ack 序列号为 ISN2+1,发送 ack 报文给本机。 // 处理握手报文 if err := h.processSegments(); err != nil { return err @@ -447,7 +453,7 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise r.Stats().TCP.ResetsSent.Increment() } - log.Printf("send tcp %s segment to %s, seq: |%d|, ack: %d, rcvWnd: %d", + log.Printf("TCP 发送 [%s] 报文片段到 %s, seq: |%d|, ack: %d, rcvWnd: %d", flagString(flags), fmt.Sprintf("%s:%d", id.RemoteAddress, id.RemotePort), seq, ack, rcvWnd) @@ -553,6 +559,21 @@ func (e *endpoint) handleSegments() *tcpip.Error { // protocolMainLoop 是TCP协议的主循环。它在自己的goroutine中运行,负责握手、发送段和处理收到的段 func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { + + // 收尾工作 + + // 处理三次握手 + if handshake { + h, err := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + logger.GetInstance().Info(logger.HANDSHAKE, func() { + log.Println("TCP STATE SENT") + }) + if err == nil { + // 执行握手 + err = h.execute() + } + } + // Set up the functions that will be called when the main protocol loop // wakes up. // 触发器的事件,这些函数很重要 diff --git a/tcpip/transport/tcp/endpoint.go b/tcpip/transport/tcp/endpoint.go index 35256ac..09ac301 100644 --- a/tcpip/transport/tcp/endpoint.go +++ b/tcpip/transport/tcp/endpoint.go @@ -337,7 +337,31 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er return err } } else { - // TODO 需要添加 + // 端点还没有本地端口,所以尝试获取一个端口。确保它不会导致本地和远程的相同地址/端口(否则此端点将尝试连接到自身) + // 远端地址和本地地址是否相同 + // NOTE 这段代码值得借鉴 + sameAddr := e.id.LocalAddress == e.id.RemoteAddress + if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.id.RemotePort { // 同样的ip同样的port 打咩捏 + return false, nil + } + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p) { // 端口被占用打咩 + return false, nil + } + id := e.id + id.LocalPort = p + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) { + case nil: + e.id = id + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }); err != nil { + return err + } } // Remove the port reservation. This can happen when Bind is called @@ -356,7 +380,28 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er e.effectiveNetProtos = netProtos e.connectingAddress = connectingAddr - // TODO 需要添加 + // Connect in the restore phase does not perform handshake. Restore its + // connection setting here. + if !handshake { + //e.segmentQueue.mu.Lock() + //for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + // for s := l.Front(); s != nil; s = s.Next() { + // s.id = e.id + // s.route = r.Clone() + // e.sndWaker.Assert() + // } + //} + //e.segmentQueue.mu.Unlock() + //e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) + //e.state = stateConnected + } + + if run { + e.workerRunning = true + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + // tcp的主函数 + go e.protocolMainLoop(handshake) + } return tcpip.ErrConnectStarted } diff --git a/tcpip/transport/tcp/rcv.go b/tcpip/transport/tcp/rcv.go index 09a6d22..3c3ab4f 100644 --- a/tcpip/transport/tcp/rcv.go +++ b/tcpip/transport/tcp/rcv.go @@ -86,6 +86,11 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // 收到 fin,立即回复 ack r.ep.snd.sendAck() // FIXME 不应该是 seq+2 捏 + + // 标记接收器关闭 + // 触发上层应用可以读取 + r.closed = true + r.ep.readyToRead(nil) } return true @@ -108,7 +113,7 @@ func (r *receiver) handleRcvdSegment(s *segment) { return } - log.Println(s.data, segLen, segSeq) + //log.Println(s.data, segLen, segSeq) // Defer segment processing if it can't be consumed now. // tcp可靠性:r.consumeSegment 返回值是个bool类型,如果是true,表示已经消费该数据段,