udp通信的Connect 和 Read 结束 明天看Waiter 这相当于linux内核的事件驱动机制

当有某种事件就绪后通知waiter 监听着waiter的监听者就能通过waiter得知事件已经发生 从而不再阻塞
This commit is contained in:
impact-eintr
2022-12-01 22:36:40 +08:00
parent 3d8ca3c0c8
commit be40f904fc
11 changed files with 741 additions and 123 deletions

View File

@@ -12,36 +12,31 @@ import (
"netstack/tcpip/link/tuntap"
"netstack/tcpip/network/arp"
"netstack/tcpip/network/ipv4"
"netstack/tcpip/network/ipv6"
"netstack/tcpip/stack"
"netstack/tcpip/transport/udp"
"netstack/waiter"
"os"
"strings"
"time"
)
var mac = flag.String("mac", "01:01:01:01:01:01", "mac address to use in tap device")
func main() {
flag.Parse()
if len(flag.Args()) != 2 {
log.Fatal("Usage: ", os.Args[0], " <tap-device> <listen-address>")
if len(flag.Args()) < 2 {
log.Fatal("Usage: ", os.Args[0], " <tap-device> <local-address/mask>")
}
log.SetFlags(log.Lshortfile | log.LstdFlags)
tapName := flag.Arg(0)
listeAddr := flag.Arg(1)
cidrName := flag.Arg(1)
log.Printf("tap: %v, listeAddr: %v", tapName, listeAddr)
log.Printf("tap: %v, cidrName: %v", tapName, cidrName)
// Parse the mac address.
maddr, err := net.ParseMAC(*mac)
parsedAddr, cidr, err := net.ParseCIDR(cidrName)
if err != nil {
log.Fatalf("Bad MAC address: %v", *mac)
log.Fatalf("Bad cidr: %v", cidrName)
}
parsedAddr := net.ParseIP(listeAddr)
// 解析地址ip地址ipv4或者ipv6地址都支持
var addr tcpip.Address
var proto tcpip.NetworkProtocolNumber
@@ -50,7 +45,7 @@ func main() {
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
addr = tcpip.Address(parsedAddr.To16())
proto = ipv6.ProtocolNumber
//proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", parsedAddr)
}
@@ -69,17 +64,22 @@ func main() {
}
// 启动tap网卡
_ = tuntap.SetLinkUp(tapName)
// 设置tap网卡IP地址
_ = tuntap.AddIP(tapName, listeAddr)
tuntap.SetLinkUp(tapName)
// 设置路由
tuntap.SetRoute(tapName, cidr.String())
// 获取mac地址
mac, err := tuntap.GetHardwareAddr(tapName)
if err != nil {
panic(err)
}
// 抽象网卡的文件接口
linkID := fdbased.New(&fdbased.Options{
FD: fd,
MTU: 1500,
Address: tcpip.LinkAddress(maddr),
Address: tcpip.LinkAddress(mac),
})
// 新建相关协议的协议栈
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName},
[]string{ /*tcp.ProtocolName, */ udp.ProtocolName}, stack.Options{})
@@ -109,37 +109,33 @@ func main() {
},
})
conn, _ := net.Listen("tcp", "0.0.0.0:9999")
TCPServer(conn, &RCV{s, nil, nil})
// 同时监听tcp和udp localPort端口
//tcpEp := tcpListen(s, proto, localPort)
//udpEp := udpListen(s, proto, localPort)
// 关闭监听服务,此时会释放端口
//tcpEp.Close()
//udpEp.Close()
}
go func() {
// 监听udp localPort端口
udpEp := udpListen(s, proto, 9999)
//func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint {
// var wq waiter.Queue
// // 新建一个tcp端
// ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
// if err != nil {
// log.Fatal(err)
// }
//
// // 绑定IP和端口这里的IP地址为空表示绑定任何IP
// // 此时就会调用端口管理器
// if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}, nil); err != nil {
// log.Fatal("Bind failed: ", err)
// }
//
// // 开始监听
// if err := ep.Listen(10); err != nil {
// log.Fatal("Listen failed: ", err)
// }
//
// return ep
//}
for {
buf, _, err := udpEp.Read(nil)
if err != nil {
if err == tcpip.ErrWouldBlock {
time.Sleep(100 * time.Millisecond)
log.Println("阻塞中")
continue
}
}
log.Println(buf)
break
}
// 关闭监听服务,此时会释放端口
udpEp.Close()
}()
conn, _ := net.Listen("tcp", "0.0.0.0:9999")
rcv := &RCV{
Stack: s,
addr: tcpip.FullAddress{},
}
TCPServer(conn, rcv)
}
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint {
var wq waiter.Queue
@@ -156,10 +152,6 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int)
log.Fatal("Bind failed: ", err)
}
if err := ep.Connect(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}); err != nil {
log.Fatal("Conn failed: ", err)
}
// 注意UDP是无连接的它不需要Listen
return ep
}
@@ -167,6 +159,7 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int)
type RCV struct {
*stack.Stack
ep tcpip.Endpoint
addr tcpip.FullAddress
rcvBuf []byte
}
@@ -189,6 +182,7 @@ func (r *RCV) Handle(conn net.Conn) {
}
r.ep = ep
r.Bind()
r.Connect()
r.Close()
case "tcp":
default:
@@ -202,12 +196,16 @@ func (r *RCV) Bind() {
return
}
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
addr := tcpip.FullAddress{
NIC: 0,
r.addr = tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(r.rcvBuf[3:7]),
Port: port,
}
r.ep.Bind(addr, nil)
r.ep.Bind(r.addr, nil)
}
func (r *RCV) Connect() {
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
}
func (r *RCV) Close() {

42
cmd/udp_client/main.go Normal file
View File

@@ -0,0 +1,42 @@
package main
import (
"flag"
"log"
"net"
)
func main() {
var (
addr = flag.String("a", "192.168.1.1:9999", "udp dst address")
)
log.SetFlags(log.Lshortfile | log.LstdFlags)
udpAddr, err := net.ResolveUDPAddr("udp", *addr)
if err != nil {
panic(err)
}
log.Println("解析地址")
// 建立UDP连接只是填息了目的IP和端口并未真正的建立连接
conn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
panic(err)
}
log.Println("TEST")
send := []byte("hello")
if _, err := conn.Write(send); err != nil {
panic(err)
}
log.Printf("send: %s", string(send))
//recv := make([]byte, 10)
//rn, _, err := conn.ReadFrom(recv)
//if err != nil {
// panic(err)
//}
//log.Printf("recv: %s", string(recv[:rn]))
}

