arp基本实现 创建一个网卡对象并绑定到ip1 网卡收到一个arp报文 链路层分发给arp网络端实现 arp到本地缓存中查找 该网卡是否绑定过目标ip地址ip1 没有直接放弃 有就新建一个源与目标逆置并添加了该网卡MAC的arp报文 并包装给链路层

This commit is contained in:
impact-eintr
2022-11-26 18:52:11 +08:00
parent 20b5b3415a
commit d9c0633bf5
12 changed files with 556 additions and 33 deletions

View File

@@ -23,6 +23,7 @@ func (p Prependable) UsedLength() int {
return len(p.buf) - p.usedIdx return len(p.buf) - p.usedIdx
} }
// 从内到外取出报文头的协议
func (p *Prependable) Prepend(size int) []byte { func (p *Prependable) Prepend(size int) []byte {
if size > p.usedIdx { if size > p.usedIdx {
return nil return nil

View File

@@ -1,6 +1,5 @@
package buffer package buffer
type View []byte type View []byte
func NewView(size int) View { func NewView(size int) View {
@@ -21,14 +20,14 @@ func (v *View) CapLength(length int) {
*v = (*v)[:length:length] *v = (*v)[:length:length]
} }
func (v View) ToVectoriseView() VectorisedView { func (v View) ToVectorisedView() VectorisedView {
return NewVectorisedView(len(v), []View{v}) return NewVectorisedView(len(v), []View{v})
} }
// VectorisedView 是使用非连续内存的 View 的矢量化版本 // VectorisedView 是使用非连续内存的 View 的矢量化版本
type VectorisedView struct { type VectorisedView struct {
views []View views []View
size int size int
} }
func NewVectorisedView(size int, views []View) VectorisedView { func NewVectorisedView(size int, views []View) VectorisedView {

View File

@@ -65,6 +65,12 @@ func (a ARP) SetIPv4OverEthernet() {
a[5] = uint8(IPv4AddressSize) a[5] = uint8(IPv4AddressSize)
} }
// HardwareAddressSender从报文中得到arp发送方的硬件地址
func (a ARP) HardwareAddressSender() []byte {
const s = 8
return a[s : s+6]
}
// ProtocolAddressSender从报文中得到arp发送方的协议地址为ipv4地址 // ProtocolAddressSender从报文中得到arp发送方的协议地址为ipv4地址
func (a ARP) ProtocolAddressSender() []byte { func (a ARP) ProtocolAddressSender() []byte {
const s = 8 + 6 // 8 是arp的协议头部 6是本机MAC const s = 8 + 6 // 8 是arp的协议头部 6是本机MAC

View File

@@ -49,6 +49,7 @@ func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv buffer.Vector
// InjectLinkAddr injects an inbound packet with a remote link address. // InjectLinkAddr injects an inbound packet with a remote link address.
func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) { func (e *Endpoint) InjectLinkAddr(protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, vv buffer.VectorisedView) {
// 这里的实现在NIC.go中 由 网卡对象进行数据分发
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, "" /* localLinkAddr */, protocol, vv.Clone(nil)) e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, "" /* localLinkAddr */, protocol, vv.Clone(nil))
} }

View File

@@ -2,36 +2,36 @@ package fdbased
import ( import (
"fmt" "fmt"
"reflect"
"time"
"math/rand" "math/rand"
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/buffer" "netstack/tcpip/buffer"
"netstack/tcpip/header" "netstack/tcpip/header"
"netstack/tcpip/stack" "netstack/tcpip/stack"
"reflect"
"syscall" "syscall"
"testing" "testing"
"time"
) )
const ( const (
mtu = 1500 mtu = 1500
laddr = tcpip.LinkAddress("\x65\x66\x67\x68\x69\x70") laddr = tcpip.LinkAddress("\x65\x66\x67\x68\x69\x70")
raddr = tcpip.LinkAddress("\x71\x72\x73\x74\x75\x76") raddr = tcpip.LinkAddress("\x71\x72\x73\x74\x75\x76")
proto = 10 proto = 10
) )
type packetInfo struct { type packetInfo struct {
raddr tcpip.LinkAddress raddr tcpip.LinkAddress
proto tcpip.NetworkProtocolNumber proto tcpip.NetworkProtocolNumber
contents buffer.View contents buffer.View
} }
type context struct { type context struct {
t *testing.T t *testing.T
fds [2]int fds [2]int
ep stack.LinkEndpoint ep stack.LinkEndpoint
ch chan packetInfo // 信道 ch chan packetInfo // 信道
done chan struct{} // 通知退出 done chan struct{} // 通知退出
} }
func newContext(t *testing.T, opt *Options) *context { func newContext(t *testing.T, opt *Options) *context {
@@ -49,10 +49,10 @@ func newContext(t *testing.T, opt *Options) *context {
ep := stack.FindLinkEndpoint(New(opt)).(*endpoint) // 找到端口实现 ep := stack.FindLinkEndpoint(New(opt)).(*endpoint) // 找到端口实现
c := &context{ c := &context{
t: t, t: t,
fds: fds, fds: fds,
ep:ep, ep: ep,
ch: make(chan packetInfo, 100), ch: make(chan packetInfo, 100),
done: done, done: done,
} }
@@ -79,7 +79,7 @@ func TestFdbased(t *testing.T) {
// Build header // Build header
hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) // 114 hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100) // 114
b := hdr.Prepend(100) // payload b := hdr.Prepend(100) // payload
for i := range b { for i := range b {
b[i] = uint8(rand.Intn(256)) b[i] = uint8(rand.Intn(256))
} }
@@ -91,7 +91,7 @@ func TestFdbased(t *testing.T) {
} }
if err := c.ep.WritePacket(&stack.Route{RemoteLinkAddress: raddr}, hdr, if err := c.ep.WritePacket(&stack.Route{RemoteLinkAddress: raddr}, hdr,
payload.ToVectoriseView(), proto); err != nil { payload.ToVectorisedView(), proto); err != nil {
panic(err) panic(err)
} }

View File

@@ -1,7 +1,172 @@
// 主机的链路层寻址是通过 arp 表来实现的 // 主机的链路层寻址是通过 arp 表来实现的
package arp package arp
const ( import (
ProtocolName = "arp" "log"
ProtocolNumber = "arp" "netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/stack"
) )
const (
ProtocolName = "arp"
ProtocolNumber = header.ARPProtocolNumber
ProtocolAddress = tcpip.Address("arp")
)
// arp endpoint 一个网络层的实现 Implement stack.NetworkEndpoint
type endpoint struct {
nicid tcpip.NICID // arp报文使用的网卡
addr tcpip.Address // 网络层地址
linkEP stack.LinkEndpoint // MAC
linkAddrCache stack.LinkAddressCache // 链路高速缓存
}
func (e *endpoint) DefaultTTL() uint8 {
return 0
}
func (e *endpoint) MTU() uint32 {
lmtu := e.linkEP.MTU()
return lmtu - uint32(e.MaxHeaderLength())
}
func (e *endpoint) NICID() tcpip.NICID {
return e.nicid
}
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
}
func (e *endpoint) ID() *stack.NetworkEndpointID {
return &stack.NetworkEndpointID{LocalAddress: ProtocolAddress}
}
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.ARPSize
}
// arp不支持写包
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
return tcpip.ErrNotSupported
}
// arp数据包的处理包括arp请求和响应
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
v := vv.First()
h := header.ARP(v)
if !h.IsValid() {
return
}
// 判断操作码类型
switch h.Op() {
case header.ARPRequest:
// 如果是ARP请求
localAddr := tcpip.Address(h.ProtocolAddressTarget())
if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
return // 无效的ARP请求
}
// arp报文所在的网卡绑定了这个地址
hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize) // 以太 + ARP
pkt := header.ARP(hdr.Prepend(header.ARPSize)) // 取出 ARP
pkt.SetIPv4OverEthernet()
pkt.SetOp(header.ARPReply)
copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:]) // 写入本机MAC作为响应 NOTE
// 倒置目标与源 作为回应
copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
log.Println("处理注入的ARP请求 这里将返回一个ARP报文作为响应")
e.linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber) // 往链路层写回消息
// 注意这里的 fallthrough 表示需要继续执行下面分支的代码
// 当收到 arp 请求需要添加到链路地址缓存中
fallthrough // also fill the cache from requests
case header.ARPReply:
// 这里记录ip和mac对应关系也就是arp表
addr := tcpip.Address(h.ProtocolAddressSender())
linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
default:
panic(tcpip.ErrUnknownProtocol)
}
}
func (e *endpoint) Close() {}
// 实现了 stack.NetworkProtocol 和 stack.LinkAddressResolver 两个接口
type protocol struct{}
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
}
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache,
dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
if addr != ProtocolAddress {
return nil, tcpip.ErrBadLocalAddress
}
return &endpoint{
nicid: nicid,
addr: addr,
linkEP: linkEP,
linkAddrCache: linkAddrCache,
}, nil
}
func (p *protocol) MinimumPacketSize() int {
return header.ARPSize
}
func (p *protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
h := header.ARP(v)
return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
}
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// LinkAddressProtocol implements stack.LinkAddressResolver.
func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
return header.IPv4ProtocolNumber
}
// LinkAddressRequest implements stack.LinkAddressResolver.
func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
r := &stack.Route{
RemoteLinkAddress: broadcastMAC,
}
hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
h := header.ARP(hdr.Prepend(header.ARPSize))
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest)
copy(h.HardwareAddressSender(), linkEP.LinkAddress())
copy(h.ProtocolAddressSender(), localAddr)
copy(h.ProtocolAddressTarget(), addr)
return linkEP.WritePacket(r, hdr, buffer.VectorisedView{}, ProtocolNumber)
}
// ResolveStaticAddress implements stack.LinkAddressResolver.
func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
if addr == "\xff\xff\xff\xff" {
return broadcastMAC, true
}
return "", false
}
var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
func init() {
stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
return &protocol{}
})
}

