diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index c74df5d..9bac84b 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -12,36 +12,31 @@ import ( "netstack/tcpip/link/tuntap" "netstack/tcpip/network/arp" "netstack/tcpip/network/ipv4" - "netstack/tcpip/network/ipv6" "netstack/tcpip/stack" "netstack/tcpip/transport/udp" "netstack/waiter" "os" "strings" + "time" ) -var mac = flag.String("mac", "01:01:01:01:01:01", "mac address to use in tap device") - func main() { flag.Parse() - if len(flag.Args()) != 2 { - log.Fatal("Usage: ", os.Args[0], " ") + if len(flag.Args()) < 2 { + log.Fatal("Usage: ", os.Args[0], " ") } log.SetFlags(log.Lshortfile | log.LstdFlags) tapName := flag.Arg(0) - listeAddr := flag.Arg(1) + cidrName := flag.Arg(1) - log.Printf("tap: %v, listeAddr: %v", tapName, listeAddr) + log.Printf("tap: %v, cidrName: %v", tapName, cidrName) - // Parse the mac address. - maddr, err := net.ParseMAC(*mac) + parsedAddr, cidr, err := net.ParseCIDR(cidrName) if err != nil { - log.Fatalf("Bad MAC address: %v", *mac) + log.Fatalf("Bad cidr: %v", cidrName) } - parsedAddr := net.ParseIP(listeAddr) - // 解析地址ip地址,ipv4或者ipv6地址都支持 var addr tcpip.Address var proto tcpip.NetworkProtocolNumber @@ -50,7 +45,7 @@ func main() { proto = ipv4.ProtocolNumber } else if parsedAddr.To16() != nil { addr = tcpip.Address(parsedAddr.To16()) - proto = ipv6.ProtocolNumber + //proto = ipv6.ProtocolNumber } else { log.Fatalf("Unknown IP type: %v", parsedAddr) } @@ -69,17 +64,22 @@ func main() { } // 启动tap网卡 - _ = tuntap.SetLinkUp(tapName) - // 设置tap网卡IP地址 - _ = tuntap.AddIP(tapName, listeAddr) + tuntap.SetLinkUp(tapName) + // 设置路由 + tuntap.SetRoute(tapName, cidr.String()) + + // 获取mac地址 + mac, err := tuntap.GetHardwareAddr(tapName) + if err != nil { + panic(err) + } // 抽象网卡的文件接口 linkID := fdbased.New(&fdbased.Options{ FD: fd, MTU: 1500, - Address: tcpip.LinkAddress(maddr), + Address: tcpip.LinkAddress(mac), }) - // 新建相关协议的协议栈 s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ /*tcp.ProtocolName, */ udp.ProtocolName}, stack.Options{}) @@ -109,37 +109,33 @@ func main() { }, }) - conn, _ := net.Listen("tcp", "0.0.0.0:9999") - TCPServer(conn, &RCV{s, nil, nil}) - // 同时监听tcp和udp localPort端口 - //tcpEp := tcpListen(s, proto, localPort) - //udpEp := udpListen(s, proto, localPort) - // 关闭监听服务,此时会释放端口 - //tcpEp.Close() - //udpEp.Close() -} + go func() { + // 监听udp localPort端口 + udpEp := udpListen(s, proto, 9999) -//func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { -// var wq waiter.Queue -// // 新建一个tcp端 -// ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq) -// if err != nil { -// log.Fatal(err) -// } -// -// // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP -// // 此时就会调用端口管理器 -// if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}, nil); err != nil { -// log.Fatal("Bind failed: ", err) -// } -// -// // 开始监听 -// if err := ep.Listen(10); err != nil { -// log.Fatal("Listen failed: ", err) -// } -// -// return ep -//} + for { + buf, _, err := udpEp.Read(nil) + if err != nil { + if err == tcpip.ErrWouldBlock { + time.Sleep(100 * time.Millisecond) + log.Println("阻塞中") + continue + } + } + log.Println(buf) + break + } + // 关闭监听服务,此时会释放端口 + udpEp.Close() + }() + + conn, _ := net.Listen("tcp", "0.0.0.0:9999") + rcv := &RCV{ + Stack: s, + addr: tcpip.FullAddress{}, + } + TCPServer(conn, rcv) +} func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { var wq waiter.Queue @@ -156,10 +152,6 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) log.Fatal("Bind failed: ", err) } - if err := ep.Connect(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}); err != nil { - log.Fatal("Conn failed: ", err) - } - // 注意UDP是无连接的,它不需要Listen return ep } @@ -167,6 +159,7 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) type RCV struct { *stack.Stack ep tcpip.Endpoint + addr tcpip.FullAddress rcvBuf []byte } @@ -189,6 +182,7 @@ func (r *RCV) Handle(conn net.Conn) { } r.ep = ep r.Bind() + r.Connect() r.Close() case "tcp": default: @@ -202,12 +196,16 @@ func (r *RCV) Bind() { return } port := binary.BigEndian.Uint16(r.rcvBuf[7:9]) - addr := tcpip.FullAddress{ - NIC: 0, + r.addr = tcpip.FullAddress{ + NIC: 1, Addr: tcpip.Address(r.rcvBuf[3:7]), Port: port, } - r.ep.Bind(addr, nil) + r.ep.Bind(r.addr, nil) +} + +func (r *RCV) Connect() { + r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888}) } func (r *RCV) Close() { diff --git a/cmd/udp_client/main.go b/cmd/udp_client/main.go new file mode 100644 index 0000000..78cac59 --- /dev/null +++ b/cmd/udp_client/main.go @@ -0,0 +1,42 @@ +package main + +import ( + "flag" + "log" + "net" +) + +func main() { + var ( + addr = flag.String("a", "192.168.1.1:9999", "udp dst address") + ) + + log.SetFlags(log.Lshortfile | log.LstdFlags) + + udpAddr, err := net.ResolveUDPAddr("udp", *addr) + if err != nil { + panic(err) + } + log.Println("解析地址") + + // 建立UDP连接(只是填息了目的IP和端口,并未真正的建立连接) + conn, err := net.DialUDP("udp", nil, udpAddr) + if err != nil { + panic(err) + } + log.Println("TEST") + + send := []byte("hello") + + if _, err := conn.Write(send); err != nil { + panic(err) + } + log.Printf("send: %s", string(send)) + + //recv := make([]byte, 10) + //rn, _, err := conn.ReadFrom(recv) + //if err != nil { + // panic(err) + //} + //log.Printf("recv: %s", string(recv[:rn])) +} diff --git a/tcpip/header/checksum.go b/tcpip/header/checksum.go index 2876224..6a2a00c 100644 --- a/tcpip/header/checksum.go +++ b/tcpip/header/checksum.go @@ -2,7 +2,10 @@ package header import "netstack/tcpip" -// 校验和的计算 +// Checksum 校验和的计算 +// UDP 检验和的计算方法是: 按每 16 位求和得出一个 32 位的数; +// 如果这个 32 位的数,高 16 位不为 0,则高 16 位加低 16 位再得到一个 32 位的数; +// 重复第 2 步直到高 16 位为 0,将低 16 位取反,得到校验和。 func Checksum(buf []byte, initial uint16) uint16 { v := uint32(initial) diff --git a/tcpip/header/udp.go b/tcpip/header/udp.go index 4df3682..62eb194 100644 --- a/tcpip/header/udp.go +++ b/tcpip/header/udp.go @@ -1,6 +1,10 @@ package header -import "netstack/tcpip" +import ( + "encoding/binary" + "fmt" + "netstack/tcpip" +) const ( udpSrcPort = 0 @@ -36,3 +40,99 @@ const ( // UDPProtocolNumber is UDP's transport protocol number. UDPProtocolNumber tcpip.TransportProtocolNumber = 17 ) + +/* +UDP 是 User Datagram Protocol 的简称,中文名是用户数据报协议。UDP 只在 IP 数据报服务上增加了一点功能,就是复用和分用的功能以及差错检测,UDP 主要的特点是: + +1. UDP 是无连接的,即发送数据之前不需要建立连接,发送结束也不需要连接释放,因此减少了开销和发送数据之间的延时。 +2. UDP 是不可靠传输,尽最大努力交付,因此不需要维护复杂的连接状态。 +3. UDP 的数据报是有消息边界的,发送方发送一个报文,接收方就会完整的收到一个报文。 +4. UDP 没有拥塞控制,网络出现阻塞,UDP 是无感知的,也就不会降低发送速度。 +5. UDP 支持一对一,一对多,多对一,多对多的通信。 +*/ + +/* +|source Port|destination Port| +| Length | UDP Checksum | +| Data | +*/ + +// SourcePort returns the "source port" field of the udp header. +func (b UDP) SourcePort() uint16 { + return binary.BigEndian.Uint16(b[udpSrcPort:]) +} + +// DestinationPort returns the "destination port" field of the udp header. +func (b UDP) DestinationPort() uint16 { + return binary.BigEndian.Uint16(b[udpDstPort:]) +} + +// Length returns the "length" field of the udp header. +func (b UDP) Length() uint16 { + return binary.BigEndian.Uint16(b[udpLength:]) +} + +// Payload returns the data contained in the UDP datagram. +func (b UDP) Payload() []byte { + return b[UDPMinimumSize:] +} + +// Checksum returns the "checksum" field of the udp header. +func (b UDP) Checksum() uint16 { + return binary.BigEndian.Uint16(b[udpChecksum:]) +} + +// SetSourcePort sets the "source port" field of the udp header. +func (b UDP) SetSourcePort(port uint16) { + binary.BigEndian.PutUint16(b[udpSrcPort:], port) +} + +// SetDestinationPort sets the "destination port" field of the udp header. +func (b UDP) SetDestinationPort(port uint16) { + binary.BigEndian.PutUint16(b[udpDstPort:], port) +} + +// SetChecksum sets the "checksum" field of the udp header. +func (b UDP) SetChecksum(checksum uint16) { + binary.BigEndian.PutUint16(b[udpChecksum:], checksum) +} + +// CalculateChecksum calculates the checksum of the udp packet, given the total +// length of the packet and the checksum of the network-layer pseudo-header +// (excluding the total length) and the checksum of the payload. +func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 { + // Add the length portion of the checksum to the pseudo-checksum. + tmp := make([]byte, 2) + binary.BigEndian.PutUint16(tmp, totalLen) + checksum := Checksum(tmp, partialChecksum) + + // Calculate the rest of the checksum. + return Checksum(b[:UDPMinimumSize], checksum) +} + +// Encode encodes all the fields of the udp header. +func (b UDP) Encode(u *UDPFields) { + binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort) + binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort) + binary.BigEndian.PutUint16(b[udpLength:], u.Length) + binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum) +} + +var udpFmt string = ` +|% 16s|% 16s| +|% 16s|% 16s| +%v +` + +func (b UDP) String() string { + for i := range b.Payload() { + if i != int(b.Length()-8-1) && b.Payload()[i]^b.Payload()[i+1] != 0 { + return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), + atoi(b.Length()), atoi(b.Checksum()), + b.Payload()) + } + } + return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()), + atoi(b.Length()), atoi(b.Checksum()), + fmt.Sprintf("%v x %d", b.Payload()[0], b.Length()-8)) +} diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go index 6c661c1..21d40ed 100644 --- a/tcpip/stack/nic.go +++ b/tcpip/stack/nic.go @@ -60,7 +60,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { id: id, name: name, linkEP: ep, - demux: nil, // TODO 需要处理 + demux: newTransportDemuxer(stack), // NOTE 注册网卡自己的传输层分流器 primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), } @@ -302,6 +302,75 @@ func (n *NIC) Subnets() []tcpip.Subnet { return append(sns, n.subnets...) } +// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket,用来分发网络层数据包。 +// 比如 protocol 是 arp 协议号,那么会找到arp.HandlePacket来处理数据报。 +// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它, +// 当前实现的网络层协议有 arp、ipv4 和 ipv6。 +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, + protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.stack.stats.UnknownProtocolRcvdPackets.Increment() + return + } + + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { + n.stack.stats.IP.PacketsReceived.Increment() + } + + if len(vv.First()) < netProto.MinimumPacketSize() { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + src, dst := netProto.ParseAddresses(vv.First()) + log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte { + if len(vv.ToView()) > 64 { + return vv.ToView()[:64] + } + return vv.ToView() + }()) + // 根据网络协议和数据包的目的地址,找到网络端 + // 然后将数据包分发给网络层 + if ref := n.getRef(protocol, dst); ref != nil { + r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) + r.RemoteLinkAddress = remoteLinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() + return + } + + if n.stack.Forwarding() { + r, err := n.stack.FindRoute(0, "", dst, protocol) + if err != nil { + n.stack.stats.IP.InvalidAddressesReceived.Increment() + return + } + defer r.Release() + + r.LocalLinkAddress = n.linkEP.LinkAddress() + r.RemoteLinkAddress = remoteLinkAddr + + // Found a NIC. + n := r.ref.nic + n.mu.RLock() + ref, ok := n.endpoints[NetworkEndpointID{dst}] + n.mu.RUnlock() + if ok && ref.tryIncRef() { + ref.ep.HandlePacket(&r, vv) + ref.decRef() + } else { + // n doesn't have a destination endpoint. + // Send the packet out of n. + hdr := buffer.NewPrependableFromView(vv.First()) + vv.RemoveFirst() + n.linkEP.WritePacket(&r, hdr, vv, protocol) + } + return + } + + n.stack.stats.IP.InvalidAddressesReceived.Increment() +} + // 根据协议类型和目标地址,找出关联的Endpoint func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { id := NetworkEndpointID{dst} @@ -344,57 +413,49 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r return nil } -// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket,用来分发网络层数据包。 -// 比如 protocol 是 arp 协议号,那么会找到arp.HandlePacket来处理数据报。 -// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它, -// 当前实现的网络层协议有 arp、ipv4 和 ipv6。 -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, - protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - n.stack.stats.UnknownProtocolRcvdPackets.Increment() - return - } - - if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { - n.stack.stats.IP.PacketsReceived.Increment() - } - - if len(vv.First()) < netProto.MinimumPacketSize() { - n.stack.stats.MalformedRcvdPackets.Increment() - return - } - src, dst := netProto.ParseAddresses(vv.First()) - log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte { - if len(vv.ToView()) > 64 { - return vv.ToView()[:64] - } - return vv.ToView() - }()) - // 根据网络协议和数据包的目的地址,找到网络端 - // 然后将数据包分发给网络层 - if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) - r.RemoteLinkAddress = remoteLinkAddr - ref.ep.HandlePacket(&r, vv) - ref.decRef() - - return - } - n.stack.stats.IP.InvalidAddressesReceived.Increment() -} - // DeliverTransportPacket delivers packets to the appropriate // transport protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) { // 先查找协议栈是否注册了该传输层协议 - _, ok := n.stack.transportProtocols[protocol] + state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() return } - log.Println("准备分发传输层数据报", n.stack.transportProtocols) + transProto := state.proto + // 如果报文长度比该协议最小报文长度还小,那么丢弃它 + if len(vv.First()) < transProto.MinimumPacketSize() { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + // 解析报文得到源端口和目的端口 + srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + if err != nil { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + log.Println("准备分发传输层数据报", n.stack.transportProtocols, srcPort, dstPort) + id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} + // 调用分流器,根据传输层协议和传输层id分发数据报文 + if n.demux.deliverPacket(r, protocol, vv, id) { + return + } + if n.stack.demux.deliverPacket(r, protocol, vv, id) { + return + } + // Try to deliver to per-stack default handler. + if state.defaultHandler != nil { + if state.defaultHandler(r, id, vv) { + return + } + } + + // We could not find an appropriate destination for this packet, so + // deliver it to the global handler. + if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + n.stack.stats.MalformedRcvdPackets.Increment() + } } // DeliverTransportControlPacket delivers control packets to the diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go index 71f4bcb..69831fb 100644 --- a/tcpip/stack/registration.go +++ b/tcpip/stack/registration.go @@ -185,7 +185,7 @@ type NetworkEndpointID struct { type TransportEndpointID struct { LocalPort uint16 LocalAddress tcpip.Address - remotePort uint16 + RemotePort uint16 RemoteAddress tcpip.Address } diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go index babbc80..9e09d0d 100644 --- a/tcpip/stack/stack.go +++ b/tcpip/stack/stack.go @@ -116,12 +116,87 @@ func New(network []string, transport []string, opts Options) *Stack { proto: transProto, } } - // 添加传输层分流器 + // NOTE 添加协议栈全局传输层分流器 s.demux = newTransportDemuxer(s) return s } +// SetNetworkProtocolOption allows configuring individual protocol level +// options. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation or the provided value +// is incorrect. +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { + netProto, ok := s.networkProtocols[network] + if !ok { + return tcpip.ErrUnknownProtocol + } + return netProto.SetOption(option) +} + +// NetworkProtocolOption allows retrieving individual protocol level option +// values. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation. +// e.g. +// var v ipv4.MyOption +// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v) +// +// if err != nil { +// ... +// } +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { + netProto, ok := s.networkProtocols[network] + if !ok { + return tcpip.ErrUnknownProtocol + } + return netProto.Option(option) +} + +// SetTransportProtocolOption allows configuring individual protocol level +// options. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation or the provided value +// is incorrect. +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { + transProtoState, ok := s.transportProtocols[transport] + if !ok { + return tcpip.ErrUnknownProtocol + } + return transProtoState.proto.SetOption(option) +} + +// TransportProtocolOption allows retrieving individual protocol level option +// values. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation. +// var v tcp.SACKEnabled +// +// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { +// ... +// } +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { + transProtoState, ok := s.transportProtocols[transport] + if !ok { + return tcpip.ErrUnknownProtocol + } + return transProtoState.proto.Option(option) +} + +// SetTransportProtocolHandler sets the per-stack default handler for the given +// protocol. +// +// It must be called only during initialization of the stack. Changing it as the +// stack is operating is not supported. +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) { + state := s.transportProtocols[p] + if state != nil { + state.defaultHandler = h + } +} + +// NowNanoseconds implements tcpip.Clock.NowNanoseconds. +func (s *Stack) NowNanoseconds() int64 { + return s.clock.NowNanoseconds() +} + func (s *Stack) Stats() tcpip.Stats { return s.stats } @@ -260,19 +335,19 @@ func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpi return false, tcpip.ErrUnknownNICID } -// 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息 +// FindRoute 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息 func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() for i := range s.routeTable { - if (id != 0 && id != s.routeTable[i].NIC) || + if (id != 0 && id != s.routeTable[i].NIC) || // 检查是否是对应的网卡 (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) { continue } - nic := s.nics[s.routeTable[i].NIC] + nic := s.nics[s.routeTable[i].NIC] // 在协议栈里找到这张网卡 if nic == nil { continue } @@ -372,14 +447,34 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // 最终调用 demuxer.registerEndpoint 函数来实现注册。 func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { - // TODO 需要实现 - return nil + log.Println("往", nicID, "网卡注册新的传输端") + if nicID == 0 { + return s.demux.registerEndpoint(netProtos, protocol, id, ep) // 给协议栈的所有网卡注册传输端 + } + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + return nic.demux.registerEndpoint(netProtos, protocol, id, ep) // 给这张网卡注册传输端 } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { + if nicID == 0 { + s.demux.unregisterEndpoint(netProtos, protocol, id) + return + } + s.mu.RLock() + defer s.mu.RUnlock() + nic := s.nics[nicID] + if nic != nil { + nic.demux.unregisterEndpoint(netProtos, protocol, id) + } } diff --git a/tcpip/stack/transport_demuxer.go b/tcpip/stack/transport_demuxer.go index cc63328..48c2c5d 100644 --- a/tcpip/stack/transport_demuxer.go +++ b/tcpip/stack/transport_demuxer.go @@ -2,6 +2,7 @@ package stack import ( "netstack/tcpip" + "netstack/tcpip/buffer" "sync" ) @@ -23,6 +24,112 @@ type transportDemuxer struct { } // 新建一个分流器 -func newTransportDemuxer(stacl *Stack) *transportDemuxer { +func newTransportDemuxer(stack *Stack) *transportDemuxer { + d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + + for netProto := range stack.networkProtocols { + for tranProto := range stack.transportProtocols { + d.protocol[protocolIDs{network: netProto, transport: tranProto}] = &transportEndpoints{ + endpoints: make(map[TransportEndpointID]TransportEndpoint), + } + } + } + return d +} + +// registerEndpoint 向分发器注册给定端点,以便将与端点ID匹配的数据包传递给它 +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, + protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + for i, n := range netProtos { + if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id) // 把刚才注册的注销掉 + return err + } + } return nil } + +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, + protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, protocol}] // IPv4:udp + if !ok { // 未曾注册过这个传输端集合 + return nil + } + + eps.mu.Lock() + defer eps.mu.Unlock() + + if _, ok := eps.endpoints[id]; ok { // 遍历传输端集合 + return tcpip.ErrPortInUse + } + eps.endpoints[id] = ep + return nil +} + +// unregisterEndpoint 使用给定的id注销端点,使其不再接收任何数据包 +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, + protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { + for _, n := range netProtos { + if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { + eps.mu.Lock() + delete(eps.endpoints, id) + eps.mu.Unlock() + } + } +} + +// 根据传输层的id来找到对应的传输端,再将数据包交给这个传输端处理 +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool { + // 先看看分流器里有没有注册相关协议端,如果没有则返回false + eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] + if !ok { + return false + } + // 从 eps 中找符合 id 的传输端 + eps.mu.RLock() + ep := d.findEndpointLocked(eps, vv, id) + eps.mu.RUnlock() + + if ep == nil { + return false + } + + // Deliver the packet + ep.HandlePacket(r, id, vv) + + return true +} + +func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, + trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { + return false +} + +// 根据传输层id来找到相应的传输层端 +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, + vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { + if ep := eps.endpoints[id]; ep != nil { // IPv4:udp + return ep + } + // Try to find a match with the id minus the local address. + nid := id + // 如果上面的 endpoints 没有找到,那么去掉本地ip地址,看看有没有相应的传输层端 + // 因为有时候传输层监听的时候没有绑定本地ip,也就是 any address,此时的 LocalAddress + // 为空。 + nid.LocalAddress = "" + if ep := eps.endpoints[nid]; ep != nil { + return ep + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep := eps.endpoints[nid]; ep != nil { + return ep + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + return eps.endpoints[nid] +} diff --git a/tcpip/transport/udp/README.md b/tcpip/transport/udp/README.md index d05bdf2..8ab687b 100644 --- a/tcpip/transport/udp/README.md +++ b/tcpip/transport/udp/README.md @@ -3,6 +3,7 @@ ![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555488741384.png) 传输层是整个网络体系结构中的关键之一,我们很多编程都是直接和传输层打交道的,我们需要了解以下的概念: + 1. 端口的意义 - 上一章已经介绍过了 2. 无连接 UDP 协议及特点 - 本章介绍 3. 面向连接 TCP 协议及特点 - 下章会介绍 @@ -15,4 +16,4 @@ 3. 报文差错检测 网络层只对 IP 首部进行差错检测,而传输层对整个报文进行差错检测。 -4. 提供不可靠和可靠通信 网络层只提供了不可靠通信,而在传输层的 TCP 协议提供了可靠通信。 \ No newline at end of file +4. 提供不可靠和可靠通信 网络层只提供了不可靠通信,而在传输层的 TCP 协议提供了可靠通信。 diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go index a8fa760..556b1cc 100644 --- a/tcpip/transport/udp/endpoint.go +++ b/tcpip/transport/udp/endpoint.go @@ -13,7 +13,13 @@ import ( // udp报文结构 当收到udp报文时 会用这个结构来保存udp报文数据 type udpPacket struct { udpPacketEntry // 链表实现 - // TODO 需要添加 + senderAddress tcpip.FullAddress + data buffer.VectorisedView + timestamp int64 + hasTimestamp bool + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View } type endpointState int @@ -40,7 +46,7 @@ type endpoint struct { rcvBufSizeMax int rcvBufSize int rcvClosed bool - rcvTimestamp bool + rcvTimestamp bool // 通过SetSocket进行设置 是否开启时间戳 // The following fields are protected by the mu mutex. mu sync.RWMutex @@ -57,7 +63,7 @@ type endpoint struct { // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags - // TODO + multicastMemberships []multicastMembership // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -68,6 +74,12 @@ type endpoint struct { effectiveNetProtos []tcpip.NetworkProtocolNumber } +// 多播的成员关系,包括多播地址和网卡ID +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { log.Println("新建一个udp端") @@ -76,8 +88,32 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, netProto: netProto, waiterQueue: waiterQueue, multicastTTL: 1, - rcvBufSizeMax: 32 * 1024, - sndBufSize: 32 * 1024} + rcvBufSizeMax: 32 * 1024, // 接收缓存 32k + sndBufSize: 32 * 1024, // 发送缓存 32k + } +} + +// NewConnectedEndpoint creates a new endpoint in the connected state using the +// provided route. +func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, + waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := newEndpoint(stack, r.NetProto, waiterQueue) + + // Register new endpoint so that packets are routed to it. + if err := stack.RegisterTransportEndpoint(r.NICID(), + []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil { + ep.Close() + return nil, err + } + + ep.id = id + ep.route = r.Clone() + ep.dstPort = id.RemotePort + ep.regNICID = r.NICID() + + ep.state = stateConnected + + return ep, nil } // Close UDP端的关闭,释放相应的资源 @@ -98,8 +134,37 @@ func (e *endpoint) Close() { e.mu.Unlock() } -func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return nil, tcpip.ControlMessages{}, nil +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.rcvMu.Lock() + + // 如果接收链表为空,即没有任何数据 + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + // 从接收链表中取出最前面的数据报,接着从链表中删除该数据报 + // 然后减少接收缓存的大小 + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + ts := e.rcvTimestamp + + e.rcvMu.Unlock() + + if ts && !p.hasTimestamp { + // Linux uses the current time. + p.timestamp = e.stack.NowNanoseconds() + } + if addr != nil { + // 赋值发送地址 + *addr = p.senderAddress + } + + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil } func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { @@ -141,8 +206,95 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t return netProto, nil } -func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error { - log.Println("连接") +// Connect UDP中调用connect内核仅仅把对端ip&port记录下来. 这样在发送数据的时候无需再次指定 +// UDP多次调用connect有两种用途:1,指定一个新的ip&port连结. 2,断开和之前的ip&port的连结 +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + // 目标端口为0是错误的 + if addr.Port == 0 { + // We don't support connecting to port zero. + return tcpip.ErrInvalidEndpointState + } + e.mu.Lock() + defer e.mu.Unlock() + + nicid := addr.NIC + var localPort uint16 + // 判断UDP端的状态 + switch e.state { + case stateInitial: + // 如果是初始状态,直接下一步 + case stateBound, stateConnected: + localPort = e.id.LocalPort + log.Printf("绑定了 %d 的udp端 向[%d]网卡发起连接\n", localPort, nicid) + if e.bindNICID == 0 { + break + } + if nicid != 0 && nicid != e.bindNICID { + return tcpip.ErrInvalidEndpointState + } + nicid = e.bindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + // 检查地址的映射,得到相应的协议 + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + // Find a route to the desired destination. + // 在全局协议栈中查找路由 + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) + if err != nil { + return err + } + defer r.Release() + + // 新建一个传输端的标识,包括源IP、源端口、目的IP、目的端口 + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + LocalPort: localPort, + RemotePort: addr.Port, + RemoteAddress: r.RemoteAddress, + } + + // 设置网络层协议,IPV4或IPV6,或两者都有 + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } + } + + // 将该UDP端注册到协议栈中 + id, err = e.registerWithStack(nicid, netProtos, id) + if err != nil { + return err + } + // Remove the old registration. + // 如果源端口不为0,则尝试在传输层端中删除老的UDP端 + if e.id.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + } + + log.Println(e.id, id) + + // 赋值UDP端的属性 + e.id = id + e.route = r.Clone() + e.dstPort = addr.Port + e.regNICID = nicid + e.effectiveNetProtos = netProtos + + // 更改该UDP端的状态为已连接 + e.state = stateConnected + + // 标志该UDP端可以接收数据了 + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + return nil } @@ -167,7 +319,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) // 往网卡注册一个绑定了端口的udp端 if err != nil { // 释放端口 e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) @@ -206,6 +358,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error LocalAddress: addr.Addr, LocalPort: addr.Port, } + log.Println("Bind", id) // 在协议栈中注册该UDP端 id, err = e.registerWithStack(addr.NIC, netProtos, id) if err != nil { @@ -229,6 +382,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error // 标记状态为已绑定 e.state = stateBound + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + return nil } @@ -271,9 +428,64 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil } -// 从网络层接收到UDP数据报时的处理函数 +// HandlePacket 从网络层接收到UDP数据报时的处理函数 +// 首先 UDP 端有接收队列的概念,不像网络层接收到数据包立马发送给传输层, +// 对于协议栈来说,传输层是最后的一站,接下来的数据就需要交给用户层了, +// 但是用户层的行为是不可预知的,不知道用户层何时将数据取走(也就是 UDP Read 过程), +// 那么协议栈就实现一个接收队列,将接收的数据去掉 UDP 头部后保存在这个队列中,用户层需要的时候取走就可以了, +// 但是队列存数据量是有限制的,这个限制叫接收缓存大小,当接收队列中的数据总和超过这个缓存,那么接下来的这些报文将会被直接丢包。 func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + // 错误报文 + e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + return + } + log.Println("udp 正式接收数据", hdr) + // 去除UDP首部 + vv.TrimFront(header.UDPMinimumSize) + + e.rcvMu.Lock() + e.stack.Stats().UDP.PacketsReceived.Increment() + + // Drop the packet if our buffer is currently full. + // 如果UDP的接收缓存已经满了,那么丢弃报文。 + if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.rcvMu.Unlock() + log.Println("udp 接收缓存不足 丢弃报文") + return + } + + // 接收缓存是否为空 + wasEmpty := e.rcvBufSize == 0 + // 新建一个UDP数据包结构 插入到接收链表中 + pkt := &udpPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + Port: hdr.SourcePort(), + }, + } + // 复制UDP数据包的用户数据 + pkt.data = vv.Clone(pkt.views[:]) // 当vv中的数据<=8时 无需再次分配内存 + // 插入到接收链表中 并增加已经使用的缓存 + e.rcvList.PushBack(pkt) + e.rcvBufSize += vv.Size() + + if e.rcvTimestamp { + pkt.timestamp = e.stack.NowNanoseconds() + pkt.hasTimestamp = true + } + + e.rcvMu.Unlock() + // TODO 通知用户层可以读取数据了 + if wasEmpty { + + } } // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. diff --git a/tcpip/transport/udp/protocol.go b/tcpip/transport/udp/protocol.go index 780eb06..e35ef8d 100644 --- a/tcpip/transport/udp/protocol.go +++ b/tcpip/transport/udp/protocol.go @@ -38,9 +38,8 @@ func (*protocol) MinimumPacketSize() int { // ParsePorts returns the source and destination ports stored in the given udp // packet. func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { - //h := header.UDP(v) - //return h.SourcePort(), h.DestinationPort(), nil - return 0, 0, nil + h := header.UDP(v) + return h.SourcePort(), h.DestinationPort(), nil } // HandleUnknownDestinationPacket handles packets targeted at this protocol but