View File

@@ -2,7 +2,10 @@ package header
import "netstack/tcpip"
// 校验和的计算
// Checksum 校验和的计算
// UDP 检验和的计算方法是: 按每 16 位求和得出一个 32 位的数;
// 如果这个 32 位的数,高 16 位不为 0则高 16 位加低 16 位再得到一个 32 位的数;
// 重复第 2 步直到高 16 位为 0将低 16 位取反,得到校验和。
func Checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)

View File

@@ -1,6 +1,10 @@
package header
import "netstack/tcpip"
import (
"encoding/binary"
"fmt"
"netstack/tcpip"
)
const (
udpSrcPort = 0
@@ -36,3 +40,99 @@ const (
// UDPProtocolNumber is UDP's transport protocol number.
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
)
/*
UDP 是 User Datagram Protocol 的简称中文名是用户数据报协议。UDP 只在 IP 数据报服务上增加了一点功能就是复用和分用的功能以及差错检测UDP 主要的特点是:
1. UDP 是无连接的,即发送数据之前不需要建立连接,发送结束也不需要连接释放,因此减少了开销和发送数据之间的延时。
2. UDP 是不可靠传输,尽最大努力交付,因此不需要维护复杂的连接状态。
3. UDP 的数据报是有消息边界的,发送方发送一个报文,接收方就会完整的收到一个报文。
4. UDP 没有拥塞控制网络出现阻塞UDP 是无感知的,也就不会降低发送速度。
5. UDP 支持一对一,一对多,多对一,多对多的通信。
*/
/*
|source Port|destination Port|
| Length | UDP Checksum |
| Data |
*/
// SourcePort returns the "source port" field of the udp header.
func (b UDP) SourcePort() uint16 {
return binary.BigEndian.Uint16(b[udpSrcPort:])
}
// DestinationPort returns the "destination port" field of the udp header.
func (b UDP) DestinationPort() uint16 {
return binary.BigEndian.Uint16(b[udpDstPort:])
}
// Length returns the "length" field of the udp header.
func (b UDP) Length() uint16 {
return binary.BigEndian.Uint16(b[udpLength:])
}
// Payload returns the data contained in the UDP datagram.
func (b UDP) Payload() []byte {
return b[UDPMinimumSize:]
}
// Checksum returns the "checksum" field of the udp header.
func (b UDP) Checksum() uint16 {
return binary.BigEndian.Uint16(b[udpChecksum:])
}
// SetSourcePort sets the "source port" field of the udp header.
func (b UDP) SetSourcePort(port uint16) {
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
}
// SetDestinationPort sets the "destination port" field of the udp header.
func (b UDP) SetDestinationPort(port uint16) {
binary.BigEndian.PutUint16(b[udpDstPort:], port)
}
// SetChecksum sets the "checksum" field of the udp header.
func (b UDP) SetChecksum(checksum uint16) {
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
}
// CalculateChecksum calculates the checksum of the udp packet, given the total
// length of the packet and the checksum of the network-layer pseudo-header
// (excluding the total length) and the checksum of the payload.
func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
// Add the length portion of the checksum to the pseudo-checksum.
tmp := make([]byte, 2)
binary.BigEndian.PutUint16(tmp, totalLen)
checksum := Checksum(tmp, partialChecksum)
// Calculate the rest of the checksum.
return Checksum(b[:UDPMinimumSize], checksum)
}
// Encode encodes all the fields of the udp header.
func (b UDP) Encode(u *UDPFields) {
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
}
var udpFmt string = `
|% 16s|% 16s|
|% 16s|% 16s|
%v
`
func (b UDP) String() string {
for i := range b.Payload() {
if i != int(b.Length()-8-1) && b.Payload()[i]^b.Payload()[i+1] != 0 {
return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()),
atoi(b.Length()), atoi(b.Checksum()),
b.Payload())
}
}
return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()),
atoi(b.Length()), atoi(b.Checksum()),
fmt.Sprintf("%v x %d", b.Payload()[0], b.Length()-8))
}