View File

@@ -1 +1,134 @@
package arp_test package arp_test
import (
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/link/channel"
"netstack/tcpip/network/arp"
"netstack/tcpip/network/ipv4"
"netstack/tcpip/stack"
"testing"
"time"
)
const (
stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c") // 0a:0a:0b:0b:0c:0c
stackAddr1 = tcpip.Address("\x0a\x00\x00\x01") // 10.0.0.1
stackAddr2 = tcpip.Address("\x0a\x00\x00\x02") // 10.0.0.2
stackAddrBad = tcpip.Address("\x0a\x00\x00\x03") // 10.0.0.3
)
type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
}
func newTestContext(t *testing.T) *testContext {
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, nil, stack.Options{})
const defaultMTU = 65536
id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
if err := s.CreateNIC(1, id); err != nil {
t.Fatalf("CreateNIC failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
t.Fatalf("AddAddress for ipv4 failed: %v", err)
}
if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
t.Fatalf("AddAddress for arp failed: %v", err)
}
s.SetRouteTable([]tcpip.Route{{
Destination: "\x00\x00\x00\x00",
Mask: "\x00\x00\x00\x00",
Gateway: "",
NIC: 1,
}})
return &testContext{
t: t,
s: s,
linkEP: linkEP,
}
}
func (c *testContext) cleanup() {
close(c.linkEP.C)
}
func TestArpBase(t *testing.T) {
c := newTestContext(t)
defer c.cleanup()
const senderMAC = "\x01\x02\x03\x04\x05\x06"
const senderIPv4 = "\x0a\x00\x00\x02"
v := make(buffer.View, header.ARPSize)
h := header.ARP(v)
h.SetIPv4OverEthernet()
h.SetOp(header.ARPRequest) // 一个ARP请求
copy(h.HardwareAddressSender(), senderMAC) // Local MAC
copy(h.ProtocolAddressSender(), senderIPv4) // Local IP
inject := func(addr tcpip.Address) {
copy(h.ProtocolAddressTarget(), addr)
c.linkEP.Inject(arp.ProtocolNumber, v.ToVectorisedView()) // 往链路层注入一个arp报文 链路层将会自动分发它
}
inject(stackAddr1) // target IP 10.0.0.1
select {
case pkt := <-c.linkEP.C:
if pkt.Proto != arp.ProtocolNumber {
t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
}
rep := header.ARP(pkt.Header)
if !rep.IsValid() {
t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
}
if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
t.Errorf("stackAddr1: expected sender to be set")
}
if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("Case #1 Time Out\n")
}
inject(stackAddr2)
select {
case pkt := <-c.linkEP.C:
if pkt.Proto != arp.ProtocolNumber {
t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
}
rep := header.ARP(pkt.Header)
if !rep.IsValid() {
t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
}
if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
t.Errorf("stackAddr2: expected sender to be set")
}
if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
}
case <-time.After(100 * time.Millisecond):
t.Fatalf("Case #2 Time Out\n")
}
inject(stackAddrBad)
select {
case pkt := <-c.linkEP.C:
t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
case <-time.After(100 * time.Millisecond):
// Sleep tests are gross, but this will only potentially flake
// if there's a bug. If there is no bug this will reliably
// succeed.
}
}

