mirror of
https://github.com/impact-eintr/netstack.git
synced 2025-10-08 06:10:04 +08:00
udp通信的Connect 和 Read 结束 明天看Waiter 这相当于linux内核的事件驱动机制
当有某种事件就绪后通知waiter 监听着waiter的监听者就能通过waiter得知事件已经发生 从而不再阻塞
This commit is contained in:
@@ -60,7 +60,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
|
||||
id: id,
|
||||
name: name,
|
||||
linkEP: ep,
|
||||
demux: nil, // TODO 需要处理
|
||||
demux: newTransportDemuxer(stack), // NOTE 注册网卡自己的传输层分流器
|
||||
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
|
||||
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
|
||||
}
|
||||
@@ -302,6 +302,75 @@ func (n *NIC) Subnets() []tcpip.Subnet {
|
||||
return append(sns, n.subnets...)
|
||||
}
|
||||
|
||||
// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket,用来分发网络层数据包。
|
||||
// 比如 protocol 是 arp 协议号,那么会找到arp.HandlePacket来处理数据报。
|
||||
// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它,
|
||||
// 当前实现的网络层协议有 arp、ipv4 和 ipv6。
|
||||
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
|
||||
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
|
||||
netProto, ok := n.stack.networkProtocols[protocol]
|
||||
if !ok {
|
||||
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
|
||||
return
|
||||
}
|
||||
|
||||
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
|
||||
n.stack.stats.IP.PacketsReceived.Increment()
|
||||
}
|
||||
|
||||
if len(vv.First()) < netProto.MinimumPacketSize() {
|
||||
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||
return
|
||||
}
|
||||
src, dst := netProto.ParseAddresses(vv.First())
|
||||
log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte {
|
||||
if len(vv.ToView()) > 64 {
|
||||
return vv.ToView()[:64]
|
||||
}
|
||||
return vv.ToView()
|
||||
}())
|
||||
// 根据网络协议和数据包的目的地址,找到网络端
|
||||
// 然后将数据包分发给网络层
|
||||
if ref := n.getRef(protocol, dst); ref != nil {
|
||||
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
|
||||
r.RemoteLinkAddress = remoteLinkAddr
|
||||
ref.ep.HandlePacket(&r, vv)
|
||||
ref.decRef()
|
||||
return
|
||||
}
|
||||
|
||||
if n.stack.Forwarding() {
|
||||
r, err := n.stack.FindRoute(0, "", dst, protocol)
|
||||
if err != nil {
|
||||
n.stack.stats.IP.InvalidAddressesReceived.Increment()
|
||||
return
|
||||
}
|
||||
defer r.Release()
|
||||
|
||||
r.LocalLinkAddress = n.linkEP.LinkAddress()
|
||||
r.RemoteLinkAddress = remoteLinkAddr
|
||||
|
||||
// Found a NIC.
|
||||
n := r.ref.nic
|
||||
n.mu.RLock()
|
||||
ref, ok := n.endpoints[NetworkEndpointID{dst}]
|
||||
n.mu.RUnlock()
|
||||
if ok && ref.tryIncRef() {
|
||||
ref.ep.HandlePacket(&r, vv)
|
||||
ref.decRef()
|
||||
} else {
|
||||
// n doesn't have a destination endpoint.
|
||||
// Send the packet out of n.
|
||||
hdr := buffer.NewPrependableFromView(vv.First())
|
||||
vv.RemoveFirst()
|
||||
n.linkEP.WritePacket(&r, hdr, vv, protocol)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
n.stack.stats.IP.InvalidAddressesReceived.Increment()
|
||||
}
|
||||
|
||||
// 根据协议类型和目标地址,找出关联的Endpoint
|
||||
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
|
||||
id := NetworkEndpointID{dst}
|
||||
@@ -344,57 +413,49 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket,用来分发网络层数据包。
|
||||
// 比如 protocol 是 arp 协议号,那么会找到arp.HandlePacket来处理数据报。
|
||||
// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它,
|
||||
// 当前实现的网络层协议有 arp、ipv4 和 ipv6。
|
||||
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
|
||||
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
|
||||
netProto, ok := n.stack.networkProtocols[protocol]
|
||||
if !ok {
|
||||
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
|
||||
return
|
||||
}
|
||||
|
||||
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
|
||||
n.stack.stats.IP.PacketsReceived.Increment()
|
||||
}
|
||||
|
||||
if len(vv.First()) < netProto.MinimumPacketSize() {
|
||||
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||
return
|
||||
}
|
||||
src, dst := netProto.ParseAddresses(vv.First())
|
||||
log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte {
|
||||
if len(vv.ToView()) > 64 {
|
||||
return vv.ToView()[:64]
|
||||
}
|
||||
return vv.ToView()
|
||||
}())
|
||||
// 根据网络协议和数据包的目的地址,找到网络端
|
||||
// 然后将数据包分发给网络层
|
||||
if ref := n.getRef(protocol, dst); ref != nil {
|
||||
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
|
||||
r.RemoteLinkAddress = remoteLinkAddr
|
||||
ref.ep.HandlePacket(&r, vv)
|
||||
ref.decRef()
|
||||
|
||||
return
|
||||
}
|
||||
n.stack.stats.IP.InvalidAddressesReceived.Increment()
|
||||
}
|
||||
|
||||
// DeliverTransportPacket delivers packets to the appropriate
|
||||
// transport protocol endpoint.
|
||||
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
|
||||
// 先查找协议栈是否注册了该传输层协议
|
||||
_, ok := n.stack.transportProtocols[protocol]
|
||||
state, ok := n.stack.transportProtocols[protocol]
|
||||
if !ok {
|
||||
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
|
||||
return
|
||||
}
|
||||
log.Println("准备分发传输层数据报", n.stack.transportProtocols)
|
||||
transProto := state.proto
|
||||
// 如果报文长度比该协议最小报文长度还小,那么丢弃它
|
||||
if len(vv.First()) < transProto.MinimumPacketSize() {
|
||||
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||
return
|
||||
}
|
||||
// 解析报文得到源端口和目的端口
|
||||
srcPort, dstPort, err := transProto.ParsePorts(vv.First())
|
||||
if err != nil {
|
||||
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||
return
|
||||
}
|
||||
log.Println("准备分发传输层数据报", n.stack.transportProtocols, srcPort, dstPort)
|
||||
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
|
||||
// 调用分流器,根据传输层协议和传输层id分发数据报文
|
||||
if n.demux.deliverPacket(r, protocol, vv, id) {
|
||||
return
|
||||
}
|
||||
if n.stack.demux.deliverPacket(r, protocol, vv, id) {
|
||||
return
|
||||
}
|
||||
|
||||
// Try to deliver to per-stack default handler.
|
||||
if state.defaultHandler != nil {
|
||||
if state.defaultHandler(r, id, vv) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// We could not find an appropriate destination for this packet, so
|
||||
// deliver it to the global handler.
|
||||
if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
|
||||
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||
}
|
||||
}
|
||||
|
||||
// DeliverTransportControlPacket delivers control packets to the
|
||||
|
@@ -185,7 +185,7 @@ type NetworkEndpointID struct {
|
||||
type TransportEndpointID struct {
|
||||
LocalPort uint16
|
||||
LocalAddress tcpip.Address
|
||||
remotePort uint16
|
||||
RemotePort uint16
|
||||
RemoteAddress tcpip.Address
|
||||
}
|
||||
|
||||
|
@@ -116,12 +116,87 @@ func New(network []string, transport []string, opts Options) *Stack {
|
||||
proto: transProto,
|
||||
}
|
||||
}
|
||||
// 添加传输层分流器
|
||||
// NOTE 添加协议栈全局传输层分流器
|
||||
s.demux = newTransportDemuxer(s)
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// SetNetworkProtocolOption allows configuring individual protocol level
|
||||
// options. This method returns an error if the protocol is not supported or
|
||||
// option is not supported by the protocol implementation or the provided value
|
||||
// is incorrect.
|
||||
func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
|
||||
netProto, ok := s.networkProtocols[network]
|
||||
if !ok {
|
||||
return tcpip.ErrUnknownProtocol
|
||||
}
|
||||
return netProto.SetOption(option)
|
||||
}
|
||||
|
||||
// NetworkProtocolOption allows retrieving individual protocol level option
|
||||
// values. This method returns an error if the protocol is not supported or
|
||||
// option is not supported by the protocol implementation.
|
||||
// e.g.
|
||||
// var v ipv4.MyOption
|
||||
// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
|
||||
//
|
||||
// if err != nil {
|
||||
// ...
|
||||
// }
|
||||
func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
|
||||
netProto, ok := s.networkProtocols[network]
|
||||
if !ok {
|
||||
return tcpip.ErrUnknownProtocol
|
||||
}
|
||||
return netProto.Option(option)
|
||||
}
|
||||
|
||||
// SetTransportProtocolOption allows configuring individual protocol level
|
||||
// options. This method returns an error if the protocol is not supported or
|
||||
// option is not supported by the protocol implementation or the provided value
|
||||
// is incorrect.
|
||||
func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
|
||||
transProtoState, ok := s.transportProtocols[transport]
|
||||
if !ok {
|
||||
return tcpip.ErrUnknownProtocol
|
||||
}
|
||||
return transProtoState.proto.SetOption(option)
|
||||
}
|
||||
|
||||
// TransportProtocolOption allows retrieving individual protocol level option
|
||||
// values. This method returns an error if the protocol is not supported or
|
||||
// option is not supported by the protocol implementation.
|
||||
// var v tcp.SACKEnabled
|
||||
//
|
||||
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
|
||||
// ...
|
||||
// }
|
||||
func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
|
||||
transProtoState, ok := s.transportProtocols[transport]
|
||||
if !ok {
|
||||
return tcpip.ErrUnknownProtocol
|
||||
}
|
||||
return transProtoState.proto.Option(option)
|
||||
}
|
||||
|
||||
// SetTransportProtocolHandler sets the per-stack default handler for the given
|
||||
// protocol.
|
||||
//
|
||||
// It must be called only during initialization of the stack. Changing it as the
|
||||
// stack is operating is not supported.
|
||||
func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
|
||||
state := s.transportProtocols[p]
|
||||
if state != nil {
|
||||
state.defaultHandler = h
|
||||
}
|
||||
}
|
||||
|
||||
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
|
||||
func (s *Stack) NowNanoseconds() int64 {
|
||||
return s.clock.NowNanoseconds()
|
||||
}
|
||||
|
||||
func (s *Stack) Stats() tcpip.Stats {
|
||||
return s.stats
|
||||
}
|
||||
@@ -260,19 +335,19 @@ func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpi
|
||||
return false, tcpip.ErrUnknownNICID
|
||||
}
|
||||
|
||||
// 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息
|
||||
// FindRoute 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息
|
||||
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address,
|
||||
netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
for i := range s.routeTable {
|
||||
if (id != 0 && id != s.routeTable[i].NIC) ||
|
||||
if (id != 0 && id != s.routeTable[i].NIC) || // 检查是否是对应的网卡
|
||||
(len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
|
||||
continue
|
||||
}
|
||||
|
||||
nic := s.nics[s.routeTable[i].NIC]
|
||||
nic := s.nics[s.routeTable[i].NIC] // 在协议栈里找到这张网卡
|
||||
if nic == nil {
|
||||
continue
|
||||
}
|
||||
@@ -372,14 +447,34 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
|
||||
// 最终调用 demuxer.registerEndpoint 函数来实现注册。
|
||||
func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
|
||||
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
||||
// TODO 需要实现
|
||||
return nil
|
||||
log.Println("往", nicID, "网卡注册新的传输端")
|
||||
if nicID == 0 {
|
||||
return s.demux.registerEndpoint(netProtos, protocol, id, ep) // 给协议栈的所有网卡注册传输端
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
nic := s.nics[nicID]
|
||||
if nic == nil {
|
||||
return tcpip.ErrUnknownNICID
|
||||
}
|
||||
return nic.demux.registerEndpoint(netProtos, protocol, id, ep) // 给这张网卡注册传输端
|
||||
}
|
||||
|
||||
// UnregisterTransportEndpoint removes the endpoint with the given id from the
|
||||
// stack transport dispatcher.
|
||||
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
|
||||
protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
|
||||
if nicID == 0 {
|
||||
s.demux.unregisterEndpoint(netProtos, protocol, id)
|
||||
return
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
nic := s.nics[nicID]
|
||||
if nic != nil {
|
||||
nic.demux.unregisterEndpoint(netProtos, protocol, id)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@@ -2,6 +2,7 @@ package stack
|
||||
|
||||
import (
|
||||
"netstack/tcpip"
|
||||
"netstack/tcpip/buffer"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -23,6 +24,112 @@ type transportDemuxer struct {
|
||||
}
|
||||
|
||||
// 新建一个分流器
|
||||
func newTransportDemuxer(stacl *Stack) *transportDemuxer {
|
||||
func newTransportDemuxer(stack *Stack) *transportDemuxer {
|
||||
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
|
||||
|
||||
for netProto := range stack.networkProtocols {
|
||||
for tranProto := range stack.transportProtocols {
|
||||
d.protocol[protocolIDs{network: netProto, transport: tranProto}] = &transportEndpoints{
|
||||
endpoints: make(map[TransportEndpointID]TransportEndpoint),
|
||||
}
|
||||
}
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// registerEndpoint 向分发器注册给定端点,以便将与端点ID匹配的数据包传递给它
|
||||
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber,
|
||||
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
||||
for i, n := range netProtos {
|
||||
if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
|
||||
d.unregisterEndpoint(netProtos[:i], protocol, id) // 把刚才注册的注销掉
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber,
|
||||
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
||||
eps, ok := d.protocol[protocolIDs{netProto, protocol}] // IPv4:udp
|
||||
if !ok { // 未曾注册过这个传输端集合
|
||||
return nil
|
||||
}
|
||||
|
||||
eps.mu.Lock()
|
||||
defer eps.mu.Unlock()
|
||||
|
||||
if _, ok := eps.endpoints[id]; ok { // 遍历传输端集合
|
||||
return tcpip.ErrPortInUse
|
||||
}
|
||||
eps.endpoints[id] = ep
|
||||
return nil
|
||||
}
|
||||
|
||||
// unregisterEndpoint 使用给定的id注销端点,使其不再接收任何数据包
|
||||
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber,
|
||||
protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
|
||||
for _, n := range netProtos {
|
||||
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
|
||||
eps.mu.Lock()
|
||||
delete(eps.endpoints, id)
|
||||
eps.mu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 根据传输层的id来找到对应的传输端,再将数据包交给这个传输端处理
|
||||
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
|
||||
// 先看看分流器里有没有注册相关协议端,如果没有则返回false
|
||||
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
// 从 eps 中找符合 id 的传输端
|
||||
eps.mu.RLock()
|
||||
ep := d.findEndpointLocked(eps, vv, id)
|
||||
eps.mu.RUnlock()
|
||||
|
||||
if ep == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Deliver the packet
|
||||
ep.HandlePacket(r, id, vv)
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
|
||||
trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// 根据传输层id来找到相应的传输层端
|
||||
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints,
|
||||
vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
|
||||
if ep := eps.endpoints[id]; ep != nil { // IPv4:udp
|
||||
return ep
|
||||
}
|
||||
// Try to find a match with the id minus the local address.
|
||||
nid := id
|
||||
// 如果上面的 endpoints 没有找到,那么去掉本地ip地址,看看有没有相应的传输层端
|
||||
// 因为有时候传输层监听的时候没有绑定本地ip,也就是 any address,此时的 LocalAddress
|
||||
// 为空。
|
||||
nid.LocalAddress = ""
|
||||
if ep := eps.endpoints[nid]; ep != nil {
|
||||
return ep
|
||||
}
|
||||
|
||||
// Try to find a match with the id minus the remote part.
|
||||
nid.LocalAddress = id.LocalAddress
|
||||
nid.RemoteAddress = ""
|
||||
nid.RemotePort = 0
|
||||
if ep := eps.endpoints[nid]; ep != nil {
|
||||
return ep
|
||||
}
|
||||
|
||||
// Try to find a match with only the local port.
|
||||
nid.LocalAddress = ""
|
||||
return eps.endpoints[nid]
|
||||
}
|
||||
|
Reference in New Issue
Block a user