View File

@@ -60,7 +60,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
id: id,
name: name,
linkEP: ep,
demux: nil, // TODO 需要处理
demux: newTransportDemuxer(stack), // NOTE 注册网卡自己的传输层分流器
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
}
@@ -302,6 +302,75 @@ func (n *NIC) Subnets() []tcpip.Subnet {
return append(sns, n.subnets...)
}
// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket用来分发网络层数据包。
// 比如 protocol 是 arp 协议号那么会找到arp.HandlePacket来处理数据报。
// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它,
// 当前实现的网络层协议有 arp、ipv4 和 ipv6。
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
if len(vv.First()) < netProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
src, dst := netProto.ParseAddresses(vv.First())
log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte {
if len(vv.ToView()) > 64 {
return vv.ToView()[:64]
}
return vv.ToView()
}())
// 根据网络协议和数据包的目的地址,找到网络端
// 然后将数据包分发给网络层
if ref := n.getRef(protocol, dst); ref != nil {
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
r.RemoteLinkAddress = remoteLinkAddr
ref.ep.HandlePacket(&r, vv)
ref.decRef()
return
}
if n.stack.Forwarding() {
r, err := n.stack.FindRoute(0, "", dst, protocol)
if err != nil {
n.stack.stats.IP.InvalidAddressesReceived.Increment()
return
}
defer r.Release()
r.LocalLinkAddress = n.linkEP.LinkAddress()
r.RemoteLinkAddress = remoteLinkAddr
// Found a NIC.
n := r.ref.nic
n.mu.RLock()
ref, ok := n.endpoints[NetworkEndpointID{dst}]
n.mu.RUnlock()
if ok && ref.tryIncRef() {
ref.ep.HandlePacket(&r, vv)
ref.decRef()
} else {
// n doesn't have a destination endpoint.
// Send the packet out of n.
hdr := buffer.NewPrependableFromView(vv.First())
vv.RemoveFirst()
n.linkEP.WritePacket(&r, hdr, vv, protocol)
}
return
}
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
// 根据协议类型和目标地址找出关联的Endpoint
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
id := NetworkEndpointID{dst}
@@ -344,57 +413,49 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
return nil
}
// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket用来分发网络层数据包。
// 比如 protocol 是 arp 协议号那么会找到arp.HandlePacket来处理数据报。
// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它,
// 当前实现的网络层协议有 arp、ipv4 和 ipv6。
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
netProto, ok := n.stack.networkProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
n.stack.stats.IP.PacketsReceived.Increment()
}
if len(vv.First()) < netProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
src, dst := netProto.ParseAddresses(vv.First())
log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte {
if len(vv.ToView()) > 64 {
return vv.ToView()[:64]
}
return vv.ToView()
}())
// 根据网络协议和数据包的目的地址,找到网络端
// 然后将数据包分发给网络层
if ref := n.getRef(protocol, dst); ref != nil {
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
r.RemoteLinkAddress = remoteLinkAddr
ref.ep.HandlePacket(&r, vv)
ref.decRef()
return
}
n.stack.stats.IP.InvalidAddressesReceived.Increment()
}
// DeliverTransportPacket delivers packets to the appropriate
// transport protocol endpoint.
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
// 先查找协议栈是否注册了该传输层协议
_, ok := n.stack.transportProtocols[protocol]
state, ok := n.stack.transportProtocols[protocol]
if !ok {
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
return
}
log.Println("准备分发传输层数据报", n.stack.transportProtocols)
transProto := state.proto
// 如果报文长度比该协议最小报文长度还小,那么丢弃它
if len(vv.First()) < transProto.MinimumPacketSize() {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
// 解析报文得到源端口和目的端口
srcPort, dstPort, err := transProto.ParsePorts(vv.First())
if err != nil {
n.stack.stats.MalformedRcvdPackets.Increment()
return
}
log.Println("准备分发传输层数据报", n.stack.transportProtocols, srcPort, dstPort)
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
// 调用分流器根据传输层协议和传输层id分发数据报文
if n.demux.deliverPacket(r, protocol, vv, id) {
return
}
if n.stack.demux.deliverPacket(r, protocol, vv, id) {
return
}
// Try to deliver to per-stack default handler.
if state.defaultHandler != nil {
if state.defaultHandler(r, id, vv) {
return
}
}
// We could not find an appropriate destination for this packet, so
// deliver it to the global handler.
if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
n.stack.stats.MalformedRcvdPackets.Increment()
}
}
// DeliverTransportControlPacket delivers control packets to the

