网卡对象 绑定IP地址 然后向网卡对象写入数据 数据中将包含dst和src

This commit is contained in:
impact-eintr
2022-11-25 19:25:54 +08:00
parent 4589d971fd
commit 2312813aac
10 changed files with 808 additions and 43 deletions

View File

@@ -5,7 +5,10 @@ import (
"netstack/ilist"
"netstack/tcpip"
"netstack/tcpip/buffer"
"netstack/tcpip/header"
"strings"
"sync"
"sync/atomic"
)
// PrimaryEndpointBehavior 是端点首要行为的枚举
@@ -44,7 +47,7 @@ type NIC struct {
spoofing bool
promiscuous bool // 混杂模式
primary map[tcpip.NetworkProtocolNumber]*ilist.List
// 网络层端的记录
// 网络层端的记录 IP:网络端实现
endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
// 子网的记录
subnets []tcpip.Subnet
@@ -67,6 +70,22 @@ func (n *NIC) attachLinkEndpoint() {
n.linkEP.Attach(n)
}
// setPromiscuousMode enables or disables promiscuous mode.
// 设备网卡为混杂模式
func (n *NIC) setPromiscuousMode(enable bool) {
n.mu.Lock()
n.promiscuous = enable
n.mu.Unlock()
}
// 判断网卡是否开启混杂模式
func (n *NIC) isPromiscuousMode() bool {
n.mu.RLock()
rv := n.promiscuous
n.mu.RUnlock()
return rv
}
// 在NIC上添加addr地址注册和初始化网络层协议
// 相当于给网卡添加ip地址
func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address,
@@ -76,9 +95,56 @@ func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.
log.Println("添加失败")
return nil, tcpip.ErrUnknownProtocol
}
log.Println(netProto.Number(), "添加ip", addr.String())
// TODO 接着这里实现 22/11/24 21:29
return nil, nil
log.Printf("基于[%d]协议 为 #%d 网卡 添加IP: %s\n", netProto.Number(), n.id, addr.String())
// 比如netProto是ipv4 会调用ipv4.NewEndpoint 新建一个网络端
ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
if err != nil {
return nil, err
}
// 获取网络层端的id 其实就是ip地址
id := *ep.ID()
if ref, ok := n.endpoints[id]; ok {
// 不是替换 且该id已经存在
if !replace {
return nil, tcpip.ErrDuplicateAddress
}
n.removeEndpointLocked(ref) // 这里被调用的时候已经上过锁了 NOTE
}
ref := &referencedNetworkEndpoint{
refs: 1,
ep: ep,
nic: n,
protocol: protocol,
holdsInsertRef: true,
}
// Set up cache if link address resolution exists for this protocol.
if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
if _, ok := n.stack.linkAddrResolvers[protocol]; ok {
ref.linkCache = n.stack
}
}
// 注册该网络端
n.endpoints[id] = ref
l, ok := n.primary[protocol]
if !ok {
l = &ilist.List{}
n.primary[protocol] = l
}
switch peb {
case CanBePrimaryEndpoint:
l.PushBack(ref) // 目前走这一支
case FirstPrimaryEndpoint:
l.PushFront(ref)
}
log.Printf("Network Info: %v \n", ref.ep.ID())
return ref, nil
}
func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
@@ -94,8 +160,235 @@ func (n *NIC) AddAddressWithOptions(protocol tcpip.NetworkProtocolNumber,
return err
}
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, dstLinkAddr, srcLinkAddr tcpip.LinkAddress,
// 删除一个网络端
func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
id := *r.ep.ID()
// Nothing to do if the reference has already been replaced with a
// different one.
if n.endpoints[id] != r {
return
}
if r.holdsInsertRef {
panic("Reference count dropped to zero before being removed")
}
delete(n.endpoints, id)
wasInList := r.Next() != nil || r.Prev() != nil || r == n.primary[r.protocol].Front()
if wasInList {
n.primary[r.protocol].Remove(r)
}
r.ep.Close()
}
func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
n.mu.Lock()
n.removeEndpointLocked(r)
n.mu.Unlock()
}
// 根据address参数找到对应的网络层端
func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address,
peb PrimaryEndpointBehavior) *referencedNetworkEndpoint {
id := NetworkEndpointID{address}
n.mu.RLock()
ref := n.endpoints[id]
if ref != nil && !ref.tryIncRef() {
ref = nil
}
spoofing := n.spoofing
n.mu.RUnlock()
if ref != nil || !spoofing {
return ref
}
// Try again with the lock in exclusive mode. If we still can't get the
// endpoint, create a new "temporary" endpoint. It will only exist while
// there's a route through it.
n.mu.Lock()
ref = n.endpoints[id]
if ref == nil || !ref.tryIncRef() {
ref, _ = n.addAddressLocked(protocol, address, peb, true)
if ref != nil {
ref.holdsInsertRef = false
}
}
n.mu.Unlock()
return ref
}
// AddSubnet adds a new subnet to n, so that it starts accepting packets
// targeted at the given address and network protocol.
// AddSubnet向n添加一个新子网以便它开始接受针对给定地址和网络协议的数据包。
func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
n.mu.Lock()
n.subnets = append(n.subnets, subnet)
n.mu.Unlock()
}
// RemoveSubnet removes the given subnet from n.
// 从n中删除一个子网
func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) {
n.mu.Lock()
// Use the same underlying array.
tmp := n.subnets[:0]
for _, sub := range n.subnets {
if sub != subnet {
tmp = append(tmp, sub)
}
}
n.subnets = tmp
n.mu.Unlock()
}
// ContainsSubnet reports whether this NIC contains the given subnet.
// 判断 subnet 这个子网是否在该网卡下
func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool {
for _, s := range n.Subnets() {
if s == subnet {
return true
}
}
return false
}
// Subnets returns the Subnets associated with this NIC.
// 获取该网卡的所有子网
func (n *NIC) Subnets() []tcpip.Subnet {
n.mu.RLock()
defer n.mu.RUnlock()
sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
for nid := range n.endpoints {
sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
if err != nil {
// This should never happen as the mask has been carefully crafted to
// match the address.
panic("Invalid endpoint subnet: " + err.Error())
}
sns = append(sns, sn)
}
return append(sns, n.subnets...)
}
// 根据协议类型和目标地址找出关联的Endpoint
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
id := NetworkEndpointID{dst}
n.mu.RLock()
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
log.Println("找到了目标地址: ", id)
n.mu.RUnlock()
return ref
}
promiscuous := n.promiscuous
// Check if the packet is for a subnet this NIC cares about.
if !promiscuous {
for _, sn := range n.subnets {
if sn.Contains(dst) {
promiscuous = true
break
}
}
}
n.mu.RUnlock()
if promiscuous {
// Try again with the lock in exclusive mode. If we still can't
// get the endpoint, create a new "temporary" one. It will only
// exist while there's a route through it.
n.mu.Lock()
if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() {
n.mu.Unlock()
return ref
}
ref, err := n.addAddressLocked(protocol, dst, CanBePrimaryEndpoint, true)
n.mu.Unlock()
if err == nil {
ref.holdsInsertRef = false
return ref
}
}
return nil
}
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
// TODO 需要完成逻辑
log.Println(vv.ToView())
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
if len(vv.First()) < netProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
src, dst := netProto.ParseAddresses(vv.First())
log.Printf("设备[%s]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, vv.ToView())
// 根据网络协议和数据包的目的地址,找到网络端
// 然后将数据包分发给网络层
if ref := n.getRef(protocol, dst); ref != nil {
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
r.RemoteLinkAddress = remoteLinkAddr
ref.ep.HandlePacket(&r, vv)
ref.decRef()
return
}
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
// 网络端引用
type referencedNetworkEndpoint struct {
ilist.Entry
refs int32 // 引用计数
ep NetworkEndpoint // 网络端实现
nic *NIC
protocol tcpip.NetworkProtocolNumber
// linkCache is set if link address resolution is enabled for this
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
linkAddrCache
// holdsInsertRef is protected by the NIC's mutex. It indicates whether
// the reference count is biased by 1 due to the insertion of the
// endpoint. It is reset to false when RemoveAddress is called on the
// NIC.
holdsInsertRef bool
}
func (r *referencedNetworkEndpoint) decRef() {
if atomic.AddInt32(&r.refs, -1) == 0 {
r.nic.removeEndpoint(r)
}
}
func (r *referencedNetworkEndpoint) incRef() {
atomic.AddInt32(&r.refs, 1)
}
func (r *referencedNetworkEndpoint) tryIncRef() bool {
for {
v := atomic.LoadInt32(&r.refs)
if v == 0 {
return false
}
if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
return true
}
}
}