153
tcpip/network/ipv4/ipv4.go Normal file
View File

@@ -0,0 +1,153 @@
package ipv4
import (
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"netstack/tcpip/stack"
)
const (
// ProtocolName is the string representation of the ipv4 protocol name.
ProtocolName = "ipv4"
// ProtocolNumber is the ipv4 protocol number.
ProtocolNumber = header.IPv4ProtocolNumber
// maxTotalSize is maximum size that can be encoded in the 16-bit
// TotalLength field of the ipv4 header.
maxTotalSize = 0xffff
// buckets is the number of identifier buckets.
buckets = 2048
)
// IPv4 实现
type endpoint struct {
// 网卡id
nicid tcpip.NICID
// 表示该endpoint的id也是ip地址
id stack.NetworkEndpointID
// 链路端的表示
linkEP stack.LinkEndpoint
// TODO 需要添加
}
// DefaultTTL is the default time-to-live value for this endpoint.
// 默认的TTL值TTL每经过路由转发一次就会减1
func (e *endpoint) DefaultTTL() uint8 {
return 255
}
// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
// the network layer max header length.
// 获取去除ipv4头部后的最大报文长度
func (e *endpoint) MTU() uint32 {
return calculateMTU(e.linkEP.MTU())
}
// Capabilities implements stack.NetworkEndpoint.Capabilities.
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
return e.linkEP.Capabilities()
}
// NICID returns the ID of the NIC this endpoint belongs to.
func (e *endpoint) NICID() tcpip.NICID {
return e.nicid
}
// ID returns the ipv4 endpoint ID.
// 获取该网络层端的id也就是ip地址
func (e *endpoint) ID() *stack.NetworkEndpointID {
return &e.id
}
// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
// underlying protocols).
// 链路层和网络层的头部长度
func (e *endpoint) MaxHeaderLength() uint16 {
return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
}
// WritePacket writes a packet to the given destination address and protocol.
// 将传输层的数据封装加上IP头并调用网卡的写入接口写入IP报文
func (e *endpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView,
protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
return nil
}
// HandlePacket is called by the link layer when new ipv4 packets arrive for
// this endpoint.
// 收到ip包的处理
func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
}
// Close cleans up resources associated with the endpoint.
func (e *endpoint) Close() {
}
// 实现NetworkProtocol接口
type protocol struct{}
// NewEndpoint creates a new ipv4 endpoint.
// 根据参数新建一个ipv4端
func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache,
dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
e := &endpoint{
nicid: nicid,
id: stack.NetworkEndpointID{LocalAddress: addr},
linkEP: linkEP,
}
return e, nil
}
// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
// only for tests that short-circuit the stack. Regular use of the protocol is
// done via the stack, which gets a protocol descriptor from the init() function
// below.
func NewProtocol() stack.NetworkProtocol {
return &protocol{}
}
// Number returns the ipv4 protocol number.
func (p *protocol) Number() tcpip.NetworkProtocolNumber {
return ProtocolNumber
}
// MinimumPacketSize returns the minimum valid ipv4 packet size.
func (p *protocol) MinimumPacketSize() int {
return header.IPv4MinimumSize
}
// ParseAddresses implements NetworkProtocol.ParseAddresses.
func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
//h := header.IPv4(v)
//return h.SourceAddress(), h.DestinationAddress()
return "", ""
}
// SetOption implements NetworkProtocol.SetOption.
func (p *protocol) SetOption(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// Option implements NetworkProtocol.Option.
func (p *protocol) Option(option interface{}) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
// calculateMTU calculates the network-layer payload MTU based on the link-layer
// payload mtu.
func calculateMTU(mtu uint32) uint32 {
if mtu > maxTotalSize {
mtu = maxTotalSize
}
return mtu - header.IPv4MinimumSize
}
func init() {
stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
return &protocol{}
})
}

