diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go index 3313932..3ae7dbc 100644 --- a/tcpip/stack/nic.go +++ b/tcpip/stack/nic.go @@ -45,8 +45,8 @@ type NIC struct { mu sync.RWMutex spoofing bool - promiscuous bool // 混杂模式 - primary map[tcpip.NetworkProtocolNumber]*ilist.List + promiscuous bool // 混杂模式 + primary map[tcpip.NetworkProtocolNumber]*ilist.List // 网络协议号:网络端实现 // 网络层端的记录 IP:网络端实现 endpoints map[NetworkEndpointID]*referencedNetworkEndpoint // 子网的记录 @@ -189,6 +189,33 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } +// primaryEndpoint returns the primary endpoint of n for the given network +// protocol. +// 根据网络层协议号找到对应的网络层端 +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { + n.mu.RLock() + defer n.mu.RUnlock() + + list := n.primary[protocol] + if list == nil { + return nil + } + + for e := list.Front(); e != nil; e = e.Next() { + r := e.(*referencedNetworkEndpoint) + // TODO: allow broadcast address when SO_BROADCAST is set. + switch r.ep.ID().LocalAddress { + case header.IPv4Broadcast, header.IPv4Any: + continue + } + if r.tryIncRef() { + return r + } + } + + return nil +} + // 根据address参数找到对应的网络层端 func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { @@ -350,6 +377,10 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLin n.stack.stats.IP.InvalidAddressesReceived.Increment() } +func (n *NIC) ID() tcpip.NICID { + return n.id +} + // 网络端引用 type referencedNetworkEndpoint struct { ilist.Entry diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go index a5bf7f8..d75c854 100644 --- a/tcpip/stack/route.go +++ b/tcpip/stack/route.go @@ -1,6 +1,10 @@ package stack -import "netstack/tcpip" +import ( + "netstack/sleep" + "netstack/tcpip" + "netstack/tcpip/buffer" +) // 贯穿整个协议栈的路由,也就是在链路层和网络层都可以路由 // 如果目标地址是链路层地址,那么在链路层路由, @@ -37,3 +41,73 @@ func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip ref: ref, } } + +// NICID returns the id of the NIC from which this route originates. +func (r *Route) NICID() tcpip.NICID { + return r.ref.ep.NICID() +} + +// MaxHeaderLength forwards the call to the network endpoint's implementation. +func (r *Route) MaxHeaderLength() uint16 { + return r.ref.ep.MaxHeaderLength() +} + +// Stats returns a mutable copy of current stats. +func (r *Route) Stats() tcpip.Stats { + return r.ref.nic.stack.Stats() +} + +// Capabilities returns the link-layer capabilities of the route. +func (r *Route) Capabilities() LinkEndpointCapabilities { + return r.ref.ep.Capabilities() +} + +// RemoveWaker removes a waker that has been added in Resolve(). +func (r *Route) RemoveWaker(waker *sleep.Waker) { + nextAddr := r.NextHop + if nextAddr == "" { + nextAddr = r.RemoteAddress + } + r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) +} + +// IsResolutionRequired returns true if Resolve() must be called to resolve +// the link address before the this route can be written to. +func (r *Route) IsResolutionRequired() bool { + return r.ref.linkCache != nil && r.RemoteLinkAddress == "" +} + +// WritePacket writes the packet through the given route. +func (r *Route) WritePacket(hdr buffer.Prependable, payload buffer.VectorisedView, + protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + err := r.ref.ep.WritePacket(r, hdr, payload, protocol, ttl) + if err == tcpip.ErrNoRoute { + r.Stats().IP.OutgoingPacketErrors.Increment() + } + return err +} + +// DefaultTTL returns the default TTL of the underlying network endpoint. +func (r *Route) DefaultTTL() uint8 { + return r.ref.ep.DefaultTTL() +} + +// MTU returns the MTU of the underlying network endpoint. +func (r *Route) MTU() uint32 { + return r.ref.ep.MTU() +} + +// Release frees all resources associated with the route. +func (r *Route) Release() { + if r.ref != nil { + r.ref.decRef() + r.ref = nil + } +} + +// Clone Clone a route such that the original one can be released and the new +// one will remain valid. +func (r *Route) Clone() Route { + r.ref.incRef() + return *r +} diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go index ea319f5..df7349f 100644 --- a/tcpip/stack/stack.go +++ b/tcpip/stack/stack.go @@ -1,6 +1,7 @@ package stack import ( + "log" "netstack/sleep" "netstack/tcpip" "netstack/tcpip/ports" @@ -93,6 +94,42 @@ func New(network []string, transport []string, opts Options) *Stack { return s } +func (s *Stack) Stats() tcpip.Stats { + return s.stats +} + +// SetForwarding enables or disables the packet forwarding between NICs. +func (s *Stack) SetForwarding(enable bool) { + // TODO: Expose via /proc/sys/net/ipv4/ip_forward. + s.mu.Lock() + s.forwarding = enable + s.mu.Unlock() +} + +// Forwarding returns if the packet forwarding between NICs is enabled. +func (s *Stack) Forwarding() bool { + // TODO: Expose via /proc/sys/net/ipv4/ip_forward. + s.mu.RLock() + defer s.mu.RUnlock() + return s.forwarding +} + +// SetRouteTable assigns the route table to be used by this stack. It +// specifies which NIC to use for given destination address ranges. +func (s *Stack) SetRouteTable(table []tcpip.Route) { + s.mu.Lock() + defer s.mu.Unlock() + + s.routeTable = table +} + +// GetRouteTable returns the route table which is currently in use. +func (s *Stack) GetRouteTable() []tcpip.Route { + s.mu.Lock() + defer s.mu.Unlock() + return append([]tcpip.Route(nil), s.routeTable...) +} + func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { return s.createNIC(id, "", linkEP, true) } @@ -178,6 +215,48 @@ func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpi return false, tcpip.ErrUnknownNICID } +// 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息 +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, + netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + for i := range s.routeTable { + //if (id != 0 && id != s.routeTable[i].NIC) || + // (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) { + // continue + //} + + nic := s.nics[s.routeTable[i].NIC] + if nic == nil { + continue + } + + var ref *referencedNetworkEndpoint + if len(localAddr) != 0 { + ref = nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) + } else { + ref = nic.primaryEndpoint(netProto) + } + if ref == nil { + continue + } + + if len(remoteAddr) == 0 { + // If no remote address was provided, then the route + // provided will refer to the link local address. + remoteAddr = ref.ep.ID().LocalAddress // 发回自己? TODO + } + + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, nic.linkEP.LinkAddress(), ref) + r.NextHop = s.routeTable[i].Gateway + log.Println(r.LocalLinkAddress, r.LocalAddress, r.RemoteLinkAddress, r.RemoteAddress, r.NextHop) + return r, nil + } + + return Route{}, tcpip.ErrNoRoute +} + func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { return 0 } diff --git a/tcpip/stack/stack_test.go b/tcpip/stack/stack_test.go index c437f4d..5e44c2f 100644 --- a/tcpip/stack/stack_test.go +++ b/tcpip/stack/stack_test.go @@ -39,7 +39,13 @@ func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { } func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { - return nil + b := hdr.Prepend(fakeNetHeaderLen) + copy(b[:4], []byte(r.RemoteAddress)) + copy(b[4:8], []byte(f.id.LocalAddress)) + b[8] = byte(protocol) + log.Println("写入网络层数据 下一层去往链路层", b, payload) + + return f.linkEP.WritePacket(r, hdr, payload, 114514) } func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { @@ -110,7 +116,7 @@ func TestStackBase(t *testing.T) { if err := myStack.CreateNIC(2, id2); err != nil { // 将上面的物理设备抽象成我们的网卡对象 panic(err) } - myStack.AddAddress(1, 114514, "\x0a\xff\x01\x02") // 给网卡对象绑定一个IP地址 可以绑定多个 + myStack.AddAddress(2, 114514, "\x0a\xff\x01\x02") // 给网卡对象绑定一个IP地址 可以绑定多个 buf := buffer.NewView(30) for i := range buf { @@ -127,5 +133,28 @@ func TestStackBase(t *testing.T) { buf[6] = '\x01' buf[7] = '\x01' - ep1.Inject(114514, buf.ToVectoriseView()) + myStack.SetRouteTable([]tcpip.Route{ + {"\x01", "\x01", "\x00", 1}, + {"\x00", "\x01", "\x00", 2}, + }) + + sendTo(t, myStack, tcpip.Address("\x0a\xff\x01\x02")) + + //log.Println(ep1.Drain()) + p := <-ep1.C + log.Println(p) +} + +func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) { + r, err := s.FindRoute(0, "", addr, 114514) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + defer r.Release() + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) + if err := r.WritePacket(hdr, buffer.VectorisedView{}, 10086, 123); err != nil { + t.Errorf("WritePacket failed: %v", err) + return + } } diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go index 79ea692..c473557 100644 --- a/tcpip/tcpip.go +++ b/tcpip/tcpip.go @@ -172,7 +172,7 @@ type Route struct { Destination Address // 目标地址 Mask AddressMask // 掩码 Gateway Address // 网关 - MIC NICID // 使用的网卡设备 + NIC NICID // 使用的网卡设备 } // Stats 包含了网络栈的统计信息