diff --git a/cmd/port/main.go b/cmd/port/main.go index 948ca37..5f6214a 100644 --- a/cmd/port/main.go +++ b/cmd/port/main.go @@ -152,8 +152,9 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) } // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP + // 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现 // 此时就会调用端口管理器 - if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}, nil); err != nil { + if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}, nil); err != nil { log.Fatal("Bind failed: ", err) } diff --git a/tcpip/ports/ports.go b/tcpip/ports/ports.go index 3134b31..571b6e1 100644 --- a/tcpip/ports/ports.go +++ b/tcpip/ports/ports.go @@ -1,6 +1,7 @@ package ports import ( + "log" "math" "math/rand" "netstack/tcpip" @@ -87,7 +88,7 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb for _, network := range networks { // 遍历网络协议 desc := portDescriptor{network: network, transport: transport, port: port} // 构造端口描述符 if addrs, ok := s.allocatedPorts[desc]; ok { // 检查端口描述符绑定的ip集合 - if !addrs.isAvailable(addr) { // 该集合中已经有这个ip + if !addrs.isAvailable(addr) { // 该集合中已经有这个ip 或者是"" 也就是 0.0.0.0 return false } } @@ -101,5 +102,64 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) { - return 0, nil + s.mu.Lock() + defer s.mu.Unlock() + // defer log.Println(transport, "成功分配端口", *(&reservedPort)) TODO 这样写就有问题 defer给直接取值了? + defer func() { + log.Println(transport, "成功分配端口", *(&reservedPort)) + }() + + // 指定端口进行绑定 + if port != 0 { + if !s.reserveSpecificPort(networks, transport, addr, port) { + return 0, tcpip.ErrPortInUse // 已经被占用 + } + reservedPort = port + return + } + // 随机分配 + reservedPort, err = s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + return s.reserveSpecificPort(networks, transport, addr, p), nil + }) + return reservedPort, nil +} + +// reserveSpecificPort 尝试根据协议号和IP地址绑定一个端口 +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, + addr tcpip.Address, port uint16) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port) { + return false + } + + // 根据给定的网络层协议号绑定端口 + for _, network := range networks { + desc := portDescriptor{network: network, transport: transport, port: port} // ipv4-udp-9999 + m, ok := s.allocatedPorts[desc] + if !ok { + m = make(bindAddresses) // Set of IP + s.allocatedPorts[desc] = m + } + // 注册该地址被绑定了 + m[addr] = struct{}{} + } + return true +} + +// ReleasePort 释放绑定的端口,以便别的程序复用。 +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, + addr tcpip.Address, port uint16) { + s.mu.Lock() + defer s.mu.Unlock() + + // 删除绑定关系 + for _, network := range networks { + desc := portDescriptor{network, transport, port} + if m, ok := s.allocatedPorts[desc]; ok { + log.Println(transport, "释放", port) + delete(m, addr) + if len(m) == 0 { + delete(s.allocatedPorts, desc) + } + } + } } diff --git a/tcpip/stack/registration.go b/tcpip/stack/registration.go index 55abc6e..ed2f2b3 100644 --- a/tcpip/stack/registration.go +++ b/tcpip/stack/registration.go @@ -181,8 +181,12 @@ type NetworkEndpointID struct { // ==============================传输层相关============================== +// TransportEndpointID 是某个传输层实现的标识 type TransportEndpointID struct { - // TODO + LocalPort uint16 + LocalAddress tcpip.Address + remotePort uint16 + RemoteAddress tcpip.Address } // ControlType 是网络层控制消息的类型 @@ -197,7 +201,7 @@ const ( // TransportEndpoint 传输层实现接口 type TransportEndpoint interface { HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) - HandleControlPacker(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) } // TransportProtocol 传输层协议 TCP OR UDP diff --git a/tcpip/stack/stack.go b/tcpip/stack/stack.go index fb2fcd7..d9ac26c 100644 --- a/tcpip/stack/stack.go +++ b/tcpip/stack/stack.go @@ -162,7 +162,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, if !ok { return nil, tcpip.ErrUnknownProtocol } - return t.proto.NewEndpoint(s, network, waiterQueue) + return t.proto.NewEndpoint(s, network, waiterQueue) // 新建一个传输层实现 } // CreateNIC 根据给定的网卡号 和 链路层设备号 创建一个网卡对象 @@ -300,7 +300,8 @@ func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, } // ===============本机链路层缓存实现================== -// 检查本地是否绑定过该网络层地址 + +// CheckLocalAddress 检查本地是否绑定过该网络层地址 注意 NICID 为0表示寻找本机所有网卡 func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { s.mu.RLock() defer s.mu.RUnlock() @@ -362,3 +363,39 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. s.linkAddrCache.removeWaker(fullAddr, waker) } } + +// RegisterTransportEndpoint 协议栈或者NIC的分流器注册给定传输层端点。 +// 收到的与提供的id匹配的数据包将被传送到给定的端点;指定nic是可选的,但特定于nic的ID优先于全局ID。 +// 最终调用 demuxer.registerEndpoint 函数来实现注册。 +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, + protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + // TODO 需要实现 + return nil +} + +// UnregisterTransportEndpoint removes the endpoint with the given id from the +// stack transport dispatcher. +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, + protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { + +} + +// NetworkProtocolInstance returns the protocol instance in the stack for the +// specified network protocol. This method is public for protocol implementers +// and tests to use. +func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol { + if p, ok := s.networkProtocols[num]; ok { + return p + } + return nil +} + +// TransportProtocolInstance returns the protocol instance in the stack for the +// specified transport protocol. This method is public for protocol implementers +// and tests to use. +func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol { + if pState, ok := s.transportProtocols[num]; ok { + return pState.proto + } + return nil +} diff --git a/tcpip/tcpip.go b/tcpip/tcpip.go index 5ed90fc..2d846dd 100644 --- a/tcpip/tcpip.go +++ b/tcpip/tcpip.go @@ -165,8 +165,20 @@ func (l LinkAddress) String() string { type LinkEndpointID uint64 +// TransportProtocolNumber 传输层协议号 type TransportProtocolNumber uint32 +const ( + UDPProtocolNumber = 17 +) + +func (t TransportProtocolNumber) String() string { + if t == UDPProtocolNumber { + return "UDP" + } + return "TCP" +} + type NetworkProtocolNumber uint32 type NICID int32 diff --git a/tcpip/transport/udp/endpoint.go b/tcpip/transport/udp/endpoint.go index 4be3ce8..9fbaf84 100644 --- a/tcpip/transport/udp/endpoint.go +++ b/tcpip/transport/udp/endpoint.go @@ -4,16 +4,28 @@ import ( "log" "netstack/tcpip" "netstack/tcpip/buffer" + "netstack/tcpip/header" "netstack/tcpip/stack" "netstack/waiter" "sync" ) // udp报文结构 当收到udp报文时 会用这个结构来保存udp报文数据 -type udpPacker struct { +type udpPacket struct { + udpPacketEntry // 链表实现 // TODO 需要添加 } +type endpointState int + +// 表示UDP端的状态参数 +const ( + stateInitial endpointState = iota + stateBound + stateConnected + stateClosed +) + type endpoint struct { stack *stack.Stack // udp所依赖的用户协议栈 netProto tcpip.NetworkProtocolNumber // udp网络协议号 ipv4/ipv6 @@ -31,18 +43,59 @@ type endpoint struct { rcvTimestamp bool // The following fields are protected by the mu mutex. - mu sync.RWMutex - // TODO 需要添加 + mu sync.RWMutex + sndBufSize int // 发送缓冲区大小 + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID // 绑定的网卡 + regNICID tcpip.NICID // + route stack.Route // 路由? TODO + dstPort uint16 // 目标端口 + v6only bool // 仅支持ipv6 + multicastTTL uint8 // 广播TTL + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // TODO + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). 当前生效的网络层协议 + effectiveNetProtos []tcpip.NetworkProtocolNumber } func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { - log.Println("新建传输层实现") - return &endpoint{} + log.Println("新建一个udp端") + return &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + multicastTTL: 1, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024} } +// Close UDP端的关闭,释放相应的资源 func (e *endpoint) Close() { + e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite + + switch e.state { + case stateBound, stateConnected: + // 释放在协议栈中注册的UDP端 + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + // 释放端口占用 + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + } + + // TODO + e.mu.Unlock() } func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { @@ -57,6 +110,37 @@ func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) return 0, tcpip.ControlMessages{}, nil } +// IPV6于IPV4地址的映射 +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + + // Fail if we are bound to an IPv6 address. + if !allowMismatch && len(e.id.LocalAddress) == 16 { + return 0, tcpip.ErrNetworkUnreachable + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + // 源地址用与目标地址使用的ip协议不能不一致 + if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error { return nil } @@ -73,12 +157,96 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { return nil, nil, nil } +func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, + id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if e.id.LocalPort == 0 { // 一个没有绑定过端口的udp端 + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) // 为这个udp端绑定一个端口 + if err != nil { + return id, err + } + id.LocalPort = port + } + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + if err != nil { + // 释放端口 + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + } + return id, err +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + // 不是初始状态的UDP实现不允许绑定 + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, true) + if err != nil { + return nil + } + + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { // IPv6 && 支持ipv4 & 任意地址 + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + // 不是任意地址的话 需要检验本地网卡是否绑定个这个ip地址 + if len(addr.Addr) != 0 { + if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + } + + // 开始绑定 绑定的时候 传输端ID : srcIP + srcPort + id := stack.TransportEndpointID{ + LocalAddress: addr.Addr, + LocalPort: addr.Port, + } + // 在协议栈中注册该UDP端 + id, err = e.registerWithStack(addr.NIC, netProtos, id) + if err != nil { + return err + } + // 如果指定了 commit 函数 执行并处理错误 + if commit != nil { + if err := commit(); err != nil { + // Unregister, the commit failed. + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + return err + } + } + + e.id = id + e.regNICID = addr.NIC + e.effectiveNetProtos = netProtos + + // Mark endpoint as bound. + // 标记状态为已绑定 + e.state = stateBound + + return nil +} + // Bind binds the endpoint to a specific local address and port. // Specifying a NIC is optional. // Bind 将该UDP端绑定本地的一个IP+端口 // 例如:绑定本地0.0.0.0的9000端口,那么其他机器给这台机器9000端口发消息,该UDP端就能收到消息了 -func (e *endpoint) Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { - log.Println("绑定端口", address) +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // 执行绑定IP+端口操作 + err := e.bindLocked(addr, commit) + if err != nil { + return err + } + + // 绑定的网卡ID + e.bindNICID = addr.NIC return nil } @@ -101,3 +269,12 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { return nil } + +// 从网络层接收到UDP数据报时的处理函数 +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +} diff --git a/tcpip/transport/udp/protocol.go b/tcpip/transport/udp/protocol.go index 613f236..11ed8c4 100644 --- a/tcpip/transport/udp/protocol.go +++ b/tcpip/transport/udp/protocol.go @@ -1,7 +1,6 @@ package udp import ( - "log" "netstack/tcpip" "netstack/tcpip/buffer" "netstack/tcpip/header" @@ -28,7 +27,6 @@ func (*protocol) Number() tcpip.TransportProtocolNumber { // NewEndpoint creates a new udp endpoint. func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - log.Println("新建udp传输层协议") return newEndpoint(stack, netProto, waiterQueue), nil }