diff --git a/tcpip/buffer/prependable.go b/tcpip/buffer/prependable.go index e2247ca..d7f765f 100644 --- a/tcpip/buffer/prependable.go +++ b/tcpip/buffer/prependable.go @@ -23,6 +23,7 @@ func (p Prependable) UsedLength() int { return len(p.buf) - p.usedIdx } +// 从内到外取出报文头的协议 func (p *Prependable) Prepend(size int) []byte { if size > p.usedIdx { return nil diff --git a/tcpip/buffer/view.go b/tcpip/buffer/view.go index aa43602..e78aef3 100644 --- a/tcpip/buffer/view.go +++ b/tcpip/buffer/view.go @@ -1,6 +1,5 @@ package buffer - type View []byte func NewView(size int) View { @@ -21,14 +20,14 @@ func (v *View) CapLength(length int) { *v = (*v)[:length:length] } -func (v View) ToVectoriseView() VectorisedView { +func (v View) ToVectorisedView() VectorisedView { return NewVectorisedView(len(v), []View{v}) } // VectorisedView 是使用非连续内存的 View 的矢量化版本 type VectorisedView struct { views []View - size int + size int } func NewVectorisedView(size int, views []View) VectorisedView { diff --git a/tcpip/header/arp.go b/tcpip/header/arp.go index a747f3a..faaef1e 100644 --- a/tcpip/header/arp.go +++ b/tcpip/header/arp.go @@ -65,6 +65,12 @@ func (a ARP) SetIPv4OverEthernet() { a[5] = uint8(IPv4AddressSize) } +// HardwareAddressSender从报文中得到arp发送方的硬件地址 +func (a ARP) HardwareAddressSender() []byte { + const s = 8 + return a[s : s+6] +} + // ProtocolAddressSender从报文中得到arp发送方的协议地址,为ipv4地址 func (a ARP) ProtocolAddressSender() []byte { const s = 8 + 6 // 8 是arp的协议头部 6是本机MAC diff --git a/tcpip/link/channel/channel.go b/tcpip/link/channel/channel.go index a5634da..aa5c2c0 100644 --- a/tcpip/link/channel/channel.go +++ b/tcpip/link/channel/channel.go @@ -49,6 +49,7 @@ func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.Vector // InjectLinkAddr injects an inbound packet with a remote link address. func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) { + // 这里的实现在NIC.go中 由 网卡对象进行数据分发 e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, "" /* localLinkAddr */, protocol, vv.Clone(nil)) } diff --git a/tcpip/link/channel/stack.RegisterLinkEndpoint b/tcpip/link/channel/stack.RegisterLinkEndpoint deleted file mode 100644 index e69de29..0000000 diff --git a/tcpip/link/fdbased/endpoint_test.go b/tcpip/link/fdbased/endpoint_test.go index 3371f81..a1f5180 100644 --- a/tcpip/link/fdbased/endpoint_test.go +++ b/tcpip/link/fdbased/endpoint_test.go @@ -2,36 +2,36 @@ package fdbased import ( "fmt" - "reflect" - "time" "math/rand" "netstack/tcpip" "netstack/tcpip/buffer" "netstack/tcpip/header" "netstack/tcpip/stack" + "reflect" "syscall" "testing" + "time" ) const ( - mtu = 1500 + mtu = 1500 laddr = tcpip.LinkAddress("\x65\x66\x67\x68\x69\x70") raddr = tcpip.LinkAddress("\x71\x72\x73\x74\x75\x76") proto = 10 ) type packetInfo struct { - raddr tcpip.LinkAddress - proto tcpip.NetworkProtocolNumber + raddr tcpip.LinkAddress + proto tcpip.NetworkProtocolNumber contents buffer.View } type context struct { - t *testing.T - fds [2]int - ep stack.LinkEndpoint - ch chan packetInfo // 信道 - done chan struct{} // 通知退出 + t *testing.T + fds [2]int + ep stack.LinkEndpoint + ch chan packetInfo // 信道 + done chan struct{} // 通知退出 } func newContext(t *testing.T, opt *Options) *context { @@ -49,10 +49,10 @@ func newContext(t *testing.T, opt *Options) *context { ep := stack.FindLinkEndpoint(New(opt)).(*endpoint) // 找到端口实现 c := &context{ - t: t, - fds: fds, - ep:ep, - ch: make(chan packetInfo, 100), + t: t, + fds: fds, + ep: ep, + ch: make(chan packetInfo, 100), done: done, } @@ -79,7 +79,7 @@ func TestFdbased(t *testing.T) { // Build header hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) // 114 - b := hdr.Prepend(100) // payload + b := hdr.Prepend(100) // payload for i := range b { b[i] = uint8(rand.Intn(256)) } @@ -91,7 +91,7 @@ func TestFdbased(t *testing.T) { } if err := c.ep.WritePacket(&stack.Route{RemoteLinkAddress: raddr}, hdr, - payload.ToVectoriseView(), proto); err != nil { + payload.ToVectorisedView(), proto); err != nil { panic(err) } diff --git a/tcpip/network/arp/arp.go b/tcpip/network/arp/arp.go index 1a5e7ce..bb09fbd 100644 --- a/tcpip/network/arp/arp.go +++ b/tcpip/network/arp/arp.go @@ -1,7 +1,172 @@ // 主机的链路层寻址是通过 arp 表来实现的 package arp -const ( - ProtocolName = "arp" - ProtocolNumber = "arp" +import ( + "log" + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/stack" ) + +const ( + ProtocolName = "arp" + ProtocolNumber = header.ARPProtocolNumber + ProtocolAddress = tcpip.Address("arp") +) + +// arp endpoint 一个网络层的实现 Implement stack.NetworkEndpoint +type endpoint struct { + nicid tcpip.NICID // arp报文使用的网卡 + addr tcpip.Address // 网络层地址 + linkEP stack.LinkEndpoint // MAC + linkAddrCache stack.LinkAddressCache // 链路高速缓存 +} + +func (e *endpoint) DefaultTTL() uint8 { + return 0 +} + +func (e *endpoint) MTU() uint32 { + lmtu := e.linkEP.MTU() + return lmtu - uint32(e.MaxHeaderLength()) +} + +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &stack.NetworkEndpointID{LocalAddress: ProtocolAddress} +} + +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.ARPSize +} + +// arp不支持写包 +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// arp数据包的处理,包括arp请求和响应 +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { + v := vv.First() + h := header.ARP(v) + if !h.IsValid() { + return + } + + // 判断操作码类型 + switch h.Op() { + case header.ARPRequest: + // 如果是ARP请求 + localAddr := tcpip.Address(h.ProtocolAddressTarget()) + if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 { + return // 无效的ARP请求 + } + + // arp报文所在的网卡绑定了这个地址 + hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) // 以太 + ARP + pkt := header.ARP(hdr.Prepend(header.ARPSize)) // 取出 ARP + pkt.SetIPv4OverEthernet() + pkt.SetOp(header.ARPReply) + copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) // 写入本机MAC作为响应 NOTE + // 倒置目标与源 作为回应 + copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget()) + copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender()) + log.Println("处理注入的ARP请求 这里将返回一个ARP报文作为响应") + e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) // 往链路层写回消息 + // 注意这里的 fallthrough 表示需要继续执行下面分支的代码 + // 当收到 arp 请求需要添加到链路地址缓存中 + fallthrough // also fill the cache from requests + case header.ARPReply: + // 这里记录ip和mac对应关系,也就是arp表 + addr := tcpip.Address(h.ProtocolAddressSender()) + linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) + e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr) + default: + panic(tcpip.ErrUnknownProtocol) + } +} + +func (e *endpoint) Close() {} + +// 实现了 stack.NetworkProtocol 和 stack.LinkAddressResolver 两个接口 +type protocol struct{} + +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, + dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + if addr != ProtocolAddress { + return nil, tcpip.ErrBadLocalAddress + } + return &endpoint{ + nicid: nicid, + addr: addr, + linkEP: linkEP, + linkAddrCache: linkAddrCache, + }, nil +} + +func (p *protocol) MinimumPacketSize() int { + return header.ARPSize +} + +func (p *protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + h := header.ARP(v) + return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress +} + +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// LinkAddressProtocol implements stack.LinkAddressResolver. +func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return header.IPv4ProtocolNumber +} + +// LinkAddressRequest implements stack.LinkAddressResolver. +func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error { + r := &stack.Route{ + RemoteLinkAddress: broadcastMAC, + } + + hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize) + h := header.ARP(hdr.Prepend(header.ARPSize)) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) + copy(h.HardwareAddressSender(), linkEP.LinkAddress()) + copy(h.ProtocolAddressSender(), localAddr) + copy(h.ProtocolAddressTarget(), addr) + + return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) +} + +// ResolveStaticAddress implements stack.LinkAddressResolver. +func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "\xff\xff\xff\xff" { + return broadcastMAC, true + } + return "", false +} + +var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff}) + +func init() { + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/tcpip/network/arp/arp_test.go b/tcpip/network/arp/arp_test.go index 625a510..71ec148 100644 --- a/tcpip/network/arp/arp_test.go +++ b/tcpip/network/arp/arp_test.go @@ -1 +1,134 @@ package arp_test + +import ( + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/link/channel" + "netstack/tcpip/network/arp" + "netstack/tcpip/network/ipv4" + "netstack/tcpip/stack" + "testing" + "time" +) + +const ( + stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") // 0a:0a:0b:0b:0c:0c + stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") // 10.0.0.1 + stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") // 10.0.0.2 + stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") // 10.0.0.3 +) + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack +} + +func newTestContext(t *testing.T) *testContext { + s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, nil, stack.Options{}) + + const defaultMTU = 65536 + id, linkEP := channel.New(256, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil { + t.Fatalf("AddAddress for ipv4 failed: %v", err) + } + if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil { + t.Fatalf("AddAddress for arp failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{ + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }}) + + return &testContext{ + t: t, + s: s, + linkEP: linkEP, + } +} + +func (c *testContext) cleanup() { + close(c.linkEP.C) +} + +func TestArpBase(t *testing.T) { + c := newTestContext(t) + defer c.cleanup() + + const senderMAC = "\x01\x02\x03\x04\x05\x06" + const senderIPv4 = "\x0a\x00\x00\x02" + + v := make(buffer.View, header.ARPSize) + h := header.ARP(v) + h.SetIPv4OverEthernet() + h.SetOp(header.ARPRequest) // 一个ARP请求 + copy(h.HardwareAddressSender(), senderMAC) // Local MAC + copy(h.ProtocolAddressSender(), senderIPv4) // Local IP + + inject := func(addr tcpip.Address) { + copy(h.ProtocolAddressTarget(), addr) + c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) // 往链路层注入一个arp报文 链路层将会自动分发它 + } + + inject(stackAddr1) // target IP 10.0.0.1 + select { + case pkt := <-c.linkEP.C: + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 { + t.Errorf("stackAddr1: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("Case #1 Time Out\n") + } + + inject(stackAddr2) + select { + case pkt := <-c.linkEP.C: + if pkt.Proto != arp.ProtocolNumber { + t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto) + } + rep := header.ARP(pkt.Header) + if !rep.IsValid() { + t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header)) + } + if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 { + t.Errorf("stackAddr2: expected sender to be set") + } + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got) + } + + case <-time.After(100 * time.Millisecond): + t.Fatalf("Case #2 Time Out\n") + } + + inject(stackAddrBad) + select { + case pkt := <-c.linkEP.C: + t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto) + case <-time.After(100 * time.Millisecond): + // Sleep tests are gross, but this will only potentially flake + // if there's a bug. If there is no bug this will reliably + // succeed. + } +} diff --git a/tcpip/network/ipv4/ipv4.go b/tcpip/network/ipv4/ipv4.go new file mode 100644 index 0000000..913c73a --- /dev/null +++ b/tcpip/network/ipv4/ipv4.go @@ -0,0 +1,153 @@ +package ipv4 + +import ( + "netstack/tcpip" + "netstack/tcpip/buffer" + "netstack/tcpip/header" + "netstack/tcpip/stack" +) + +const ( + // ProtocolName is the string representation of the ipv4 protocol name. + ProtocolName = "ipv4" + + // ProtocolNumber is the ipv4 protocol number. + ProtocolNumber = header.IPv4ProtocolNumber + + // maxTotalSize is maximum size that can be encoded in the 16-bit + // TotalLength field of the ipv4 header. + maxTotalSize = 0xffff + + // buckets is the number of identifier buckets. + buckets = 2048 +) + +// IPv4 实现 +type endpoint struct { + // 网卡id + nicid tcpip.NICID + // 表示该endpoint的id,也是ip地址 + id stack.NetworkEndpointID + // 链路端的表示 + linkEP stack.LinkEndpoint + // TODO 需要添加 +} + +// DefaultTTL is the default time-to-live value for this endpoint. +// 默认的TTL值,TTL每经过路由转发一次就会减1 +func (e *endpoint) DefaultTTL() uint8 { + return 255 +} + +// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus +// the network layer max header length. +// 获取去除ipv4头部后的最大报文长度 +func (e *endpoint) MTU() uint32 { + return calculateMTU(e.linkEP.MTU()) +} + +// Capabilities implements stack.NetworkEndpoint.Capabilities. +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return e.linkEP.Capabilities() +} + +// NICID returns the ID of the NIC this endpoint belongs to. +func (e *endpoint) NICID() tcpip.NICID { + return e.nicid +} + +// ID returns the ipv4 endpoint ID. +// 获取该网络层端的id,也就是ip地址 +func (e *endpoint) ID() *stack.NetworkEndpointID { + return &e.id +} + +// MaxHeaderLength returns the maximum length needed by ipv4 headers (and +// underlying protocols). +// 链路层和网络层的头部长度 +func (e *endpoint) MaxHeaderLength() uint16 { + return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize +} + +// WritePacket writes a packet to the given destination address and protocol. +// 将传输层的数据封装加上IP头,并调用网卡的写入接口,写入IP报文 +func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, + protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + return nil +} + +// HandlePacket is called by the link layer when new ipv4 packets arrive for +// this endpoint. +// 收到ip包的处理 +func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { +} + +// Close cleans up resources associated with the endpoint. +func (e *endpoint) Close() { +} + +// 实现NetworkProtocol接口 +type protocol struct{} + +// NewEndpoint creates a new ipv4 endpoint. +// 根据参数,新建一个ipv4端 +func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, + dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + e := &endpoint{ + nicid: nicid, + id: stack.NetworkEndpointID{LocalAddress: addr}, + linkEP: linkEP, + } + + return e, nil +} + +// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported +// only for tests that short-circuit the stack. Regular use of the protocol is +// done via the stack, which gets a protocol descriptor from the init() function +// below. +func NewProtocol() stack.NetworkProtocol { + return &protocol{} +} + +// Number returns the ipv4 protocol number. +func (p *protocol) Number() tcpip.NetworkProtocolNumber { + return ProtocolNumber +} + +// MinimumPacketSize returns the minimum valid ipv4 packet size. +func (p *protocol) MinimumPacketSize() int { + return header.IPv4MinimumSize +} + +// ParseAddresses implements NetworkProtocol.ParseAddresses. +func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + //h := header.IPv4(v) + //return h.SourceAddress(), h.DestinationAddress() + return "", "" +} + +// SetOption implements NetworkProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements NetworkProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// calculateMTU calculates the network-layer payload MTU based on the link-layer +// payload mtu. +func calculateMTU(mtu uint32) uint32 { + if mtu > maxTotalSize { + mtu = maxTotalSize + } + return mtu - header.IPv4MinimumSize +} + +func init() { + stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol { + return &protocol{} + }) +} diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go index a78aec8..81c133e 100644 --- a/tcpip/stack/nic.go +++ b/tcpip/stack/nic.go @@ -44,7 +44,7 @@ type NIC struct { demux *transportDemuxer mu sync.RWMutex - spoofing bool + spoofing bool // 欺骗 promiscuous bool // 混杂模式 primary map[tcpip.NetworkProtocolNumber]*ilist.List // 网络协议号:网络端实现 // 网络层端的记录 IP:网络端实现 @@ -95,13 +95,13 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip. log.Println("添加失败") return nil, tcpip.ErrUnknownProtocol } - log.Printf("基于[%d]协议 为 #%d 网卡 添加IP: %s\n", netProto.Number(), n.id, addr.String()) // 比如netProto是ipv4 会调用ipv4.NewEndpoint 新建一个网络端 ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP) if err != nil { return nil, err } + log.Printf("基于[%d]协议 为 #%d 网卡 添加网络层实现 并绑定地址到: %s\n", netProto.Number(), n.id, ep.ID().LocalAddress) // 获取网络层端的id 其实就是ip地址 id := *ep.ID() @@ -143,7 +143,6 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip. case FirstPrimaryEndpoint: l.PushFront(ref) } - log.Printf("Network Info: %v \n", ref.ep.ID()) return ref, nil } @@ -223,7 +222,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A n.mu.RLock() ref := n.endpoints[id] - if ref != nil && !ref.tryIncRef() { + if ref != nil && !ref.tryIncRef() { // 尝试去使用这个网络端实现 ref = nil } spoofing := n.spoofing @@ -309,7 +308,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r n.mu.RLock() if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - log.Println("找到了目标地址: ", id) + log.Println("找到了目标网络层实现: ", id.LocalAddress) n.mu.RUnlock() return ref } @@ -366,7 +365,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLin return } src, dst := netProto.ParseAddresses(vv.First()) - log.Printf("设备[%s]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, vv.ToView()) + log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, vv.ToView()) // 根据网络协议和数据包的目的地址,找到网络端 // 然后将数据包分发给网络层 if ref := n.getRef(protocol, dst); ref != nil { diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go index df7349f..01441f6 100644 --- a/tcpip/stack/stack.go +++ b/tcpip/stack/stack.go @@ -6,6 +6,16 @@ import ( "netstack/tcpip" "netstack/tcpip/ports" "sync" + "time" +) + +const ( + // ageLimit is set to the same cache stale time used in Linux. + ageLimit = 1 * time.Minute + // resolutionTimeout is set to the same ARP timeout used in Linux. + resolutionTimeout = 1 * time.Second + // resolutionAttempts is set to the same ARP retries used in Linux. + resolutionAttempts = 3 ) // TODO 需要解读 @@ -72,7 +82,7 @@ func New(network []string, transport []string, opts Options) *Stack { networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), - //linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), //PortManager: ports.NewPortManager(), clock: clock, stats: opts.Stats.FillIn(), @@ -257,19 +267,66 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, return Route{}, tcpip.ErrNoRoute } +// ===============本机链路层缓存实现================== +// 检查本地是否绑定过该网络层地址 func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { + s.mu.RLock() + defer s.mu.RUnlock() + + if nicid != 0 { + nic := s.nics[nicid] // 先拿到网卡 + if nic == nil { + return 0 + } + + ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) // 看看这张网卡是否绑定过这个地址 + if ref == nil { + return 0 + } + + ref.decRef() // 这个网络端实现使用结束 释放对它的占用 + + return nic.id + } + // Go through all the NICs. + for _, nic := range s.nics { + ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) + if ref != nil { + ref.decRef() + return nic.id + } + } return 0 } func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { - + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + s.linkAddrCache.add(fullAddr, linkAddr) } func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { - return "", nil, nil + s.mu.RLock() + // 获取网卡对象 + nic := s.nics[nicid] + if nic == nil { + s.mu.RUnlock() + return "", nil, tcpip.ErrUnknownNICID + } + s.mu.RUnlock() + + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + // 根据网络层协议号找到对应的地址解析协议 + linkRes := s.linkAddrResolvers[protocol] + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, w) } func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { + s.mu.RLock() + defer s.mu.RUnlock() + if nic := s.nics[nicid]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + s.linkAddrCache.removeWaker(fullAddr, waker) + } } diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go index c473557..df7d2ce 100644 --- a/tcpip/tcpip.go +++ b/tcpip/tcpip.go @@ -160,6 +160,14 @@ func (s *Subnet) Mask() AddressMask { // 它通常是一个 6 字节的 MAC 地址。 type LinkAddress string // MAC地址 +func (l LinkAddress) String() string { + if len(l) == 6 { + return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", l[0], l[1], l[2], l[3], l[4], l[5]) + } else { + return string(l) + } +} + type LinkEndpointID uint64 type TransportProtocolNumber uint32 @@ -249,7 +257,8 @@ func fillIn(v reflect.Value) { v := v.Field(i) switch v.Kind() { case reflect.Ptr: - if s, ok := v.Addr().Interface().(**StatCounter); ok { + x := v.Addr().Interface() + if s, ok := x.(**StatCounter); ok { if *s == nil { *s = &StatCounter{} } @@ -307,6 +316,6 @@ func (a Address) String() string { } return b.String() default: - return fmt.Sprintf("%x", []byte(a)) + return fmt.Sprintf("%s", string(a)) } }