From 050d5fec9768d77ffbcf1ef876ce2b09281e261b Mon Sep 17 00:00:00 2001 From: impact-eintr Date: Fri, 2 Dec 2022 21:11:41 +0800 Subject: [PATCH] =?UTF-8?q?udp=E5=9F=BA=E6=9C=AC=E5=86=99=E5=AE=8C?= =?UTF-8?q?=E4=BA=86=20=E5=85=B3=E4=BA=8Esocket=E7=9A=84=E4=B8=8D=E5=B0=91?= =?UTF-8?q?=E7=BB=86=E8=8A=82=E8=BF=98=E6=B2=A1=E7=9C=8B=20=E5=8F=A6?= =?UTF-8?q?=E5=A4=96=E5=9C=A8=E7=BD=91=E7=BB=9C=E6=A0=88=E9=9D=99=E7=BD=AE?= =?UTF-8?q?=E4=B8=80=E6=AE=B5=E6=97=B6=E9=97=B4=E5=90=8E=E5=86=8D=E6=AC=A1?= =?UTF-8?q?=E5=8F=91=E8=B5=B7=E8=BF=9E=E6=8E=A5=E5=B0=86=E6=97=A0=E6=B3=95?= =?UTF-8?q?=E6=89=BE=E5=88=B0=E8=B7=AF=E7=94=B1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/netstack/main.go | 87 +++++--- cmd/udp_client/main.go | 25 +-- tcpip/link/fdbased/endpoint.go | 4 +- tcpip/link/tuntap/tuntap.go | 68 +++++- tcpip/network/arp/arp.go | 4 +- tcpip/network/ipv4/ipv4.go | 5 +- tcpip/stack/linkaddrcache.go | 2 +- tcpip/stack/nic.go | 26 ++- tcpip/stack/registration.go | 2 +- tcpip/stack/route.go | 38 ++++ tcpip/stack/stack.go | 36 ++- tcpip/tcpip.go | 50 +++++ tcpip/transport/udp/endpoint.go | 382 +++++++++++++++++++++++++++++++- 13 files changed, 655 insertions(+), 74 deletions(-) diff --git a/cmd/netstack/main.go b/cmd/netstack/main.go index ebc1e11..7916e60 100644 --- a/cmd/netstack/main.go +++ b/cmd/netstack/main.go @@ -10,28 +10,42 @@ import ( "netstack/tcpip/link/tuntap" "netstack/tcpip/network/arp" "netstack/tcpip/network/ipv4" + "netstack/tcpip/network/ipv6" "netstack/tcpip/stack" "netstack/tcpip/transport/udp" "netstack/waiter" "os" + "os/signal" + "strconv" "strings" + "syscall" ) +var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device") + func main() { flag.Parse() - if len(flag.Args()) < 2 { - log.Fatal("Usage: ", os.Args[0], " ") + if len(flag.Args()) != 4 { + log.Fatal("Usage: ", os.Args[0], " ") } log.SetFlags(log.Lshortfile | log.LstdFlags) + tapName := flag.Arg(0) cidrName := flag.Arg(1) + addrName := flag.Arg(2) + portName := flag.Arg(3) - log.Printf("tap: %v, cidrName: %v", tapName, cidrName) + log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName) - parsedAddr, cidr, err := net.ParseCIDR(cidrName) + maddr, err := net.ParseMAC(*mac) if err != nil { - log.Fatalf("Bad cidr: %v", cidrName) + log.Fatalf("Bad MAC address: %v", *mac) + } + + parsedAddr := net.ParseIP(addrName) + if err != nil { + log.Fatalf("Bad addrress: %v", addrName) } // 解析地址ip地址,ipv4或者ipv6地址都支持 @@ -42,11 +56,16 @@ func main() { proto = ipv4.ProtocolNumber } else if parsedAddr.To16() != nil { addr = tcpip.Address(parsedAddr.To16()) - //proto = ipv6.ProtocolNumber + proto = ipv6.ProtocolNumber } else { log.Fatalf("Unknown IP type: %v", parsedAddr) } + localPort, err := strconv.Atoi(portName) + if err != nil { + log.Fatalf("Unable to convert port %v: %v", portName, err) + } + // 虚拟网卡配置 conf := &tuntap.Config{ Name: tapName, @@ -61,22 +80,18 @@ func main() { } // 启动tap网卡 - tuntap.SetLinkUp(tapName) + _ = tuntap.SetLinkUp(tapName) // 设置路由 - tuntap.SetRoute(tapName, cidr.String()) + _ = tuntap.SetRoute(tapName, cidrName) - // 获取mac地址 - mac, err := tuntap.GetHardwareAddr(tapName) - if err != nil { - panic(err) - } - - // 抽象网卡的文件接口 + // 抽象的文件接口 linkID := fdbased.New(&fdbased.Options{ - FD: fd, - MTU: 1500, - Address: tcpip.LinkAddress(mac), + FD: fd, + MTU: 1500, + Address: tcpip.LinkAddress(maddr), + ResolutionRequired: true, }) + // 新建相关协议的协议栈 s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ /*tcp.ProtocolName, */ udp.ProtocolName}, stack.Options{}) @@ -106,9 +121,9 @@ func main() { }, }) - go func() { + go func() { // echo server // 监听udp localPort端口 - conn := udpListen(s, proto, 9999) + conn := udpListen(s, proto, localPort) for { buf := make([]byte, 1024) @@ -117,13 +132,17 @@ func main() { log.Println(err) break } - log.Println("接收到数据", buf[:n]) + log.Println("接收到数据", string(buf[:n])) + conn.Write([]byte("server echo")) } // 关闭监听服务,此时会释放端口 conn.Close() }() - select {} + c := make(chan os.Signal) + signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2) + <-c + //conn, _ := net.Listen("tcp", "0.0.0.0:9999") //rcv := &RCV{ // Stack: s, @@ -133,6 +152,7 @@ func main() { } type UdpConn struct { + raddr tcpip.FullAddress ep tcpip.Endpoint wq *waiter.Queue we *waiter.Entry @@ -147,7 +167,7 @@ func (conn *UdpConn) Read(rcv []byte) (int, error) { conn.wq.EventRegister(conn.we, waiter.EventIn) defer conn.wq.EventUnregister(conn.we) for { - buf, _, err := conn.ep.Read(nil) + buf, _, err := conn.ep.Read(&conn.raddr) if err != nil { if err == tcpip.ErrWouldBlock { <-conn.notifyCh @@ -155,8 +175,19 @@ func (conn *UdpConn) Read(rcv []byte) (int, error) { } return 0, fmt.Errorf("%s", err.String()) } - rcv = append(rcv[:0], buf[:cap(rcv)]...) - return len(rcv), nil + n := len(buf) + if n > cap(rcv) { + n = cap(rcv) + } + rcv = append(rcv[:0], buf[:n]...) + return n, nil + } +} + +func (conn *UdpConn) Write(snd []byte) { + _, _, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr}) + if err != nil { + log.Fatal(err) } } @@ -176,5 +207,9 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) } waitEntry, notifyCh := waiter.NewChannelEntry(nil) - return &UdpConn{ep, &wq, &waitEntry, notifyCh} + return &UdpConn{ + ep: ep, + wq: &wq, + we: &waitEntry, + notifyCh: notifyCh} } diff --git a/cmd/udp_client/main.go b/cmd/udp_client/main.go index ff3ea2a..994b44e 100644 --- a/cmd/udp_client/main.go +++ b/cmd/udp_client/main.go @@ -13,31 +13,28 @@ func main() { log.SetFlags(log.Lshortfile | log.LstdFlags) + var err error udpAddr, err := net.ResolveUDPAddr("udp", *addr) if err != nil { panic(err) } - log.Println("解析地址") // 建立UDP连接(只是填息了目的IP和端口,并未真正的建立连接) conn, err := net.DialUDP("udp", nil, udpAddr) if err != nil { panic(err) } - log.Println("TEST") - for i := 0; i < 3; i++ { - send := make([]byte, 2048) - if _, err := conn.Write(send); err != nil { - panic(err) - } - log.Printf("send: %s", string(send)) + send := []byte("hello world") + if _, err := conn.Write(send); err != nil { + panic(err) } + log.Printf("send: %s", string(send)) - //recv := make([]byte, 10) - //rn, _, err := conn.ReadFrom(recv) - //if err != nil { - // panic(err) - //} - //log.Printf("recv: %s", string(recv[:rn])) + recv := make([]byte, 32) + rn, _, err := conn.ReadFrom(recv) + if err != nil { + panic(err) + } + log.Printf("recv: %s", string(recv[:rn])) } diff --git a/tcpip/link/fdbased/endpoint.go b/tcpip/link/fdbased/endpoint.go index 80055fb..2829282 100644 --- a/tcpip/link/fdbased/endpoint.go +++ b/tcpip/link/fdbased/endpoint.go @@ -49,7 +49,7 @@ type Options struct { TestLossPacket func(data []byte) bool } -// 根据选项参数创建一个链路层的endpoint,并返回该endpoint的id +// New 根据选项参数创建一个链路层的endpoint,并返回该endpoint的id func New(opts *Options) tcpip.LinkEndpointID { syscall.SetNonblock(opts.FD, true) caps := stack.LinkEndpointCapabilities(0) // 初始化 @@ -203,7 +203,7 @@ func (e *endpoint) dispatch() (bool, *tcpip.Error) { switch p { case header.ARPProtocolNumber, header.IPv4ProtocolNumber: - log.Println("链路层收到报文") + log.Println("链路层收到报文,来自: ", remoteLinkAddr, localLinkAddr) e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv) case header.IPv6ProtocolNumber: // TODO ipv6暂时不感兴趣 diff --git a/tcpip/link/tuntap/tuntap.go b/tcpip/link/tuntap/tuntap.go index aeb0c44..6e7592d 100644 --- a/tcpip/link/tuntap/tuntap.go +++ b/tcpip/link/tuntap/tuntap.go @@ -3,6 +3,7 @@ package tuntap import ( "errors" "fmt" + "log" "os/exec" "syscall" "unsafe" @@ -19,12 +20,12 @@ var ( type rawSockaddr struct { Family uint16 - Data [14]byte + Data [14]byte } type Config struct { Name string // 网卡名 - Mode int // 网卡模式 TUN or TAP + Mode int // 网卡模式 TUN or TAP } // NewNetDev根据配置返回虚拟网卡的文件描述符 @@ -50,7 +51,7 @@ func newTun(name string) (int, error) { } // TAP工作在第三层 -func newTap(name string) (int, error){ +func newTap(name string) (int, error) { return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI) } @@ -62,15 +63,15 @@ func open(name string, flags uint16) (int, error) { } var ifr struct { - name [16]byte + name [16]byte flags uint16 - _ [22]byte + _ [22]byte } copy(ifr.name[:], name) ifr.flags = flags // 通过ioctl系统调用 将fd和虚拟网卡驱动绑定在一起 - _, _ , errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), + _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr))) if errno != 0 { syscall.Close(fd) @@ -101,12 +102,63 @@ func SetRoute(name, cidr string) (err error) { return } +// SetBridge 开启并设置网桥 通过网桥进行通信 +func SetBridge(bridge, tap, addr string) (err error) { + // ip link add br0 type bridge + out, cmdErr := exec.Command("ip", "link", "add", bridge, "type", "bridge").CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + log.Println(err) + } + out, cmdErr = exec.Command("ip", "link", "set", "dev", bridge, "up").CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + log.Println(err) + } + // ifconfig br0 192.168.1.66 netmask 255.255.255.0 up + out, cmdErr = exec.Command("ifconfig", bridge, addr, "netmask", "255.255.255.0", "up").CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + log.Println(err) + } + // ip link seteth0 master br0 + out, cmdErr = exec.Command("ip", "link", "set", "eth0", "master", bridge).CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + log.Println(err) + } + // ip link set tap0 master br0 + out, cmdErr = exec.Command("ip", "link", "set", tap, "master", bridge).CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + log.Println(err) + } + return +} + +func RemoveBridge(bridge string) (err error) { + + out, cmdErr := exec.Command("ip", "link", "set", "dev", bridge, "down").CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + log.Println(err) + } + + // ip link add br0 type bridge + out, cmdErr = exec.Command("ip", "link", "del", bridge, "type", "bridge").CombinedOutput() + if cmdErr != nil { + err = fmt.Errorf("%v:%v", cmdErr, string(out)) + log.Println(err) + } + return +} + // AddIP 通过ip命令添加IP地址 func AddIP(name, ip string) (err error) { // ip addr add 192.168.1.1 dev tap0 out, cmdErr := exec.Command("ip", "addr", "add", ip, "dev", name).CombinedOutput() if cmdErr != nil { - err = fmt.Errorf("%v:%v",cmdErr, string(out)) + err = fmt.Errorf("%v:%v", cmdErr, string(out)) return } return @@ -123,7 +175,7 @@ func GetHardwareAddr(name string) (string, error) { var ifreq struct { name [16]byte addr rawSockaddr - _ [8]byte + _ [8]byte } copy(ifreq.name[:], name) diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go index 73c8b48..d4f6111 100644 --- a/tcpip/network/arp/arp.go +++ b/tcpip/network/arp/arp.go @@ -79,7 +79,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { // 倒置目标与源 作为回应 copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) - log.Println("处理注入的ARP请求 这里将返回一个ARP报文作为响应") + log.Println("处理注入的ARP请求 这里将返回一个ARP报文作为响应", tcpip.LinkAddress(pkt.HardwareAddressTarget())) e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) // 往链路层写回消息 // 注意这里的 fallthrough 表示需要继续执行下面分支的代码 // 当收到 arp 请求需要添加到链路地址缓存中 @@ -87,7 +87,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { case header.ARPReply: // 这里记录ip和mac对应关系,也就是arp表 addr := tcpip.Address(h.ProtocolAddressSender()) - linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) // 记录远端机的MAC地址 e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) default: panic(tcpip.ErrUnknownProtocol) diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go index 3c41dbd..efc42fd 100644 --- a/tcpip/network/ipv4/ipv4.go +++ b/tcpip/network/ipv4/ipv4.go @@ -108,7 +108,7 @@ func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload b // 写入网卡接口 if protocol == header.ICMPv4ProtocolNumber { - log.Printf("IP 写回ICMP报文 长度: %d\n", hdr.UsedLength()+payload.Size()) + log.Println("IP 写回ICMP报文", header.IPv4(append(ip, payload.ToView()...))) } else { //log.Printf("send ipv4 packet %d bytes, proto: 0x%x", hdr.UsedLength()+payload.Size(), protocol) log.Println(header.IPv4(append(ip, payload.ToView()...))) @@ -132,7 +132,6 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { tlen := int(h.TotalLength()) vv.TrimFront(hlen) vv.CapLength(tlen - hlen) - log.Println(hlen, tlen) // 报文重组 more := (h.Flags() & header.IPv4FlagMoreFragments) != 0 @@ -157,7 +156,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { } r.Stats().IP.PacketsDelivered.Increment() // 根据协议分发到不同处理函数,比如协议时TCP,会进入tcp.HandlePacket - log.Printf("recv ipv4 packet %d bytes, proto: 0x%x", tlen, p) + log.Printf("准备前往 UDP/TCP recv ipv4 packet %d bytes, proto: 0x%x", tlen, p) e.dispatcher.DeliverTransportPacket(r, p, vv) } diff --git a/tcpip/stack/linkaddrcache.go b/tcpip/stack/linkaddrcache.go index 15bd4c6..5393c55 100644 --- a/tcpip/stack/linkaddrcache.go +++ b/tcpip/stack/linkaddrcache.go @@ -202,7 +202,7 @@ func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress // get reports any known link address for k. func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { - log.Printf("link addr get linkRes: %#v, addr: %+v", linkRes, k) + log.Println("在arp本地缓存中寻找", k) if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { return addr, nil, nil diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go index 21d40ed..9eab217 100644 --- a/tcpip/stack/nic.go +++ b/tcpip/stack/nic.go @@ -302,6 +302,23 @@ func (n *NIC) Subnets() []tcpip.Subnet { return append(sns, n.subnets...) } +// RemoveAddress removes an address from n. +func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + r := n.endpoints[NetworkEndpointID{addr}] + if r == nil || !r.holdsInsertRef { + n.mu.Unlock() + return tcpip.ErrBadLocalAddress + } + + r.holdsInsertRef = false + n.mu.Unlock() + + r.decRef() + + return nil +} + // DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket,用来分发网络层数据包。 // 比如 protocol 是 arp 协议号,那么会找到arp.HandlePacket来处理数据报。 // 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它, @@ -323,7 +340,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLin return } src, dst := netProto.ParseAddresses(vv.First()) - log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte { + log.Printf("网卡[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte { if len(vv.ToView()) > 64 { return vv.ToView()[:64] } @@ -334,6 +351,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLin if ref := n.getRef(protocol, dst); ref != nil { r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) r.RemoteLinkAddress = remoteLinkAddr + log.Println("准备前往 IP 将本地和远端的MAC、IP 保存在路由中 以便协议栈使用", + r.LocalLinkAddress, r.RemoteLinkAddress, r.LocalAddress, r.RemoteAddress) ref.ep.HandlePacket(&r, vv) ref.decRef() return @@ -377,7 +396,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r n.mu.RLock() if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - log.Println("找到了目标网络层实现: ", id.LocalAddress) + log.Println("找到了目标网络端(绑定过的IP): ", id.LocalAddress) n.mu.RUnlock() return ref } @@ -434,7 +453,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN n.stack.stats.MalformedRcvdPackets.Increment() return } - log.Println("准备分发传输层数据报", n.stack.transportProtocols, srcPort, dstPort) + log.Println("网卡准备分发传输层数据报", n.stack.transportProtocols, srcPort, dstPort) id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} // 调用分流器,根据传输层协议和传输层id分发数据报文 if n.demux.deliverPacket(r, protocol, vv, id) { @@ -480,7 +499,6 @@ type referencedNetworkEndpoint struct { // linkCache is set if link address resolution is enabled for this // protocol. Set to nil otherwise. linkCache LinkAddressCache - linkAddrCache // holdsInsertRef is protected by the NIC's mutex. It indicates whether // the reference count is biased by 1 due to the insertion of the diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go index 69831fb..16fe071 100644 --- a/tcpip/stack/registration.go +++ b/tcpip/stack/registration.go @@ -49,7 +49,7 @@ type LinkEndpoint interface { IsAttached() bool } -// LinkAddressResolver 是对可以解析链接地址的 NetworkProtocol 的扩展 TODO 需要解读 +// LinkAddressResolver 是对可以解析链接地址的 NetworkProtocol 的扩展 其实就是ARP type LinkAddressResolver interface { LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go index b4ead51..b84322b 100644 --- a/tcpip/stack/route.go +++ b/tcpip/stack/route.go @@ -4,6 +4,7 @@ import ( "netstack/sleep" "netstack/tcpip" "netstack/tcpip/buffer" + "netstack/tcpip/header" ) // 贯穿整个协议栈的路由,也就是在链路层和网络层都可以路由 @@ -57,11 +58,48 @@ func (r *Route) Stats() tcpip.Stats { return r.ref.nic.stack.Stats() } +// PseudoHeaderChecksum forwards the call to the network endpoint's +// implementation. +// udp或tcp伪首部校验和的计算 +func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 { + return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress) +} + // Capabilities returns the link-layer capabilities of the route. func (r *Route) Capabilities() LinkEndpointCapabilities { return r.ref.ep.Capabilities() } +// Resolve 如有必要,解决尝试解析链接地址的问题。如果地址解析需要阻塞,则返回ErrWouldBlock, +// 例如等待ARP回复。地址解析完成(成功与否)时通知Waker。 +// 如果需要地址解析,则返回ErrNoLinkAddress和通知通道,以阻止顶级调用者。 +// 地址解析完成后,通道关闭(不管成功与否)。 +func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { + if !r.IsResolutionRequired() { + // Nothing to do if there is no cache (which does the resolution on cache miss) or + // link address is already known. + return nil, nil + } + + nextAddr := r.NextHop + if nextAddr == "" { + // Local link address is already known. + if r.RemoteAddress == r.LocalAddress { // 发给自己 + r.RemoteLinkAddress = r.LocalLinkAddress // MAC 就是自己 + return nil, nil + } + nextAddr = r.RemoteAddress // 下一跳是远端机 + } + + // 调用地址解析协议来解析IP地址 + linkAddr, ch, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + if err != nil { + return ch, err + } + r.RemoteLinkAddress = linkAddr + return nil, nil +} + // RemoveWaker removes a waker that has been added in Resolve(). func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr := r.NextHop diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go index 9e09d0d..b27b530 100644 --- a/tcpip/stack/stack.go +++ b/tcpip/stack/stack.go @@ -38,7 +38,7 @@ type transportProtocolState struct { type Stack struct { transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState // 各种传输层协议 networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol // 各种网络层协议 - linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver // 各种链接解析器 + linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver // 支持链接层反向解析的网络层协议 demux *transportDemuxer // 传输层的复用器 @@ -103,6 +103,12 @@ func New(network []string, transport []string, opts Options) *Stack { } netProto := netProtoFactory() // 制造一个该型号协议的示实例 s.networkProtocols[netProto.Number()] = netProto // 注册该型号的网络协议 + // 判断该协议是否支持链路层地址解析协议接口,如果支持添加到 s.linkAddrResolvers 中, + // 如:ARP协议会添加 IPV4-ARP 的对应关系 + // 后面需要地址解析协议的时候会更改网络层协议号来找到相应的地址解析协议 + if r, ok := netProto.(LinkAddressResolver); ok { + s.linkAddrResolvers[r.LinkAddressProtocol()] = r // 其实就是说: 声明arp支持地址解析 + } } // 添加指定的传输层协议 必已经在init中注册过 @@ -335,6 +341,19 @@ func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpi return false, tcpip.ErrUnknownNICID } +// RemoveAddress removes an existing network-layer address from the specified +// NIC. +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + return nic.RemoveAddress(addr) + } + + return tcpip.ErrUnknownNICID +} + // FindRoute 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息 func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) { @@ -354,7 +373,7 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, var ref *referencedNetworkEndpoint if len(localAddr) != 0 { - ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) + ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) // 找到绑定LocalAddr的IP端 } else { ref = nic.primaryEndpoint(netProto) } @@ -426,7 +445,7 @@ func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, } s.mu.RUnlock() - fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} // addr 可能是Remote IP Address // 根据网络层协议号找到对应的地址解析协议 linkRes := s.linkAddrResolvers[protocol] return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, w) @@ -497,3 +516,14 @@ func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) Tra } return nil } + +// JoinGroup joins the given multicast group on the given NIC. +func (s *Stack) JoinGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { + // TODO: notify network of subscription via igmp protocol. + return s.AddAddressWithOptions(nicID, protocol, multicastAddr, NeverPrimaryEndpoint) +} + +// LeaveGroup leaves the given multicast group on the given NIC. +func (s *Stack) LeaveGroup(protocol tcpip.NetworkProtocolNumber, nicID tcpip.NICID, multicastAddr tcpip.Address) *tcpip.Error { + return s.RemoveAddress(nicID, multicastAddr) +} diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go index e3c573b..d7b51a9 100644 --- a/tcpip/tcpip.go +++ b/tcpip/tcpip.go @@ -355,6 +355,56 @@ type WriteOptions struct { EndOfRecord bool } +// ErrorOption is used in GetSockOpt to specify that the last error reported by +// the endpoint should be cleared and returned. +type ErrorOption struct{} + +// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6 +// socket is to be restricted to sending and receiving IPv6 packets only. +type V6OnlyOption int + +// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send +// buffer size option. +type SendBufferSizeOption int + +// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the +// receive buffer size option. +type ReceiveBufferSizeOption int + +// SendQueueSizeOption is used in GetSockOpt to specify that the number of +// unread bytes in the output buffer should be returned. +type SendQueueSizeOption int + +// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of +// unread bytes in the input buffer should be returned. +type ReceiveQueueSizeOption int + +// TimestampOption is used by SetSockOpt/GetSockOpt to specify whether +// SO_TIMESTAMP socket control messages are enabled. +type TimestampOption int + +// MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default +// TTL value for multicast messages. The default is 1. +type MulticastTTLOption uint8 + +// MembershipOption is used by SetSockOpt/GetSockOpt as an argument to +// AddMembershipOption and RemoveMembershipOption. +type MembershipOption struct { + NIC NICID + InterfaceAddr Address + MulticastAddr Address +} + +// RemoveMembershipOption is used by SetSockOpt/GetSockOpt to leave a multicast +// group identified by the given multicast address, on the interface matching +// the given interface address. +type RemoveMembershipOption MembershipOption + +// AddMembershipOption is used by SetSockOpt/GetSockOpt to join a multicast +// group identified by the given multicast address, on the interface matching +// the given interface address. +type AddMembershipOption MembershipOption + type Route struct { Destination Address // 目标地址 Mask AddressMask // 掩码 diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go index a3a75c2..1f81250 100644 --- a/tcpip/transport/udp/endpoint.go +++ b/tcpip/transport/udp/endpoint.go @@ -2,6 +2,8 @@ package udp import ( "log" + "math" + "netstack/sleep" "netstack/tcpip" "netstack/tcpip/buffer" "netstack/tcpip/header" @@ -35,9 +37,8 @@ const ( type endpoint struct { stack *stack.Stack // udp所依赖的用户协议栈 netProto tcpip.NetworkProtocolNumber // udp网络协议号 ipv4/ipv6 - waiterQueue *waiter.Queue // TODO 需要解析 + waiterQueue *waiter.Queue // 事件驱动机制 - // TODO 需要解析 // The following fields are used to manage the receive queue, and are // protected by rcvMu. rcvMu sync.Mutex @@ -130,8 +131,29 @@ func (e *endpoint) Close() { e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) } - // TODO + for _, mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + // 清空接收链表 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = stateClosed + e.mu.Unlock() + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -167,8 +189,188 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil } -func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - return 0, nil, nil +// sendUDP sends a UDP segment via the provided network endpoint and under the +// provided identity. +// 增加UDP头部信息,并发送给给网络层 +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { + // Allocate a buffer for the UDP header. + hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + + // Initialize the header. + udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + + // 得到报文的长度 + length := uint16(hdr.UsedLength() + data.Size()) + // UDP首部的编码 + udp.Encode(&header.UDPFields{ + SrcPort: localPort, + DstPort: remotePort, + Length: length, + }) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { + // 检验和的计算 + xsum := r.PseudoHeaderChecksum(ProtocolNumber) + for _, v := range data.Views() { + xsum = header.Checksum(v, xsum) + } + udp.SetChecksum(^udp.CalculateChecksum(xsum, length)) + } + + // Track count of packets sent. + r.Stats().UDP.PacketsSent.Increment() + + // 将准备好的UDP首部和数据写给网络层 + log.Printf("send udp %d bytes", hdr.UsedLength()+data.Size()) + return r.WritePacket(hdr, data, ProtocolNumber, ttl) +} + +// 写数据之前的准备,如果还是初始状态需要先进性绑定操作。 +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case stateInitial: + case stateConnected: + return false, nil + + case stateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != stateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil { + return false, err + } + + return true, nil +} + +// Write 用户层最终调用该函数,发送数据包给对端,即使数据写失败,也不会阻塞。 +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + // NOTE 如果报文长度超过65535,将会超过UDP最大的长度表示,这是不允许的。 + if p.Size() > math.MaxUint16 { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + log.Println("UDP 准备向 路由", to, "写入数据") + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + // 如果设置了关闭写数据,那返回错误 + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, nil, tcpip.ErrClosedForSend + } + + var route *stack.Route + var dstPort uint16 + if to == nil { + // 如果没有指定发送的地址,用UDP端 Connect 得到的路由和目的端口 + route = &e.route + dstPort = e.dstPort + + if route.IsResolutionRequired() { + // Promote lock to exclusive if using a shared route, given that it may need to + // change in Route.Resolve() call below. + // 如果使用共享路由,则将锁定提升为独占路由,因为它可能需要在下面的Route.Resolve()调用中进行更改。 + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // Recheck state after lock was re-acquired. + // 锁定后重新检查状态。 + if e.state != stateConnected { + return 0, nil, tcpip.ErrInvalidEndpointState + } + } + } else { // 如果指定了发送地址和端口 + nicid := to.NIC + // 如果绑定了网卡 + if e.bindNICID != 0 { + if nicid != 0 && nicid != e.bindNICID { + return 0, nil, tcpip.ErrNoRoute // 指定了网卡但udp端没绑定这张网卡 + } + nicid = e.bindNICID // 如果没指定网卡就用这张绑定过的网卡 + } + // 得到目的IP+端口 + toCopy := *to + to = &toCopy + netProto, err := e.checkV4Mapped(to, false) + if err != nil { + return 0, nil, err + } + // Find the enpoint. + // 根据目的地址和协议找到相关路由信息 + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, to.Addr, netProto) + if err != nil { + return 0, nil, err + } + defer r.Release() + + route = &r + dstPort = to.Port + } + + // TODO + // 如果路由没有下一跳的链路MAC地址,那么触发相应的机制,来填充该路由信息。 + // 比如:IPV4协议,如果没有目的IP对应的MAC信息,从从ARP缓存中查找信息,找到了直接返回, + // 若没找到,那么发送ARP请求,得到对应的MAC地址。 + if route.IsResolutionRequired() { + waker := &sleep.Waker{} + log.Println("发起arp广播(如果目标是255.255.255.255)或者在本地arp缓存来寻找目标主机 目标路由为", to, route.RemoteAddress) + if ch, err := route.Resolve(waker); err != nil { + if err == tcpip.ErrWouldBlock { + // Link address needs to be resolved. Resolution was triggered the background. + // Better luck next time. + route.RemoveWaker(waker) + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + // 得到要发送的数据内容 + v, err := p.Get(p.Size()) + if err != nil { + return 0, nil, err + } + + ttl := route.DefaultTTL() + // 如果是多播地址,设置ttl + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { + ttl = e.multicastTTL + } + + // 增加UDP头部信息,并发送出去 + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { + return 0, nil, err + } + + return uintptr(len(v)), nil, nil } func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { @@ -409,23 +611,183 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcp } func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + }, nil } func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { - return tcpip.FullAddress{}, nil + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + }, nil } func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { - return waiter.EventErr + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result } func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v != 0 + + case tcpip.TimestampOption: + e.rcvMu.Lock() + e.rcvTimestamp = v != 0 + e.rcvMu.Unlock() + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() + + case tcpip.AddMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + + e.multicastMemberships = append(e.multicastMemberships, multicastMembership{nicID, v.MulticastAddr}) + + case tcpip.RemoveMembershipOption: + nicID := v.NIC + if v.InterfaceAddr != header.IPv4Any { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrNoRoute + } + + // TODO: check that v.MulticastAddr is a multicast address. + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.mu.Lock() + defer e.mu.Unlock() + for i, mem := range e.multicastMemberships { + if mem.nicID == nicID && mem.multicastAddr == v.MulticastAddr { + // Only remove the first match, so that each added membership above is + // paired with exactly 1 removal. + e.multicastMemberships[i] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + break + } + } + } return nil } func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - return nil + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + e.mu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + e.rcvMu.Unlock() + return nil + + case *tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.ReceiveQueueSizeOption: + e.rcvMu.Lock() + if e.rcvList.Empty() { + *o = 0 + } else { + p := e.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + e.rcvMu.Unlock() + return nil + + case *tcpip.TimestampOption: + e.rcvMu.Lock() + *o = 0 + if e.rcvTimestamp { + *o = 1 + } + e.rcvMu.Unlock() + + case *tcpip.MulticastTTLOption: + e.mu.Lock() + *o = tcpip.MulticastTTLOption(e.multicastTTL) + e.mu.Unlock() + return nil + } + + return tcpip.ErrUnknownProtocolOption } // HandlePacket 从网络层接收到UDP数据报时的处理函数 @@ -482,7 +844,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv } e.rcvMu.Unlock() - // TODO 通知用户层可以读取数据了 + // NOTE 通知用户层可以读取数据了 if wasEmpty { e.waiterQueue.Notify(waiter.EventIn) }