View File

@@ -185,7 +185,7 @@ type NetworkEndpointID struct {
type TransportEndpointID struct {
LocalPort uint16
LocalAddress tcpip.Address
remotePort uint16
RemotePort uint16
RemoteAddress tcpip.Address
}

View File

@@ -116,12 +116,87 @@ func New(network []string, transport []string, opts Options) *Stack {
proto: transProto,
}
}
// 添加传输层分流器
// NOTE 添加协议栈全局传输层分流器
s.demux = newTransportDemuxer(s)
return s
}
// SetNetworkProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
}
return netProto.SetOption(option)
}
// NetworkProtocolOption allows retrieving individual protocol level option
// values. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation.
// e.g.
// var v ipv4.MyOption
// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
//
// if err != nil {
// ...
// }
func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
netProto, ok := s.networkProtocols[network]
if !ok {
return tcpip.ErrUnknownProtocol
}
return netProto.Option(option)
}
// SetTransportProtocolOption allows configuring individual protocol level
// options. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation or the provided value
// is incorrect.
func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
}
return transProtoState.proto.SetOption(option)
}
// TransportProtocolOption allows retrieving individual protocol level option
// values. This method returns an error if the protocol is not supported or
// option is not supported by the protocol implementation.
// var v tcp.SACKEnabled
//
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
// ...
// }
func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
transProtoState, ok := s.transportProtocols[transport]
if !ok {
return tcpip.ErrUnknownProtocol
}
return transProtoState.proto.Option(option)
}
// SetTransportProtocolHandler sets the per-stack default handler for the given
// protocol.
//
// It must be called only during initialization of the stack. Changing it as the
// stack is operating is not supported.
func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
state := s.transportProtocols[p]
if state != nil {
state.defaultHandler = h
}
}
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
func (s *Stack) NowNanoseconds() int64 {
return s.clock.NowNanoseconds()
}
func (s *Stack) Stats() tcpip.Stats {
return s.stats
}
@@ -260,19 +335,19 @@ func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpi
return false, tcpip.ErrUnknownNICID
}
// 路由查找实现比如当tcp建立连接时会用该函数得到路由信息
// FindRoute 路由查找实现比如当tcp建立连接时会用该函数得到路由信息
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address,
netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
s.mu.RLock()
defer s.mu.RUnlock()
for i := range s.routeTable {
if (id != 0 && id != s.routeTable[i].NIC) ||
if (id != 0 && id != s.routeTable[i].NIC) || // 检查是否是对应的网卡
(len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
continue
}
nic := s.nics[s.routeTable[i].NIC]
nic := s.nics[s.routeTable[i].NIC] // 在协议栈里找到这张网卡
if nic == nil {
continue
}
@@ -372,14 +447,34 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
// 最终调用 demuxer.registerEndpoint 函数来实现注册。
func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
// TODO 需要实现
return nil
log.Println("往", nicID, "网卡注册新的传输端")
if nicID == 0 {
return s.demux.registerEndpoint(netProtos, protocol, id, ep) // 给协议栈的所有网卡注册传输端
}
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[nicID]
if nic == nil {
return tcpip.ErrUnknownNICID
}
return nic.demux.registerEndpoint(netProtos, protocol, id, ep) // 给这张网卡注册传输端
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
if nicID == 0 {
s.demux.unregisterEndpoint(netProtos, protocol, id)
return
}
s.mu.RLock()
defer s.mu.RUnlock()
nic := s.nics[nicID]
if nic != nil {
nic.demux.unregisterEndpoint(netProtos, protocol, id)
}
}

