diff --git a/ilist/list.go b/ilist/list.go index 651a9a8..88038aa 100644 --- a/ilist/list.go +++ b/ilist/list.go @@ -11,10 +11,10 @@ type Element interface { Linker } -type ElementMapper struct {} +type ElementMapper struct{} func (ElementMapper) linkerFor(elem Element) Linker { - return elem; + return elem } type List struct { @@ -31,6 +31,10 @@ func (l *List) Empty() bool { return l.head == nil } +func (l *List) Front() Element { + return l.head +} + func (l *List) Back() Element { return l.tail } diff --git a/tcpip/header/ipv4.go b/tcpip/header/ipv4.go new file mode 100644 index 0000000..2495c37 --- /dev/null +++ b/tcpip/header/ipv4.go @@ -0,0 +1,30 @@ +package header + +import "netstack/tcpip" + +type IPv4 []byte + +const ( + // IPv4MinimumSize is the minimum size of a valid IPv4 packet. + IPv4MinimumSize = 20 + + // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given + // that there are only 4 bits to represents the header length in 32-bit + // units, the header cannot exceed 15*4 = 60 bytes. + IPv4MaximumHeaderSize = 60 + + // IPv4AddressSize is the size, in bytes, of an IPv4 address. + IPv4AddressSize = 4 + + // IPv4ProtocolNumber is IPv4's network protocol number. + IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 + + // IPv4Version is the version of the ipv4 protocol. + IPv4Version = 4 + + // IPv4Broadcast is the broadcast address of the IPv4 procotol. + IPv4Broadcast tcpip.Address = "\xff\xff\xff\xff" + + // IPv4Any is the non-routable IPv4 "any" meta address. + IPv4Any tcpip.Address = "\x00\x00\x00\x00" +) diff --git a/tcpip/header/ipv6.go b/tcpip/header/ipv6.go new file mode 100644 index 0000000..6687cd1 --- /dev/null +++ b/tcpip/header/ipv6.go @@ -0,0 +1,23 @@ +package header + +import "netstack/tcpip" + +type IPv6 []byte + +const ( + // IPv6MinimumSize is the minimum size of a valid IPv6 packet. + IPv6MinimumSize = 40 + + // IPv6AddressSize is the size, in bytes, of an IPv6 address. + IPv6AddressSize = 16 + + // IPv6ProtocolNumber is IPv6's network protocol number. + IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd + + // IPv6Version is the version of the ipv6 protocol. + IPv6Version = 6 + + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, + // section 5. + IPv6MinimumMTU = 1280 +) diff --git a/tcpip/stack/nic.go b/tcpip/stack/nic.go index 95ffa30..3313932 100644 --- a/tcpip/stack/nic.go +++ b/tcpip/stack/nic.go @@ -5,7 +5,10 @@ import ( "netstack/ilist" "netstack/tcpip" "netstack/tcpip/buffer" + "netstack/tcpip/header" + "strings" "sync" + "sync/atomic" ) // PrimaryEndpointBehavior 是端点首要行为的枚举 @@ -44,7 +47,7 @@ type NIC struct { spoofing bool promiscuous bool // 混杂模式 primary map[tcpip.NetworkProtocolNumber]*ilist.List - // 网络层端的记录 + // 网络层端的记录 IP:网络端实现 endpoints map[NetworkEndpointID]*referencedNetworkEndpoint // 子网的记录 subnets []tcpip.Subnet @@ -67,6 +70,22 @@ func (n *NIC) attachLinkEndpoint() { n.linkEP.Attach(n) } +// setPromiscuousMode enables or disables promiscuous mode. +// 设备网卡为混杂模式 +func (n *NIC) setPromiscuousMode(enable bool) { + n.mu.Lock() + n.promiscuous = enable + n.mu.Unlock() +} + +// 判断网卡是否开启混杂模式 +func (n *NIC) isPromiscuousMode() bool { + n.mu.RLock() + rv := n.promiscuous + n.mu.RUnlock() + return rv +} + // 在NIC上添加addr地址,注册和初始化网络层协议 // 相当于给网卡添加ip地址 func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, @@ -76,9 +95,56 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip. log.Println("添加失败") return nil, tcpip.ErrUnknownProtocol } - log.Println(netProto.Number(), "添加ip", addr.String()) - // TODO 接着这里实现 22/11/24 21:29 - return nil, nil + log.Printf("基于[%d]协议 为 #%d 网卡 添加IP: %s\n", netProto.Number(), n.id, addr.String()) + + // 比如netProto是ipv4 会调用ipv4.NewEndpoint 新建一个网络端 + ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP) + if err != nil { + return nil, err + } + + // 获取网络层端的id 其实就是ip地址 + id := *ep.ID() + if ref, ok := n.endpoints[id]; ok { + // 不是替换 且该id已经存在 + if !replace { + return nil, tcpip.ErrDuplicateAddress + } + n.removeEndpointLocked(ref) // 这里被调用的时候已经上过锁了 NOTE + } + + ref := &referencedNetworkEndpoint{ + refs: 1, + ep: ep, + nic: n, + protocol: protocol, + holdsInsertRef: true, + } + + // Set up cache if link address resolution exists for this protocol. + if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if _, ok := n.stack.linkAddrResolvers[protocol]; ok { + ref.linkCache = n.stack + } + } + + // 注册该网络端 + n.endpoints[id] = ref + + l, ok := n.primary[protocol] + if !ok { + l = &ilist.List{} + n.primary[protocol] = l + } + + switch peb { + case CanBePrimaryEndpoint: + l.PushBack(ref) // 目前走这一支 + case FirstPrimaryEndpoint: + l.PushFront(ref) + } + log.Printf("Network Info: %v \n", ref.ep.ID()) + return ref, nil } func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { @@ -94,8 +160,235 @@ func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber, return err } -func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress, +// 删除一个网络端 +func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { + id := *r.ep.ID() + + // Nothing to do if the reference has already been replaced with a + // different one. + if n.endpoints[id] != r { + return + } + + if r.holdsInsertRef { + panic("Reference count dropped to zero before being removed") + } + + delete(n.endpoints, id) + wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front() + if wasInList { + n.primary[r.protocol].Remove(r) + } + + r.ep.Close() +} + +func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { + n.mu.Lock() + n.removeEndpointLocked(r) + n.mu.Unlock() +} + +// 根据address参数找到对应的网络层端 +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, + peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + id := NetworkEndpointID{address} + + n.mu.RLock() + ref := n.endpoints[id] + if ref != nil && !ref.tryIncRef() { + ref = nil + } + spoofing := n.spoofing + n.mu.RUnlock() + + if ref != nil || !spoofing { + return ref + } + + // Try again with the lock in exclusive mode. If we still can't get the + // endpoint, create a new "temporary" endpoint. It will only exist while + // there's a route through it. + n.mu.Lock() + ref = n.endpoints[id] + if ref == nil || !ref.tryIncRef() { + ref, _ = n.addAddressLocked(protocol, address, peb, true) + if ref != nil { + ref.holdsInsertRef = false + } + } + n.mu.Unlock() + return ref +} + +// AddSubnet adds a new subnet to n, so that it starts accepting packets +// targeted at the given address and network protocol. +// AddSubnet向n添加一个新子网,以便它开始接受针对给定地址和网络协议的数据包。 +func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { + n.mu.Lock() + n.subnets = append(n.subnets, subnet) + n.mu.Unlock() +} + +// RemoveSubnet removes the given subnet from n. +// 从n中删除一个子网 +func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { + n.mu.Lock() + + // Use the same underlying array. + tmp := n.subnets[:0] + for _, sub := range n.subnets { + if sub != subnet { + tmp = append(tmp, sub) + } + } + n.subnets = tmp + + n.mu.Unlock() +} + +// ContainsSubnet reports whether this NIC contains the given subnet. +// 判断 subnet 这个子网是否在该网卡下 +func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { + for _, s := range n.Subnets() { + if s == subnet { + return true + } + } + return false +} + +// Subnets returns the Subnets associated with this NIC. +// 获取该网卡的所有子网 +func (n *NIC) Subnets() []tcpip.Subnet { + n.mu.RLock() + defer n.mu.RUnlock() + sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + for nid := range n.endpoints { + sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) + if err != nil { + // This should never happen as the mask has been carefully crafted to + // match the address. + panic("Invalid endpoint subnet: " + err.Error()) + } + sns = append(sns, sn) + } + return append(sns, n.subnets...) +} + +// 根据协议类型和目标地址,找出关联的Endpoint +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + id := NetworkEndpointID{dst} + + n.mu.RLock() + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + log.Println("找到了目标地址: ", id) + n.mu.RUnlock() + return ref + } + + promiscuous := n.promiscuous + // Check if the packet is for a subnet this NIC cares about. + if !promiscuous { + for _, sn := range n.subnets { + if sn.Contains(dst) { + promiscuous = true + break + } + } + } + n.mu.RUnlock() + if promiscuous { + // Try again with the lock in exclusive mode. If we still can't + // get the endpoint, create a new "temporary" one. It will only + // exist while there's a route through it. + n.mu.Lock() + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + n.mu.Unlock() + return ref + } + ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true) + n.mu.Unlock() + if err == nil { + ref.holdsInsertRef = false + return ref + } + } + + return nil +} + +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) { // TODO 需要完成逻辑 - log.Println(vv.ToView()) + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.stack.stats.UnknownProtocolRcvdPackets.Increment() + return + } + + if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber { + n.stack.stats.IP.PacketsReceived.Increment() + } + + if len(vv.First()) < netProto.MinimumPacketSize() { + n.stack.stats.MalformedRcvdPackets.Increment() + return + } + src, dst := netProto.ParseAddresses(vv.First()) + log.Printf("设备[%s]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, vv.ToView()) + // 根据网络协议和数据包的目的地址,找到网络端 + // 然后将数据包分发给网络层 + if ref := n.getRef(protocol, dst); ref != nil { + r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref) + r.RemoteLinkAddress = remoteLinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() + + return + } + n.stack.stats.IP.InvalidAddressesReceived.Increment() +} + +// 网络端引用 +type referencedNetworkEndpoint struct { + ilist.Entry + refs int32 // 引用计数 + ep NetworkEndpoint // 网络端实现 + nic *NIC + protocol tcpip.NetworkProtocolNumber + + // linkCache is set if link address resolution is enabled for this + // protocol. Set to nil otherwise. + linkCache LinkAddressCache + linkAddrCache + + // holdsInsertRef is protected by the NIC's mutex. It indicates whether + // the reference count is biased by 1 due to the insertion of the + // endpoint. It is reset to false when RemoveAddress is called on the + // NIC. + holdsInsertRef bool +} + +func (r *referencedNetworkEndpoint) decRef() { + if atomic.AddInt32(&r.refs, -1) == 0 { + r.nic.removeEndpoint(r) + } +} + +func (r *referencedNetworkEndpoint) incRef() { + atomic.AddInt32(&r.refs, 1) +} + +func (r *referencedNetworkEndpoint) tryIncRef() bool { + for { + v := atomic.LoadInt32(&r.refs) + if v == 0 { + return false + } + + if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { + return true + } + } } diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go index e5ae3ef..dff37e8 100644 --- a/tcpip/stack/registration.go +++ b/tcpip/stack/registration.go @@ -1,8 +1,6 @@ package stack import ( - "log" - "netstack/ilist" "netstack/sleep" "netstack/tcpip" "netstack/tcpip/buffer" @@ -108,13 +106,70 @@ var ( // ==============================网络层相关============================== type NetworkProtocol interface { + // 网络协议版本号 Number() tcpip.NetworkProtocolNumber - // todo 需要添加 + + // MinimumPacketSize returns the minimum valid packet size of this + // network protocol. The stack automatically drops any packets smaller + // than this targeted at this protocol. + MinimumPacketSize() int + + // ParsePorts returns the source and destination addresses stored in a + // packet of this protocol. + ParseAddresses(v buffer.View) (src, dst tcpip.Address) + + // 新建一个网络终端 比如 ipv4 或者 ipv6 的一个实现 + NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, + dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + + // SetOption allows enabling/disabling protocol specific features. + // SetOption returns an error if the option is not supported or the + // provided option value is invalid. + SetOption(option interface{}) *tcpip.Error + + // Option allows retrieving protocol specific option values. + // Option returns an error if the option is not supported or the + // provided option value is invalid. + Option(option interface{}) *tcpip.Error } // NetworkEndpoint是需要由网络层协议(例如,ipv4,ipv6)的端点实现的接口 type NetworkEndpoint interface { - // TODO 需要添加 + // DefaultTTL is the default time-to-live value (or hop limit, in ipv6) + // for this endpoint. + DefaultTTL() uint8 + + // MTU is the maximum transmission unit for this endpoint. This is + // generally calculated as the MTU of the underlying data link endpoint + // minus the network endpoint max header length. + MTU() uint32 + + // Capabilities returns the set of capabilities supported by the + // underlying link-layer endpoint. + Capabilities() LinkEndpointCapabilities + + // MaxHeaderLength returns the maximum size the network (and lower + // level layers combined) headers can have. Higher levels use this + // information to reserve space in the front of the packets they're + // building. + MaxHeaderLength() uint16 + + // WritePacket writes a packet to the given destination address and + // protocol. + WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error + + // ID returns the network protocol endpoint ID. + ID() *NetworkEndpointID + + // NICID returns the id of the NIC this endpoint belongs to. + NICID() tcpip.NICID + + // HandlePacket is called by the link layer when new packets arrive to + // this network endpoint. + HandlePacket(r *Route, vv buffer.VectorisedView) + + // Close is called when the endpoint is reomved from a stack. + Close() } type NetworkEndpointID struct { @@ -137,29 +192,16 @@ type TransportEndpoint interface { } // TODO 需要解读 -type referencedNetworkEndpoint struct { - ilist.Entry - refs int32 - ep NetworkEndpoint - nic *NIC - protocol tcpip.NetworkProtocolNumber +type TransportProtocol interface { +} - // linkCache is set if link address resolution is enabled for this - // protocol. Set to nil otherwise. - linkCache LinkAddressCache - linkAddrCache - - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool +// TODO 需要解读 +type TransportDispatcher interface { } // 注册一个新的网络协议工厂 func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { networkProtocols[name] = p - log.Println(networkProtocols) } // 注册一个链路层设备 diff --git a/tcpip/stack/route.go b/tcpip/stack/route.go index de96a7d..a5bf7f8 100644 --- a/tcpip/stack/route.go +++ b/tcpip/stack/route.go @@ -25,3 +25,15 @@ type Route struct { // 相关的网络终端 ref *referencedNetworkEndpoint } + +// 根据参数新建一个路由,并关联一个网络层端 +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, + localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route { + return Route{ + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: localLinkAddr, + RemoteAddress: remoteAddr, + ref: ref, + } +} diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go index 7e5f2a1..ea319f5 100644 --- a/tcpip/stack/stack.go +++ b/tcpip/stack/stack.go @@ -1,7 +1,7 @@ package stack import ( - "log" + "netstack/sleep" "netstack/tcpip" "netstack/tcpip/ports" "sync" @@ -49,12 +49,32 @@ type Stack struct { clock tcpip.Clock } -func New(network []string) *Stack { +// Options contains optional Stack configuration. +type Options struct { + // Clock is an optional clock source used for timestampping packets. + // + // If no Clock is specified, the clock source will be time.Now. + Clock tcpip.Clock + + // Stats are optional statistic counters. + Stats tcpip.Stats +} + +func New(network []string, transport []string, opts Options) *Stack { + clock := opts.Clock + if clock == nil { + clock = &tcpip.StdClock{} + } + s := &Stack{ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), nics: make(map[tcpip.NICID]*NIC), + //linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + //PortManager: ports.NewPortManager(), + clock: clock, + stats: opts.Stats.FillIn(), } // 添加指定的网络端协议 必须已经在init中注册过 @@ -62,7 +82,6 @@ func New(network []string) *Stack { // 先检查这个网络协议是否注册过工厂方法 netProtoFactory, ok := networkProtocols[name] if !ok { - log.Println(name) continue // 没有就略过 } netProto := netProtoFactory() // 制造一个该型号协议的示实例 @@ -119,3 +138,59 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt return nic.AddAddressWithOptions(protocol, addr, peb) } + +// AddSubnet adds a subnet range to the specified NIC. +func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + nic.AddSubnet(protocol, subnet) + return nil + } + + return tcpip.ErrUnknownNICID +} + +// RemoveSubnet removes the subnet range from the specified NIC. +func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + nic.RemoveSubnet(subnet) + return nil + } + + return tcpip.ErrUnknownNICID +} + +// ContainsSubnet reports whether the specified NIC contains the specified +// subnet. +func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic, ok := s.nics[id]; ok { + return nic.ContainsSubnet(subnet), nil + } + + return false, tcpip.ErrUnknownNICID +} + +func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { + return 0 +} + +func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + +} + +func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, + protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { + return "", nil, nil +} + +func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { + +} diff --git a/tcpip/stack/stack_test.go b/tcpip/stack/stack_test.go index 82188a6..c437f4d 100644 --- a/tcpip/stack/stack_test.go +++ b/tcpip/stack/stack_test.go @@ -10,16 +10,86 @@ import ( ) const ( - defaultMTU = 65536 + fakeNetHeaderLen = 12 + defaultMTU = 65536 ) -type fakeNetworkProtocol struct { +type fakeNetworkEndpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + proto *fakeNetworkProtocol + dispatcher stack.TransportDispatcher + linkEP stack.LinkEndpoint } +func (f *fakeNetworkEndpoint) DefaultTTL() uint8 { + return 123 +} + +func (f *fakeNetworkEndpoint) MTU() uint32 { + return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) +} + +func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return f.linkEP.Capabilities() +} + +func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { + return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen +} +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, + protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + return nil +} + +func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { + return &f.id +} + +func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { + return f.nicid +} + +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { + log.Println("执行这个函数 接下来它会去向传输层分发数据") +} + +func (f *fakeNetworkEndpoint) Close() {} + +// dst|src|payload +type fakeNetworkProtocol struct{} + func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { return 114514 } +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, + dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return &fakeNetworkEndpoint{ + nicid: nicid, + id: stack.NetworkEndpointID{addr}, + proto: f, + dispatcher: dispatcher, + linkEP: linkEP, + }, nil +} + +func (f *fakeNetworkProtocol) MinimumPacketSize() int { + return fakeNetHeaderLen +} + +func (f *fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[4:8]), tcpip.Address(v[0:4]) +} + +func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { + return nil +} + +func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { + return nil +} + func init() { stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { return &fakeNetworkProtocol{} @@ -28,18 +98,34 @@ func init() { func TestStackBase(t *testing.T) { - myStack := stack.New([]string{"fakeNet"}) - id, ep := channel.New(10, defaultMTU, "") // 这是一个物理设备 - log.Println(id) + myStack := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + id1, ep1 := channel.New(10, defaultMTU, "00:15:5d:26:d7:a1") // 这是一个物理设备 - if err := myStack.CreateNIC(1, id); err != nil { // 将上面的物理设备抽象成我们的网卡对象 + if err := myStack.CreateNIC(1, id1); err != nil { // 将上面的物理设备抽象成我们的网卡对象 panic(err) } - myStack.AddAddress(1, 114514, "\x01") // 给网卡对象绑定一个IP地址 可以绑定多个 + myStack.AddAddress(1, 114514, "\x0a\xff\x01\x01") // 给网卡对象绑定一个IP地址 可以绑定多个 + + id2, _ := channel.New(10, defaultMTU, "50:5B:C2:D0:96:57") // 这是一个物理设备 + if err := myStack.CreateNIC(2, id2); err != nil { // 将上面的物理设备抽象成我们的网卡对象 + panic(err) + } + myStack.AddAddress(1, 114514, "\x0a\xff\x01\x02") // 给网卡对象绑定一个IP地址 可以绑定多个 buf := buffer.NewView(30) for i := range buf { - buf[i] = 1 + buf[i] = 0 } - ep.Inject(114514, buf.ToVectoriseView()) + // dst 10.255.1.2 + buf[0] = '\x0a' + buf[1] = '\xff' + buf[2] = '\x01' + buf[3] = '\x02' + // src 10.255.1.1 + buf[4] = '\x0a' + buf[5] = '\xff' + buf[6] = '\x01' + buf[7] = '\x01' + + ep1.Inject(114514, buf.ToVectoriseView()) } diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go index ea406ca..79ea692 100644 --- a/tcpip/tcpip.go +++ b/tcpip/tcpip.go @@ -1,8 +1,11 @@ package tcpip import ( + "errors" "fmt" + "reflect" "strings" + "sync/atomic" ) type Error struct { @@ -55,6 +58,12 @@ var ( ErrNoBufferSpace = &Error{msg: "no buffer space available"} ) +// Errors related to Subnet +var ( + errSubnetLengthMismatch = errors.New("subnet length of address and mask differ") + errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask") +) + // Clock 提供当前的时间戳 type Clock interface { NowNanoseconds() int64 @@ -83,6 +92,70 @@ type Subnet struct { mask AddressMask } +// NewSubnet creates a new Subnet, checking that the address and mask are the same length. +func NewSubnet(a Address, m AddressMask) (Subnet, error) { + if len(a) != len(m) { + return Subnet{}, errSubnetLengthMismatch + } + for i := 0; i < len(a); i++ { + if a[i]&^m[i] != 0 { + return Subnet{}, errSubnetAddressMasked + } + } + return Subnet{a, m}, nil +} + +// Contains returns true iff the address is of the same length and matches the +// subnet address and mask. +func (s *Subnet) Contains(a Address) bool { + if len(a) != len(s.address) { + return false + } + for i := 0; i < len(a); i++ { + if a[i]&s.mask[i] != s.address[i] { + return false + } + } + return true +} + +// ID returns the subnet ID. +func (s *Subnet) ID() Address { + return s.address +} + +// Bits returns the number of ones (network bits) and zeros (host bits) in the +// subnet mask. +func (s *Subnet) Bits() (ones int, zeros int) { + for _, b := range []byte(s.mask) { + for i := uint(0); i < 8; i++ { + if b&(1<= 0; j-- { + if b&(1<