View File

@@ -44,7 +44,7 @@ type NIC struct {
demux *transportDemuxer demux *transportDemuxer
mu sync.RWMutex mu sync.RWMutex
spoofing bool spoofing bool // 欺骗
promiscuous bool // 混杂模式 promiscuous bool // 混杂模式
primary map[tcpip.NetworkProtocolNumber]*ilist.List // 网络协议号:网络端实现 primary map[tcpip.NetworkProtocolNumber]*ilist.List // 网络协议号:网络端实现
// 网络层端的记录 IP:网络端实现 // 网络层端的记录 IP:网络端实现
@@ -95,13 +95,13 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.
log.Println("添加失败") log.Println("添加失败")
return nil, tcpip.ErrUnknownProtocol return nil, tcpip.ErrUnknownProtocol
} }
log.Printf("基于[%d]协议 为 #%d 网卡 添加IP: %s\n", netProto.Number(), n.id, addr.String())
// 比如netProto是ipv4 会调用ipv4.NewEndpoint 新建一个网络端 // 比如netProto是ipv4 会调用ipv4.NewEndpoint 新建一个网络端
ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP) ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
if err != nil { if err != nil {
return nil, err return nil, err
} }
log.Printf("基于[%d]协议 为 #%d 网卡 添加网络层实现 并绑定地址到: %s\n", netProto.Number(), n.id, ep.ID().LocalAddress)
// 获取网络层端的id 其实就是ip地址 // 获取网络层端的id 其实就是ip地址
id := *ep.ID() id := *ep.ID()
@@ -143,7 +143,6 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.
case FirstPrimaryEndpoint: case FirstPrimaryEndpoint:
l.PushFront(ref) l.PushFront(ref)
} }
log.Printf("Network Info: %v \n", ref.ep.ID())
return ref, nil return ref, nil
} }
@@ -223,7 +222,7 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A
n.mu.RLock() n.mu.RLock()
ref := n.endpoints[id] ref := n.endpoints[id]
if ref != nil && !ref.tryIncRef() { if ref != nil && !ref.tryIncRef() { // 尝试去使用这个网络端实现
ref = nil ref = nil
} }
spoofing := n.spoofing spoofing := n.spoofing
@@ -309,7 +308,7 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
n.mu.RLock() n.mu.RLock()
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
log.Println("找到了目标地址: ", id) log.Println("找到了目标网络层实现: ", id.LocalAddress)
n.mu.RUnlock() n.mu.RUnlock()
return ref return ref
} }
@@ -366,7 +365,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLin
return return
} }
src, dst := netProto.ParseAddresses(vv.First()) src, dst := netProto.ParseAddresses(vv.First())
log.Printf("设备[%s]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, vv.ToView()) log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, vv.ToView())
// 根据网络协议和数据包的目的地址,找到网络端 // 根据网络协议和数据包的目的地址,找到网络端
// 然后将数据包分发给网络层 // 然后将数据包分发给网络层
if ref := n.getRef(protocol, dst); ref != nil { if ref := n.getRef(protocol, dst); ref != nil {

View File

@@ -6,6 +6,16 @@ import (
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/ports" "netstack/tcpip/ports"
"sync" "sync"
"time"
)
const (
// ageLimit is set to the same cache stale time used in Linux.
ageLimit = 1 * time.Minute
// resolutionTimeout is set to the same ARP timeout used in Linux.
resolutionTimeout = 1 * time.Second
// resolutionAttempts is set to the same ARP retries used in Linux.
resolutionAttempts = 3
) )
// TODO 需要解读 // TODO 需要解读
@@ -72,7 +82,7 @@ func New(network []string, transport []string, opts Options) *Stack {
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC), nics: make(map[tcpip.NICID]*NIC),
//linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
//PortManager: ports.NewPortManager(), //PortManager: ports.NewPortManager(),
clock: clock, clock: clock,
stats: opts.Stats.FillIn(), stats: opts.Stats.FillIn(),
@@ -257,19 +267,66 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address,
return Route{}, tcpip.ErrNoRoute return Route{}, tcpip.ErrNoRoute
} }
// ===============本机链路层缓存实现==================
// 检查本地是否绑定过该网络层地址
func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
s.mu.RLock()
defer s.mu.RUnlock()
if nicid != 0 {
nic := s.nics[nicid] // 先拿到网卡
if nic == nil {
return 0
}
ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint) // 看看这张网卡是否绑定过这个地址
if ref == nil {
return 0
}
ref.decRef() // 这个网络端实现使用结束 释放对它的占用
return nic.id
}
// Go through all the NICs.
for _, nic := range s.nics {
ref := nic.findEndpoint(protocol, addr, CanBePrimaryEndpoint)
if ref != nil {
ref.decRef()
return nic.id
}
}
return 0 return 0
} }
func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
s.linkAddrCache.add(fullAddr, linkAddr)
} }
func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address,
protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
return "", nil, nil s.mu.RLock()
// 获取网卡对象
nic := s.nics[nicid]
if nic == nil {
s.mu.RUnlock()
return "", nil, tcpip.ErrUnknownNICID
}
s.mu.RUnlock()
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
// 根据网络层协议号找到对应的地址解析协议
linkRes := s.linkAddrResolvers[protocol]
return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, w)
} }
func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
s.mu.RLock()
defer s.mu.RUnlock()
if nic := s.nics[nicid]; nic == nil {
fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
s.linkAddrCache.removeWaker(fullAddr, waker)
}
} }