View File

@@ -2,6 +2,7 @@ package stack
import (
"netstack/tcpip"
"netstack/tcpip/buffer"
"sync"
)
@@ -23,6 +24,112 @@ type transportDemuxer struct {
}
// 新建一个分流器
func newTransportDemuxer(stacl *Stack) *transportDemuxer {
func newTransportDemuxer(stack *Stack) *transportDemuxer {
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
for netProto := range stack.networkProtocols {
for tranProto := range stack.transportProtocols {
d.protocol[protocolIDs{network: netProto, transport: tranProto}] = &transportEndpoints{
endpoints: make(map[TransportEndpointID]TransportEndpoint),
}
}
}
return d
}
// registerEndpoint 向分发器注册给定端点以便将与端点ID匹配的数据包传递给它
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber,
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
for i, n := range netProtos {
if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
d.unregisterEndpoint(netProtos[:i], protocol, id) // 把刚才注册的注销掉
return err
}
}
return nil
}
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber,
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
eps, ok := d.protocol[protocolIDs{netProto, protocol}] // IPv4:udp
if !ok { // 未曾注册过这个传输端集合
return nil
}
eps.mu.Lock()
defer eps.mu.Unlock()
if _, ok := eps.endpoints[id]; ok { // 遍历传输端集合
return tcpip.ErrPortInUse
}
eps.endpoints[id] = ep
return nil
}
// unregisterEndpoint 使用给定的id注销端点使其不再接收任何数据包
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber,
protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
for _, n := range netProtos {
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
eps.mu.Lock()
delete(eps.endpoints, id)
eps.mu.Unlock()
}
}
}
// 根据传输层的id来找到对应的传输端再将数据包交给这个传输端处理
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
// 先看看分流器里有没有注册相关协议端如果没有则返回false
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
if !ok {
return false
}
// 从 eps 中找符合 id 的传输端
eps.mu.RLock()
ep := d.findEndpointLocked(eps, vv, id)
eps.mu.RUnlock()
if ep == nil {
return false
}
// Deliver the packet
ep.HandlePacket(r, id, vv)
return true
}
func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
return false
}
// 根据传输层id来找到相应的传输层端
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints,
vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
if ep := eps.endpoints[id]; ep != nil { // IPv4:udp
return ep
}
// Try to find a match with the id minus the local address.
nid := id
// 如果上面的 endpoints 没有找到那么去掉本地ip地址看看有没有相应的传输层端
// 因为有时候传输层监听的时候没有绑定本地ip也就是 any address此时的 LocalAddress
// 为空。
nid.LocalAddress = ""
if ep := eps.endpoints[nid]; ep != nil {
return ep
}
// Try to find a match with the id minus the remote part.
nid.LocalAddress = id.LocalAddress
nid.RemoteAddress = ""
nid.RemotePort = 0
if ep := eps.endpoints[nid]; ep != nil {
return ep
}
// Try to find a match with only the local port.
nid.LocalAddress = ""
return eps.endpoints[nid]
}