View File

@@ -1,8 +1,6 @@
package stack
import (
"log"
"netstack/ilist"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/buffer"
@@ -108,13 +106,70 @@ var (
// ==============================网络层相关==============================
type NetworkProtocol interface {
// 网络协议版本号
Number() tcpip.NetworkProtocolNumber
// todo 需要添加
// MinimumPacketSize returns the minimum valid packet size of this
// network protocol. The stack automatically drops any packets smaller
// than this targeted at this protocol.
MinimumPacketSize() int
// ParsePorts returns the source and destination addresses stored in a
// packet of this protocol.
ParseAddresses(v buffer.View) (src, dst tcpip.Address)
// 新建一个网络终端 比如 ipv4 或者 ipv6 的一个实现
NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache,
dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
// SetOption allows enabling/disabling protocol specific features.
// SetOption returns an error if the option is not supported or the
// provided option value is invalid.
SetOption(option interface{}) *tcpip.Error
// Option allows retrieving protocol specific option values.
// Option returns an error if the option is not supported or the
// provided option value is invalid.
Option(option interface{}) *tcpip.Error
}
// NetworkEndpoint是需要由网络层协议例如ipv4ipv6的端点实现的接口
type NetworkEndpoint interface {
// TODO 需要添加
// DefaultTTL is the default time-to-live value (or hop limit, in ipv6)
// for this endpoint.
DefaultTTL() uint8
// MTU is the maximum transmission unit for this endpoint. This is
// generally calculated as the MTU of the underlying data link endpoint
// minus the network endpoint max header length.
MTU() uint32
// Capabilities returns the set of capabilities supported by the
// underlying link-layer endpoint.
Capabilities() LinkEndpointCapabilities
// MaxHeaderLength returns the maximum size the network (and lower
// level layers combined) headers can have. Higher levels use this
// information to reserve space in the front of the packets they're
// building.
MaxHeaderLength() uint16
// WritePacket writes a packet to the given destination address and
// protocol.
WritePacket(r *Route, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error
// ID returns the network protocol endpoint ID.
ID() *NetworkEndpointID
// NICID returns the id of the NIC this endpoint belongs to.
NICID() tcpip.NICID
// HandlePacket is called by the link layer when new packets arrive to
// this network endpoint.
HandlePacket(r *Route, vv buffer.VectorisedView)
// Close is called when the endpoint is reomved from a stack.
Close()
}
type NetworkEndpointID struct {
@@ -137,29 +192,16 @@ type TransportEndpoint interface {
}
// TODO 需要解读
type referencedNetworkEndpoint struct {
ilist.Entry
refs int32
ep NetworkEndpoint
nic *NIC
protocol tcpip.NetworkProtocolNumber
type TransportProtocol interface {
}
// linkCache is set if link address resolution is enabled for this
// protocol. Set to nil otherwise.
linkCache LinkAddressCache
linkAddrCache
// holdsInsertRef is protected by the NIC's mutex. It indicates whether
// the reference count is biased by 1 due to the insertion of the
// endpoint. It is reset to false when RemoveAddress is called on the
// NIC.
holdsInsertRef bool
// TODO 需要解读
type TransportDispatcher interface {
}
// 注册一个新的网络协议工厂
func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
networkProtocols[name] = p
log.Println(networkProtocols)
}
// 注册一个链路层设备

View File

@@ -25,3 +25,15 @@ type Route struct {
// 相关的网络终端
ref *referencedNetworkEndpoint
}
// 根据参数新建一个路由,并关联一个网络层端
func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address,
localLinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint) Route {
return Route{
NetProto: netProto,
LocalAddress: localAddr,
LocalLinkAddress: localLinkAddr,
RemoteAddress: remoteAddr,
ref: ref,
}
}

View File

@@ -1,7 +1,7 @@
package stack
import (
"log"
"netstack/sleep"
"netstack/tcpip"
"netstack/tcpip/ports"
"sync"
@@ -49,12 +49,32 @@ type Stack struct {
clock tcpip.Clock
}
func New(network []string) *Stack {
// Options contains optional Stack configuration.
type Options struct {
// Clock is an optional clock source used for timestampping packets.
//
// If no Clock is specified, the clock source will be time.Now.
Clock tcpip.Clock
// Stats are optional statistic counters.
Stats tcpip.Stats
}
func New(network []string, transport []string, opts Options) *Stack {
clock := opts.Clock
if clock == nil {
clock = &tcpip.StdClock{}
}
s := &Stack{
transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
nics: make(map[tcpip.NICID]*NIC),
//linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
//PortManager: ports.NewPortManager(),
clock: clock,
stats: opts.Stats.FillIn(),
}
// 添加指定的网络端协议 必须已经在init中注册过
@@ -62,7 +82,6 @@ func New(network []string) *Stack {
// 先检查这个网络协议是否注册过工厂方法
netProtoFactory, ok := networkProtocols[name]
if !ok {
log.Println(name)
continue // 没有就略过
}
netProto := netProtoFactory() // 制造一个该型号协议的示实例
@@ -119,3 +138,59 @@ func (s *Stack) AddAddressWithOptions(id tcpip.NICID, protocol tcpip.NetworkProt
return nic.AddAddressWithOptions(protocol, addr, peb)
}
// AddSubnet adds a subnet range to the specified NIC.
func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
nic.AddSubnet(protocol, subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
// RemoveSubnet removes the subnet range from the specified NIC.
func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
nic.RemoveSubnet(subnet)
return nil
}
return tcpip.ErrUnknownNICID
}
// ContainsSubnet reports whether the specified NIC contains the specified
// subnet.
func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
if nic, ok := s.nics[id]; ok {
return nic.ContainsSubnet(subnet), nil
}
return false, tcpip.ErrUnknownNICID
}
func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
return 0
}
func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
}
func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address,
protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
return "", nil, nil
}
func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
}

View File

@@ -10,16 +10,86 @@ import (
)
const (
defaultMTU = 65536
fakeNetHeaderLen = 12
defaultMTU = 65536
)
type fakeNetworkProtocol struct {
type fakeNetworkEndpoint struct {
nicid tcpip.NICID
id stack.NetworkEndpointID
proto *fakeNetworkProtocol
dispatcher stack.TransportDispatcher
linkEP stack.LinkEndpoint
}
func (f *fakeNetworkEndpoint) DefaultTTL() uint8 {
return 123
}
func (f *fakeNetworkEndpoint) MTU() uint32 {
return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
}
func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return f.linkEP.Capabilities()
}
func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
}
func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr buffer.Prependable, payload buffer.VectorisedView,
protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error {
return nil
}
func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
return &f.id
}
func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
return f.nicid
}
func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) {
log.Println("执行这个函数 接下来它会去向传输层分发数据")
}
func (f *fakeNetworkEndpoint) Close() {}
// dst|src|payload
type fakeNetworkProtocol struct{}
func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
return 114514
}
func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache,
dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
return &fakeNetworkEndpoint{
nicid: nicid,
id: stack.NetworkEndpointID{addr},
proto: f,
dispatcher: dispatcher,
linkEP: linkEP,
}, nil
}
func (f *fakeNetworkProtocol) MinimumPacketSize() int {
return fakeNetHeaderLen
}
func (f *fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
return tcpip.Address(v[4:8]), tcpip.Address(v[0:4])
}
func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
return nil
}
func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
return nil
}
func init() {
stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
@@ -28,18 +98,34 @@ func init() {
func TestStackBase(t *testing.T) {
myStack := stack.New([]string{"fakeNet"})
id, ep := channel.New(10, defaultMTU, "") // 这是一个物理设备
log.Println(id)
myStack := stack.New([]string{"fakeNet"}, nil, stack.Options{})
id1, ep1 := channel.New(10, defaultMTU, "00:15:5d:26:d7:a1") // 这是一个物理设备
if err := myStack.CreateNIC(1, id); err != nil { // 将上面的物理设备抽象成我们的网卡对象
if err := myStack.CreateNIC(1, id1); err != nil { // 将上面的物理设备抽象成我们的网卡对象
panic(err)
}
myStack.AddAddress(1, 114514, "\x01") // 给网卡对象绑定一个IP地址 可以绑定多个
myStack.AddAddress(1, 114514, "\x0a\xff\x01\x01") // 给网卡对象绑定一个IP地址 可以绑定多个
id2, _ := channel.New(10, defaultMTU, "50:5B:C2:D0:96:57") // 这是一个物理设备
if err := myStack.CreateNIC(2, id2); err != nil { // 将上面的物理设备抽象成我们的网卡对象
panic(err)
}
myStack.AddAddress(1, 114514, "\x0a\xff\x01\x02") // 给网卡对象绑定一个IP地址 可以绑定多个
buf := buffer.NewView(30)
for i := range buf {
buf[i] = 1
buf[i] = 0
}
ep.Inject(114514, buf.ToVectoriseView())
// dst 10.255.1.2
buf[0] = '\x0a'
buf[1] = '\xff'
buf[2] = '\x01'
buf[3] = '\x02'
// src 10.255.1.1
buf[4] = '\x0a'
buf[5] = '\xff'
buf[6] = '\x01'
buf[7] = '\x01'
ep1.Inject(114514, buf.ToVectoriseView())
}