View File

@@ -160,6 +160,14 @@ func (s *Subnet) Mask() AddressMask {
// 它通常是一个 6 字节的 MAC 地址。 // 它通常是一个 6 字节的 MAC 地址。
type LinkAddress string // MAC地址 type LinkAddress string // MAC地址
func (l LinkAddress) String() string {
if len(l) == 6 {
return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", l[0], l[1], l[2], l[3], l[4], l[5])
} else {
return string(l)
}
}
type LinkEndpointID uint64 type LinkEndpointID uint64
type TransportProtocolNumber uint32 type TransportProtocolNumber uint32
@@ -249,7 +257,8 @@ func fillIn(v reflect.Value) {
v := v.Field(i) v := v.Field(i)
switch v.Kind() { switch v.Kind() {
case reflect.Ptr: case reflect.Ptr:
if s, ok := v.Addr().Interface().(**StatCounter); ok { x := v.Addr().Interface()
if s, ok := x.(**StatCounter); ok {
if *s == nil { if *s == nil {
*s = &StatCounter{} *s = &StatCounter{}
} }
@@ -307,6 +316,6 @@ func (a Address) String() string {
} }
return b.String() return b.String()
default: default:
return fmt.Sprintf("%x", []byte(a)) return fmt.Sprintf("%s", string(a))
} }
} }