View File

@@ -3,6 +3,7 @@
![img](https://doc.shiyanlou.com/document-uid949121labid10418timestamp1555488741384.png)
传输层是整个网络体系结构中的关键之一,我们很多编程都是直接和传输层打交道的,我们需要了解以下的概念:
1. 端口的意义 - 上一章已经介绍过了
2. 无连接 UDP 协议及特点 - 本章介绍
3. 面向连接 TCP 协议及特点 - 下章会介绍

View File

@@ -13,7 +13,13 @@ import (
// udp报文结构 当收到udp报文时 会用这个结构来保存udp报文数据
type udpPacket struct {
udpPacketEntry // 链表实现
// TODO 需要添加
senderAddress tcpip.FullAddress
data buffer.VectorisedView
timestamp int64
hasTimestamp bool
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
views [8]buffer.View
}
type endpointState int
@@ -40,7 +46,7 @@ type endpoint struct {
rcvBufSizeMax int
rcvBufSize int
rcvClosed bool
rcvTimestamp bool
rcvTimestamp bool // 通过SetSocket进行设置 是否开启时间戳
// The following fields are protected by the mu mutex.
mu sync.RWMutex
@@ -57,7 +63,7 @@ type endpoint struct {
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
// TODO
multicastMemberships []multicastMembership
// effectiveNetProtos contains the network protocols actually in use. In
// most cases it will only contain "netProto", but in cases like IPv6
@@ -68,6 +74,12 @@ type endpoint struct {
effectiveNetProtos []tcpip.NetworkProtocolNumber
}
// 多播的成员关系包括多播地址和网卡ID
type multicastMembership struct {
nicID tcpip.NICID
multicastAddr tcpip.Address
}
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber,
waiterQueue *waiter.Queue) *endpoint {
log.Println("新建一个udp端")
@@ -76,8 +88,32 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber,
netProto: netProto,
waiterQueue: waiterQueue,
multicastTTL: 1,
rcvBufSizeMax: 32 * 1024,
sndBufSize: 32 * 1024}
rcvBufSizeMax: 32 * 1024, // 接收缓存 32k
sndBufSize: 32 * 1024, // 发送缓存 32k
}
}
// NewConnectedEndpoint creates a new endpoint in the connected state using the
// provided route.
func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID,
waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
ep := newEndpoint(stack, r.NetProto, waiterQueue)
// Register new endpoint so that packets are routed to it.
if err := stack.RegisterTransportEndpoint(r.NICID(),
[]tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
ep.Close()
return nil, err
}
ep.id = id
ep.route = r.Clone()
ep.dstPort = id.RemotePort
ep.regNICID = r.NICID()
ep.state = stateConnected
return ep, nil
}
// Close UDP端的关闭释放相应的资源
@@ -98,8 +134,37 @@ func (e *endpoint) Close() {
e.mu.Unlock()
}
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
return nil, tcpip.ControlMessages{}, nil
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
e.rcvMu.Lock()
// 如果接收链表为空,即没有任何数据
if e.rcvList.Empty() {
err := tcpip.ErrWouldBlock
if e.rcvClosed {
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
return buffer.View{}, tcpip.ControlMessages{}, err
}
// 从接收链表中取出最前面的数据报,接着从链表中删除该数据报
// 然后减少接收缓存的大小
p := e.rcvList.Front()
e.rcvList.Remove(p)
e.rcvBufSize -= p.data.Size()
ts := e.rcvTimestamp
e.rcvMu.Unlock()
if ts && !p.hasTimestamp {
// Linux uses the current time.
p.timestamp = e.stack.NowNanoseconds()
}
if addr != nil {
// 赋值发送地址
*addr = p.senderAddress
}
return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil
}
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
@@ -141,8 +206,95 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
return netProto, nil
}
func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
log.Println("连接")
// Connect UDP中调用connect内核仅仅把对端ip&port记录下来. 这样在发送数据的时候无需再次指定
// UDP多次调用connect有两种用途:1,指定一个新的ip&port连结. 2,断开和之前的ip&port的连结
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// 目标端口为0是错误的
if addr.Port == 0 {
// We don't support connecting to port zero.
return tcpip.ErrInvalidEndpointState
}
e.mu.Lock()
defer e.mu.Unlock()
nicid := addr.NIC
var localPort uint16
// 判断UDP端的状态
switch e.state {
case stateInitial:
// 如果是初始状态,直接下一步
case stateBound, stateConnected:
localPort = e.id.LocalPort
log.Printf("绑定了 %d 的udp端 向[%d]网卡发起连接\n", localPort, nicid)
if e.bindNICID == 0 {
break
}
if nicid != 0 && nicid != e.bindNICID {
return tcpip.ErrInvalidEndpointState
}
nicid = e.bindNICID
default:
return tcpip.ErrInvalidEndpointState
}
// 检查地址的映射,得到相应的协议
netProto, err := e.checkV4Mapped(&addr, false)
if err != nil {
return err
}
// Find a route to the desired destination.
// 在全局协议栈中查找路由
r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
if err != nil {
return err
}
defer r.Release()
// 新建一个传输端的标识包括源IP、源端口、目的IP、目的端口
id := stack.TransportEndpointID{
LocalAddress: r.LocalAddress,
LocalPort: localPort,
RemotePort: addr.Port,
RemoteAddress: r.RemoteAddress,
}
// 设置网络层协议IPV4或IPV6或两者都有
netProtos := []tcpip.NetworkProtocolNumber{netProto}
if netProto == header.IPv6ProtocolNumber && !e.v6only {
netProtos = []tcpip.NetworkProtocolNumber{
header.IPv4ProtocolNumber,
header.IPv6ProtocolNumber,
}
}
// 将该UDP端注册到协议栈中
id, err = e.registerWithStack(nicid, netProtos, id)
if err != nil {
return err
}
// Remove the old registration.
// 如果源端口不为0则尝试在传输层端中删除老的UDP端
if e.id.LocalPort != 0 {
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
}
log.Println(e.id, id)
// 赋值UDP端的属性
e.id = id
e.route = r.Clone()
e.dstPort = addr.Port
e.regNICID = nicid
e.effectiveNetProtos = netProtos
// 更改该UDP端的状态为已连接
e.state = stateConnected
// 标志该UDP端可以接收数据了
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
return nil
}
@@ -167,7 +319,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
}
id.LocalPort = port
}
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) // 往网卡注册一个绑定了端口的udp端
if err != nil {
// 释放端口
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
@@ -206,6 +358,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
LocalAddress: addr.Addr,
LocalPort: addr.Port,
}
log.Println("Bind", id)
// 在协议栈中注册该UDP端
id, err = e.registerWithStack(addr.NIC, netProtos, id)
if err != nil {
@@ -229,6 +382,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
// 标记状态为已绑定
e.state = stateBound
e.rcvMu.Lock()
e.rcvReady = true
e.rcvMu.Unlock()
return nil
}
@@ -271,9 +428,64 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
// 从网络层接收到UDP数据报时的处理函数
// HandlePacket 从网络层接收到UDP数据报时的处理函数
// 首先 UDP 端有接收队列的概念,不像网络层接收到数据包立马发送给传输层,
// 对于协议栈来说,传输层是最后的一站,接下来的数据就需要交给用户层了,
// 但是用户层的行为是不可预知的,不知道用户层何时将数据取走(也就是 UDP Read 过程),
// 那么协议栈就实现一个接收队列,将接收的数据去掉 UDP 头部后保存在这个队列中,用户层需要的时候取走就可以了,
// 但是队列存数据量是有限制的,这个限制叫接收缓存大小,当接收队列中的数据总和超过这个缓存,那么接下来的这些报文将会被直接丢包。
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
// Get the header then trim it from the view.
hdr := header.UDP(vv.First())
if int(hdr.Length()) > vv.Size() {
// Malformed packet.
// 错误报文
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
return
}
log.Println("udp 正式接收数据", hdr)
// 去除UDP首部
vv.TrimFront(header.UDPMinimumSize)
e.rcvMu.Lock()
e.stack.Stats().UDP.PacketsReceived.Increment()
// Drop the packet if our buffer is currently full.
// 如果UDP的接收缓存已经满了那么丢弃报文。
if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
e.rcvMu.Unlock()
log.Println("udp 接收缓存不足 丢弃报文")
return
}
// 接收缓存是否为空
wasEmpty := e.rcvBufSize == 0
// 新建一个UDP数据包结构 插入到接收链表中
pkt := &udpPacket{
senderAddress: tcpip.FullAddress{
NIC: r.NICID(),
Addr: id.RemoteAddress,
Port: hdr.SourcePort(),
},
}
// 复制UDP数据包的用户数据
pkt.data = vv.Clone(pkt.views[:]) // 当vv中的数据<=8时 无需再次分配内存
// 插入到接收链表中 并增加已经使用的缓存
e.rcvList.PushBack(pkt)
e.rcvBufSize += vv.Size()
if e.rcvTimestamp {
pkt.timestamp = e.stack.NowNanoseconds()
pkt.hasTimestamp = true
}
e.rcvMu.Unlock()
// TODO 通知用户层可以读取数据了
if wasEmpty {
}
}
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.

View File

@@ -38,9 +38,8 @@ func (*protocol) MinimumPacketSize() int {
// ParsePorts returns the source and destination ports stored in the given udp
// packet.
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
//h := header.UDP(v)
//return h.SourcePort(), h.DestinationPort(), nil
return 0, 0, nil
h := header.UDP(v)
return h.SourcePort(), h.DestinationPort(), nil
}
// HandleUnknownDestinationPacket handles packets targeted at this protocol but