mirror of
https://github.com/impact-eintr/netstack.git
synced 2025-10-06 13:26:49 +08:00
udp通信的Connect 和 Read 结束 明天看Waiter 这相当于linux内核的事件驱动机制
当有某种事件就绪后通知waiter 监听着waiter的监听者就能通过waiter得知事件已经发生 从而不再阻塞
This commit is contained in:
@@ -12,36 +12,31 @@ import (
|
|||||||
"netstack/tcpip/link/tuntap"
|
"netstack/tcpip/link/tuntap"
|
||||||
"netstack/tcpip/network/arp"
|
"netstack/tcpip/network/arp"
|
||||||
"netstack/tcpip/network/ipv4"
|
"netstack/tcpip/network/ipv4"
|
||||||
"netstack/tcpip/network/ipv6"
|
|
||||||
"netstack/tcpip/stack"
|
"netstack/tcpip/stack"
|
||||||
"netstack/tcpip/transport/udp"
|
"netstack/tcpip/transport/udp"
|
||||||
"netstack/waiter"
|
"netstack/waiter"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var mac = flag.String("mac", "01:01:01:01:01:01", "mac address to use in tap device")
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
if len(flag.Args()) != 2 {
|
if len(flag.Args()) < 2 {
|
||||||
log.Fatal("Usage: ", os.Args[0], " <tap-device> <listen-address>")
|
log.Fatal("Usage: ", os.Args[0], " <tap-device> <local-address/mask>")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.SetFlags(log.Lshortfile | log.LstdFlags)
|
log.SetFlags(log.Lshortfile | log.LstdFlags)
|
||||||
tapName := flag.Arg(0)
|
tapName := flag.Arg(0)
|
||||||
listeAddr := flag.Arg(1)
|
cidrName := flag.Arg(1)
|
||||||
|
|
||||||
log.Printf("tap: %v, listeAddr: %v", tapName, listeAddr)
|
log.Printf("tap: %v, cidrName: %v", tapName, cidrName)
|
||||||
|
|
||||||
// Parse the mac address.
|
parsedAddr, cidr, err := net.ParseCIDR(cidrName)
|
||||||
maddr, err := net.ParseMAC(*mac)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Bad MAC address: %v", *mac)
|
log.Fatalf("Bad cidr: %v", cidrName)
|
||||||
}
|
}
|
||||||
|
|
||||||
parsedAddr := net.ParseIP(listeAddr)
|
|
||||||
|
|
||||||
// 解析地址ip地址,ipv4或者ipv6地址都支持
|
// 解析地址ip地址,ipv4或者ipv6地址都支持
|
||||||
var addr tcpip.Address
|
var addr tcpip.Address
|
||||||
var proto tcpip.NetworkProtocolNumber
|
var proto tcpip.NetworkProtocolNumber
|
||||||
@@ -50,7 +45,7 @@ func main() {
|
|||||||
proto = ipv4.ProtocolNumber
|
proto = ipv4.ProtocolNumber
|
||||||
} else if parsedAddr.To16() != nil {
|
} else if parsedAddr.To16() != nil {
|
||||||
addr = tcpip.Address(parsedAddr.To16())
|
addr = tcpip.Address(parsedAddr.To16())
|
||||||
proto = ipv6.ProtocolNumber
|
//proto = ipv6.ProtocolNumber
|
||||||
} else {
|
} else {
|
||||||
log.Fatalf("Unknown IP type: %v", parsedAddr)
|
log.Fatalf("Unknown IP type: %v", parsedAddr)
|
||||||
}
|
}
|
||||||
@@ -69,17 +64,22 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 启动tap网卡
|
// 启动tap网卡
|
||||||
_ = tuntap.SetLinkUp(tapName)
|
tuntap.SetLinkUp(tapName)
|
||||||
// 设置tap网卡IP地址
|
// 设置路由
|
||||||
_ = tuntap.AddIP(tapName, listeAddr)
|
tuntap.SetRoute(tapName, cidr.String())
|
||||||
|
|
||||||
|
// 获取mac地址
|
||||||
|
mac, err := tuntap.GetHardwareAddr(tapName)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
// 抽象网卡的文件接口
|
// 抽象网卡的文件接口
|
||||||
linkID := fdbased.New(&fdbased.Options{
|
linkID := fdbased.New(&fdbased.Options{
|
||||||
FD: fd,
|
FD: fd,
|
||||||
MTU: 1500,
|
MTU: 1500,
|
||||||
Address: tcpip.LinkAddress(maddr),
|
Address: tcpip.LinkAddress(mac),
|
||||||
})
|
})
|
||||||
|
|
||||||
// 新建相关协议的协议栈
|
// 新建相关协议的协议栈
|
||||||
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName},
|
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName},
|
||||||
[]string{ /*tcp.ProtocolName, */ udp.ProtocolName}, stack.Options{})
|
[]string{ /*tcp.ProtocolName, */ udp.ProtocolName}, stack.Options{})
|
||||||
@@ -109,37 +109,33 @@ func main() {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
conn, _ := net.Listen("tcp", "0.0.0.0:9999")
|
go func() {
|
||||||
TCPServer(conn, &RCV{s, nil, nil})
|
// 监听udp localPort端口
|
||||||
// 同时监听tcp和udp localPort端口
|
udpEp := udpListen(s, proto, 9999)
|
||||||
//tcpEp := tcpListen(s, proto, localPort)
|
|
||||||
//udpEp := udpListen(s, proto, localPort)
|
|
||||||
// 关闭监听服务,此时会释放端口
|
|
||||||
//tcpEp.Close()
|
|
||||||
//udpEp.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
//func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint {
|
for {
|
||||||
// var wq waiter.Queue
|
buf, _, err := udpEp.Read(nil)
|
||||||
// // 新建一个tcp端
|
if err != nil {
|
||||||
// ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
|
if err == tcpip.ErrWouldBlock {
|
||||||
// if err != nil {
|
time.Sleep(100 * time.Millisecond)
|
||||||
// log.Fatal(err)
|
log.Println("阻塞中")
|
||||||
// }
|
continue
|
||||||
//
|
}
|
||||||
// // 绑定IP和端口,这里的IP地址为空,表示绑定任何IP
|
}
|
||||||
// // 此时就会调用端口管理器
|
log.Println(buf)
|
||||||
// if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}, nil); err != nil {
|
break
|
||||||
// log.Fatal("Bind failed: ", err)
|
}
|
||||||
// }
|
// 关闭监听服务,此时会释放端口
|
||||||
//
|
udpEp.Close()
|
||||||
// // 开始监听
|
}()
|
||||||
// if err := ep.Listen(10); err != nil {
|
|
||||||
// log.Fatal("Listen failed: ", err)
|
conn, _ := net.Listen("tcp", "0.0.0.0:9999")
|
||||||
// }
|
rcv := &RCV{
|
||||||
//
|
Stack: s,
|
||||||
// return ep
|
addr: tcpip.FullAddress{},
|
||||||
//}
|
}
|
||||||
|
TCPServer(conn, rcv)
|
||||||
|
}
|
||||||
|
|
||||||
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint {
|
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint {
|
||||||
var wq waiter.Queue
|
var wq waiter.Queue
|
||||||
@@ -156,10 +152,6 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int)
|
|||||||
log.Fatal("Bind failed: ", err)
|
log.Fatal("Bind failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ep.Connect(tcpip.FullAddress{NIC: 0, Addr: "", Port: uint16(localPort)}); err != nil {
|
|
||||||
log.Fatal("Conn failed: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注意UDP是无连接的,它不需要Listen
|
// 注意UDP是无连接的,它不需要Listen
|
||||||
return ep
|
return ep
|
||||||
}
|
}
|
||||||
@@ -167,6 +159,7 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int)
|
|||||||
type RCV struct {
|
type RCV struct {
|
||||||
*stack.Stack
|
*stack.Stack
|
||||||
ep tcpip.Endpoint
|
ep tcpip.Endpoint
|
||||||
|
addr tcpip.FullAddress
|
||||||
rcvBuf []byte
|
rcvBuf []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,6 +182,7 @@ func (r *RCV) Handle(conn net.Conn) {
|
|||||||
}
|
}
|
||||||
r.ep = ep
|
r.ep = ep
|
||||||
r.Bind()
|
r.Bind()
|
||||||
|
r.Connect()
|
||||||
r.Close()
|
r.Close()
|
||||||
case "tcp":
|
case "tcp":
|
||||||
default:
|
default:
|
||||||
@@ -202,12 +196,16 @@ func (r *RCV) Bind() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
|
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
|
||||||
addr := tcpip.FullAddress{
|
r.addr = tcpip.FullAddress{
|
||||||
NIC: 0,
|
NIC: 1,
|
||||||
Addr: tcpip.Address(r.rcvBuf[3:7]),
|
Addr: tcpip.Address(r.rcvBuf[3:7]),
|
||||||
Port: port,
|
Port: port,
|
||||||
}
|
}
|
||||||
r.ep.Bind(addr, nil)
|
r.ep.Bind(r.addr, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RCV) Connect() {
|
||||||
|
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *RCV) Close() {
|
func (r *RCV) Close() {
|
||||||
|
42
cmd/udp_client/main.go
Normal file
42
cmd/udp_client/main.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var (
|
||||||
|
addr = flag.String("a", "192.168.1.1:9999", "udp dst address")
|
||||||
|
)
|
||||||
|
|
||||||
|
log.SetFlags(log.Lshortfile | log.LstdFlags)
|
||||||
|
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", *addr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Println("解析地址")
|
||||||
|
|
||||||
|
// 建立UDP连接(只是填息了目的IP和端口,并未真正的建立连接)
|
||||||
|
conn, err := net.DialUDP("udp", nil, udpAddr)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Println("TEST")
|
||||||
|
|
||||||
|
send := []byte("hello")
|
||||||
|
|
||||||
|
if _, err := conn.Write(send); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Printf("send: %s", string(send))
|
||||||
|
|
||||||
|
//recv := make([]byte, 10)
|
||||||
|
//rn, _, err := conn.ReadFrom(recv)
|
||||||
|
//if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
//}
|
||||||
|
//log.Printf("recv: %s", string(recv[:rn]))
|
||||||
|
}
|
@@ -2,7 +2,10 @@ package header
|
|||||||
|
|
||||||
import "netstack/tcpip"
|
import "netstack/tcpip"
|
||||||
|
|
||||||
// 校验和的计算
|
// Checksum 校验和的计算
|
||||||
|
// UDP 检验和的计算方法是: 按每 16 位求和得出一个 32 位的数;
|
||||||
|
// 如果这个 32 位的数,高 16 位不为 0,则高 16 位加低 16 位再得到一个 32 位的数;
|
||||||
|
// 重复第 2 步直到高 16 位为 0,将低 16 位取反,得到校验和。
|
||||||
func Checksum(buf []byte, initial uint16) uint16 {
|
func Checksum(buf []byte, initial uint16) uint16 {
|
||||||
v := uint32(initial)
|
v := uint32(initial)
|
||||||
|
|
||||||
|
@@ -1,6 +1,10 @@
|
|||||||
package header
|
package header
|
||||||
|
|
||||||
import "netstack/tcpip"
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"netstack/tcpip"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
udpSrcPort = 0
|
udpSrcPort = 0
|
||||||
@@ -36,3 +40,99 @@ const (
|
|||||||
// UDPProtocolNumber is UDP's transport protocol number.
|
// UDPProtocolNumber is UDP's transport protocol number.
|
||||||
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
|
UDPProtocolNumber tcpip.TransportProtocolNumber = 17
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/*
|
||||||
|
UDP 是 User Datagram Protocol 的简称,中文名是用户数据报协议。UDP 只在 IP 数据报服务上增加了一点功能,就是复用和分用的功能以及差错检测,UDP 主要的特点是:
|
||||||
|
|
||||||
|
1. UDP 是无连接的,即发送数据之前不需要建立连接,发送结束也不需要连接释放,因此减少了开销和发送数据之间的延时。
|
||||||
|
2. UDP 是不可靠传输,尽最大努力交付,因此不需要维护复杂的连接状态。
|
||||||
|
3. UDP 的数据报是有消息边界的,发送方发送一个报文,接收方就会完整的收到一个报文。
|
||||||
|
4. UDP 没有拥塞控制,网络出现阻塞,UDP 是无感知的,也就不会降低发送速度。
|
||||||
|
5. UDP 支持一对一,一对多,多对一,多对多的通信。
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
|source Port|destination Port|
|
||||||
|
| Length | UDP Checksum |
|
||||||
|
| Data |
|
||||||
|
*/
|
||||||
|
|
||||||
|
// SourcePort returns the "source port" field of the udp header.
|
||||||
|
func (b UDP) SourcePort() uint16 {
|
||||||
|
return binary.BigEndian.Uint16(b[udpSrcPort:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// DestinationPort returns the "destination port" field of the udp header.
|
||||||
|
func (b UDP) DestinationPort() uint16 {
|
||||||
|
return binary.BigEndian.Uint16(b[udpDstPort:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length returns the "length" field of the udp header.
|
||||||
|
func (b UDP) Length() uint16 {
|
||||||
|
return binary.BigEndian.Uint16(b[udpLength:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Payload returns the data contained in the UDP datagram.
|
||||||
|
func (b UDP) Payload() []byte {
|
||||||
|
return b[UDPMinimumSize:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checksum returns the "checksum" field of the udp header.
|
||||||
|
func (b UDP) Checksum() uint16 {
|
||||||
|
return binary.BigEndian.Uint16(b[udpChecksum:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSourcePort sets the "source port" field of the udp header.
|
||||||
|
func (b UDP) SetSourcePort(port uint16) {
|
||||||
|
binary.BigEndian.PutUint16(b[udpSrcPort:], port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDestinationPort sets the "destination port" field of the udp header.
|
||||||
|
func (b UDP) SetDestinationPort(port uint16) {
|
||||||
|
binary.BigEndian.PutUint16(b[udpDstPort:], port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetChecksum sets the "checksum" field of the udp header.
|
||||||
|
func (b UDP) SetChecksum(checksum uint16) {
|
||||||
|
binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateChecksum calculates the checksum of the udp packet, given the total
|
||||||
|
// length of the packet and the checksum of the network-layer pseudo-header
|
||||||
|
// (excluding the total length) and the checksum of the payload.
|
||||||
|
func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
|
||||||
|
// Add the length portion of the checksum to the pseudo-checksum.
|
||||||
|
tmp := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(tmp, totalLen)
|
||||||
|
checksum := Checksum(tmp, partialChecksum)
|
||||||
|
|
||||||
|
// Calculate the rest of the checksum.
|
||||||
|
return Checksum(b[:UDPMinimumSize], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes all the fields of the udp header.
|
||||||
|
func (b UDP) Encode(u *UDPFields) {
|
||||||
|
binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
|
||||||
|
binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
|
||||||
|
binary.BigEndian.PutUint16(b[udpLength:], u.Length)
|
||||||
|
binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
var udpFmt string = `
|
||||||
|
|% 16s|% 16s|
|
||||||
|
|% 16s|% 16s|
|
||||||
|
%v
|
||||||
|
`
|
||||||
|
|
||||||
|
func (b UDP) String() string {
|
||||||
|
for i := range b.Payload() {
|
||||||
|
if i != int(b.Length()-8-1) && b.Payload()[i]^b.Payload()[i+1] != 0 {
|
||||||
|
return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()),
|
||||||
|
atoi(b.Length()), atoi(b.Checksum()),
|
||||||
|
b.Payload())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Sprintf(udpFmt, atoi(b.SourcePort()), atoi(b.DestinationPort()),
|
||||||
|
atoi(b.Length()), atoi(b.Checksum()),
|
||||||
|
fmt.Sprintf("%v x %d", b.Payload()[0], b.Length()-8))
|
||||||
|
}
|
||||||
|
@@ -60,7 +60,7 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
|
|||||||
id: id,
|
id: id,
|
||||||
name: name,
|
name: name,
|
||||||
linkEP: ep,
|
linkEP: ep,
|
||||||
demux: nil, // TODO 需要处理
|
demux: newTransportDemuxer(stack), // NOTE 注册网卡自己的传输层分流器
|
||||||
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
|
primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
|
||||||
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
|
endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
|
||||||
}
|
}
|
||||||
@@ -302,6 +302,75 @@ func (n *NIC) Subnets() []tcpip.Subnet {
|
|||||||
return append(sns, n.subnets...)
|
return append(sns, n.subnets...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket,用来分发网络层数据包。
|
||||||
|
// 比如 protocol 是 arp 协议号,那么会找到arp.HandlePacket来处理数据报。
|
||||||
|
// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它,
|
||||||
|
// 当前实现的网络层协议有 arp、ipv4 和 ipv6。
|
||||||
|
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
|
||||||
|
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
|
||||||
|
netProto, ok := n.stack.networkProtocols[protocol]
|
||||||
|
if !ok {
|
||||||
|
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
|
||||||
|
n.stack.stats.IP.PacketsReceived.Increment()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(vv.First()) < netProto.MinimumPacketSize() {
|
||||||
|
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
src, dst := netProto.ParseAddresses(vv.First())
|
||||||
|
log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte {
|
||||||
|
if len(vv.ToView()) > 64 {
|
||||||
|
return vv.ToView()[:64]
|
||||||
|
}
|
||||||
|
return vv.ToView()
|
||||||
|
}())
|
||||||
|
// 根据网络协议和数据包的目的地址,找到网络端
|
||||||
|
// 然后将数据包分发给网络层
|
||||||
|
if ref := n.getRef(protocol, dst); ref != nil {
|
||||||
|
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
|
||||||
|
r.RemoteLinkAddress = remoteLinkAddr
|
||||||
|
ref.ep.HandlePacket(&r, vv)
|
||||||
|
ref.decRef()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if n.stack.Forwarding() {
|
||||||
|
r, err := n.stack.FindRoute(0, "", dst, protocol)
|
||||||
|
if err != nil {
|
||||||
|
n.stack.stats.IP.InvalidAddressesReceived.Increment()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer r.Release()
|
||||||
|
|
||||||
|
r.LocalLinkAddress = n.linkEP.LinkAddress()
|
||||||
|
r.RemoteLinkAddress = remoteLinkAddr
|
||||||
|
|
||||||
|
// Found a NIC.
|
||||||
|
n := r.ref.nic
|
||||||
|
n.mu.RLock()
|
||||||
|
ref, ok := n.endpoints[NetworkEndpointID{dst}]
|
||||||
|
n.mu.RUnlock()
|
||||||
|
if ok && ref.tryIncRef() {
|
||||||
|
ref.ep.HandlePacket(&r, vv)
|
||||||
|
ref.decRef()
|
||||||
|
} else {
|
||||||
|
// n doesn't have a destination endpoint.
|
||||||
|
// Send the packet out of n.
|
||||||
|
hdr := buffer.NewPrependableFromView(vv.First())
|
||||||
|
vv.RemoveFirst()
|
||||||
|
n.linkEP.WritePacket(&r, hdr, vv, protocol)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n.stack.stats.IP.InvalidAddressesReceived.Increment()
|
||||||
|
}
|
||||||
|
|
||||||
// 根据协议类型和目标地址,找出关联的Endpoint
|
// 根据协议类型和目标地址,找出关联的Endpoint
|
||||||
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
|
func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint {
|
||||||
id := NetworkEndpointID{dst}
|
id := NetworkEndpointID{dst}
|
||||||
@@ -344,57 +413,49 @@ func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *r
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeliverNetworkPacket 当 NIC 从物理接口接收数据包时,将调用函数 DeliverNetworkPacket,用来分发网络层数据包。
|
|
||||||
// 比如 protocol 是 arp 协议号,那么会找到arp.HandlePacket来处理数据报。
|
|
||||||
// 简单来说就是根据网络层协议和目的地址来找到相应的网络层端,将网络层数据发给它,
|
|
||||||
// 当前实现的网络层协议有 arp、ipv4 和 ipv6。
|
|
||||||
func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr, localLinkAddr tcpip.LinkAddress,
|
|
||||||
protocol tcpip.NetworkProtocolNumber, vv buffer.VectorisedView) {
|
|
||||||
netProto, ok := n.stack.networkProtocols[protocol]
|
|
||||||
if !ok {
|
|
||||||
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if netProto.Number() == header.IPv4ProtocolNumber || netProto.Number() == header.IPv6ProtocolNumber {
|
|
||||||
n.stack.stats.IP.PacketsReceived.Increment()
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(vv.First()) < netProto.MinimumPacketSize() {
|
|
||||||
n.stack.stats.MalformedRcvdPackets.Increment()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
src, dst := netProto.ParseAddresses(vv.First())
|
|
||||||
log.Printf("设备[%v]准备从 [%s] 向 [%s] 分发数据: %v\n", linkEP.LinkAddress(), src, dst, func() []byte {
|
|
||||||
if len(vv.ToView()) > 64 {
|
|
||||||
return vv.ToView()[:64]
|
|
||||||
}
|
|
||||||
return vv.ToView()
|
|
||||||
}())
|
|
||||||
// 根据网络协议和数据包的目的地址,找到网络端
|
|
||||||
// 然后将数据包分发给网络层
|
|
||||||
if ref := n.getRef(protocol, dst); ref != nil {
|
|
||||||
r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref)
|
|
||||||
r.RemoteLinkAddress = remoteLinkAddr
|
|
||||||
ref.ep.HandlePacket(&r, vv)
|
|
||||||
ref.decRef()
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
n.stack.stats.IP.InvalidAddressesReceived.Increment()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeliverTransportPacket delivers packets to the appropriate
|
// DeliverTransportPacket delivers packets to the appropriate
|
||||||
// transport protocol endpoint.
|
// transport protocol endpoint.
|
||||||
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
|
func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView) {
|
||||||
// 先查找协议栈是否注册了该传输层协议
|
// 先查找协议栈是否注册了该传输层协议
|
||||||
_, ok := n.stack.transportProtocols[protocol]
|
state, ok := n.stack.transportProtocols[protocol]
|
||||||
if !ok {
|
if !ok {
|
||||||
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
|
n.stack.stats.UnknownProtocolRcvdPackets.Increment()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Println("准备分发传输层数据报", n.stack.transportProtocols)
|
transProto := state.proto
|
||||||
|
// 如果报文长度比该协议最小报文长度还小,那么丢弃它
|
||||||
|
if len(vv.First()) < transProto.MinimumPacketSize() {
|
||||||
|
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 解析报文得到源端口和目的端口
|
||||||
|
srcPort, dstPort, err := transProto.ParsePorts(vv.First())
|
||||||
|
if err != nil {
|
||||||
|
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("准备分发传输层数据报", n.stack.transportProtocols, srcPort, dstPort)
|
||||||
|
id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
|
||||||
|
// 调用分流器,根据传输层协议和传输层id分发数据报文
|
||||||
|
if n.demux.deliverPacket(r, protocol, vv, id) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n.stack.demux.deliverPacket(r, protocol, vv, id) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to deliver to per-stack default handler.
|
||||||
|
if state.defaultHandler != nil {
|
||||||
|
if state.defaultHandler(r, id, vv) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// We could not find an appropriate destination for this packet, so
|
||||||
|
// deliver it to the global handler.
|
||||||
|
if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
|
||||||
|
n.stack.stats.MalformedRcvdPackets.Increment()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeliverTransportControlPacket delivers control packets to the
|
// DeliverTransportControlPacket delivers control packets to the
|
||||||
|
@@ -185,7 +185,7 @@ type NetworkEndpointID struct {
|
|||||||
type TransportEndpointID struct {
|
type TransportEndpointID struct {
|
||||||
LocalPort uint16
|
LocalPort uint16
|
||||||
LocalAddress tcpip.Address
|
LocalAddress tcpip.Address
|
||||||
remotePort uint16
|
RemotePort uint16
|
||||||
RemoteAddress tcpip.Address
|
RemoteAddress tcpip.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -116,12 +116,87 @@ func New(network []string, transport []string, opts Options) *Stack {
|
|||||||
proto: transProto,
|
proto: transProto,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 添加传输层分流器
|
// NOTE 添加协议栈全局传输层分流器
|
||||||
s.demux = newTransportDemuxer(s)
|
s.demux = newTransportDemuxer(s)
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetNetworkProtocolOption allows configuring individual protocol level
|
||||||
|
// options. This method returns an error if the protocol is not supported or
|
||||||
|
// option is not supported by the protocol implementation or the provided value
|
||||||
|
// is incorrect.
|
||||||
|
func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
|
||||||
|
netProto, ok := s.networkProtocols[network]
|
||||||
|
if !ok {
|
||||||
|
return tcpip.ErrUnknownProtocol
|
||||||
|
}
|
||||||
|
return netProto.SetOption(option)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NetworkProtocolOption allows retrieving individual protocol level option
|
||||||
|
// values. This method returns an error if the protocol is not supported or
|
||||||
|
// option is not supported by the protocol implementation.
|
||||||
|
// e.g.
|
||||||
|
// var v ipv4.MyOption
|
||||||
|
// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
|
||||||
|
//
|
||||||
|
// if err != nil {
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
|
||||||
|
netProto, ok := s.networkProtocols[network]
|
||||||
|
if !ok {
|
||||||
|
return tcpip.ErrUnknownProtocol
|
||||||
|
}
|
||||||
|
return netProto.Option(option)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTransportProtocolOption allows configuring individual protocol level
|
||||||
|
// options. This method returns an error if the protocol is not supported or
|
||||||
|
// option is not supported by the protocol implementation or the provided value
|
||||||
|
// is incorrect.
|
||||||
|
func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
|
||||||
|
transProtoState, ok := s.transportProtocols[transport]
|
||||||
|
if !ok {
|
||||||
|
return tcpip.ErrUnknownProtocol
|
||||||
|
}
|
||||||
|
return transProtoState.proto.SetOption(option)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TransportProtocolOption allows retrieving individual protocol level option
|
||||||
|
// values. This method returns an error if the protocol is not supported or
|
||||||
|
// option is not supported by the protocol implementation.
|
||||||
|
// var v tcp.SACKEnabled
|
||||||
|
//
|
||||||
|
// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
|
||||||
|
// ...
|
||||||
|
// }
|
||||||
|
func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
|
||||||
|
transProtoState, ok := s.transportProtocols[transport]
|
||||||
|
if !ok {
|
||||||
|
return tcpip.ErrUnknownProtocol
|
||||||
|
}
|
||||||
|
return transProtoState.proto.Option(option)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTransportProtocolHandler sets the per-stack default handler for the given
|
||||||
|
// protocol.
|
||||||
|
//
|
||||||
|
// It must be called only during initialization of the stack. Changing it as the
|
||||||
|
// stack is operating is not supported.
|
||||||
|
func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, buffer.VectorisedView) bool) {
|
||||||
|
state := s.transportProtocols[p]
|
||||||
|
if state != nil {
|
||||||
|
state.defaultHandler = h
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NowNanoseconds implements tcpip.Clock.NowNanoseconds.
|
||||||
|
func (s *Stack) NowNanoseconds() int64 {
|
||||||
|
return s.clock.NowNanoseconds()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Stack) Stats() tcpip.Stats {
|
func (s *Stack) Stats() tcpip.Stats {
|
||||||
return s.stats
|
return s.stats
|
||||||
}
|
}
|
||||||
@@ -260,19 +335,19 @@ func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpi
|
|||||||
return false, tcpip.ErrUnknownNICID
|
return false, tcpip.ErrUnknownNICID
|
||||||
}
|
}
|
||||||
|
|
||||||
// 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息
|
// FindRoute 路由查找实现,比如当tcp建立连接时,会用该函数得到路由信息
|
||||||
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address,
|
func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address,
|
||||||
netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
|
netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
for i := range s.routeTable {
|
for i := range s.routeTable {
|
||||||
if (id != 0 && id != s.routeTable[i].NIC) ||
|
if (id != 0 && id != s.routeTable[i].NIC) || // 检查是否是对应的网卡
|
||||||
(len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
|
(len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
nic := s.nics[s.routeTable[i].NIC]
|
nic := s.nics[s.routeTable[i].NIC] // 在协议栈里找到这张网卡
|
||||||
if nic == nil {
|
if nic == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -372,14 +447,34 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.
|
|||||||
// 最终调用 demuxer.registerEndpoint 函数来实现注册。
|
// 最终调用 demuxer.registerEndpoint 函数来实现注册。
|
||||||
func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
|
func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
|
||||||
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
||||||
// TODO 需要实现
|
log.Println("往", nicID, "网卡注册新的传输端")
|
||||||
return nil
|
if nicID == 0 {
|
||||||
|
return s.demux.registerEndpoint(netProtos, protocol, id, ep) // 给协议栈的所有网卡注册传输端
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
nic := s.nics[nicID]
|
||||||
|
if nic == nil {
|
||||||
|
return tcpip.ErrUnknownNICID
|
||||||
|
}
|
||||||
|
return nic.demux.registerEndpoint(netProtos, protocol, id, ep) // 给这张网卡注册传输端
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnregisterTransportEndpoint removes the endpoint with the given id from the
|
// UnregisterTransportEndpoint removes the endpoint with the given id from the
|
||||||
// stack transport dispatcher.
|
// stack transport dispatcher.
|
||||||
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
|
func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber,
|
||||||
protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
|
protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
|
||||||
|
if nicID == 0 {
|
||||||
|
s.demux.unregisterEndpoint(netProtos, protocol, id)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
nic := s.nics[nicID]
|
||||||
|
if nic != nil {
|
||||||
|
nic.demux.unregisterEndpoint(netProtos, protocol, id)
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -2,6 +2,7 @@ package stack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"netstack/tcpip"
|
"netstack/tcpip"
|
||||||
|
"netstack/tcpip/buffer"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,6 +24,112 @@ type transportDemuxer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 新建一个分流器
|
// 新建一个分流器
|
||||||
func newTransportDemuxer(stacl *Stack) *transportDemuxer {
|
func newTransportDemuxer(stack *Stack) *transportDemuxer {
|
||||||
|
d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
|
||||||
|
|
||||||
|
for netProto := range stack.networkProtocols {
|
||||||
|
for tranProto := range stack.transportProtocols {
|
||||||
|
d.protocol[protocolIDs{network: netProto, transport: tranProto}] = &transportEndpoints{
|
||||||
|
endpoints: make(map[TransportEndpointID]TransportEndpoint),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerEndpoint 向分发器注册给定端点,以便将与端点ID匹配的数据包传递给它
|
||||||
|
func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber,
|
||||||
|
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
||||||
|
for i, n := range netProtos {
|
||||||
|
if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
|
||||||
|
d.unregisterEndpoint(netProtos[:i], protocol, id) // 把刚才注册的注销掉
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber,
|
||||||
|
protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
|
||||||
|
eps, ok := d.protocol[protocolIDs{netProto, protocol}] // IPv4:udp
|
||||||
|
if !ok { // 未曾注册过这个传输端集合
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
eps.mu.Lock()
|
||||||
|
defer eps.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := eps.endpoints[id]; ok { // 遍历传输端集合
|
||||||
|
return tcpip.ErrPortInUse
|
||||||
|
}
|
||||||
|
eps.endpoints[id] = ep
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// unregisterEndpoint 使用给定的id注销端点,使其不再接收任何数据包
|
||||||
|
func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber,
|
||||||
|
protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
|
||||||
|
for _, n := range netProtos {
|
||||||
|
if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
|
||||||
|
eps.mu.Lock()
|
||||||
|
delete(eps.endpoints, id)
|
||||||
|
eps.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据传输层的id来找到对应的传输端,再将数据包交给这个传输端处理
|
||||||
|
func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv buffer.VectorisedView, id TransportEndpointID) bool {
|
||||||
|
// 先看看分流器里有没有注册相关协议端,如果没有则返回false
|
||||||
|
eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// 从 eps 中找符合 id 的传输端
|
||||||
|
eps.mu.RLock()
|
||||||
|
ep := d.findEndpointLocked(eps, vv, id)
|
||||||
|
eps.mu.RUnlock()
|
||||||
|
|
||||||
|
if ep == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deliver the packet
|
||||||
|
ep.HandlePacket(r, id, vv)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber,
|
||||||
|
trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据传输层id来找到相应的传输层端
|
||||||
|
func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints,
|
||||||
|
vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
|
||||||
|
if ep := eps.endpoints[id]; ep != nil { // IPv4:udp
|
||||||
|
return ep
|
||||||
|
}
|
||||||
|
// Try to find a match with the id minus the local address.
|
||||||
|
nid := id
|
||||||
|
// 如果上面的 endpoints 没有找到,那么去掉本地ip地址,看看有没有相应的传输层端
|
||||||
|
// 因为有时候传输层监听的时候没有绑定本地ip,也就是 any address,此时的 LocalAddress
|
||||||
|
// 为空。
|
||||||
|
nid.LocalAddress = ""
|
||||||
|
if ep := eps.endpoints[nid]; ep != nil {
|
||||||
|
return ep
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to find a match with the id minus the remote part.
|
||||||
|
nid.LocalAddress = id.LocalAddress
|
||||||
|
nid.RemoteAddress = ""
|
||||||
|
nid.RemotePort = 0
|
||||||
|
if ep := eps.endpoints[nid]; ep != nil {
|
||||||
|
return ep
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to find a match with only the local port.
|
||||||
|
nid.LocalAddress = ""
|
||||||
|
return eps.endpoints[nid]
|
||||||
|
}
|
||||||
|
@@ -3,6 +3,7 @@
|
|||||||

|

|
||||||
|
|
||||||
传输层是整个网络体系结构中的关键之一,我们很多编程都是直接和传输层打交道的,我们需要了解以下的概念:
|
传输层是整个网络体系结构中的关键之一,我们很多编程都是直接和传输层打交道的,我们需要了解以下的概念:
|
||||||
|
|
||||||
1. 端口的意义 - 上一章已经介绍过了
|
1. 端口的意义 - 上一章已经介绍过了
|
||||||
2. 无连接 UDP 协议及特点 - 本章介绍
|
2. 无连接 UDP 协议及特点 - 本章介绍
|
||||||
3. 面向连接 TCP 协议及特点 - 下章会介绍
|
3. 面向连接 TCP 协议及特点 - 下章会介绍
|
||||||
@@ -15,4 +16,4 @@
|
|||||||
|
|
||||||
3. 报文差错检测 网络层只对 IP 首部进行差错检测,而传输层对整个报文进行差错检测。
|
3. 报文差错检测 网络层只对 IP 首部进行差错检测,而传输层对整个报文进行差错检测。
|
||||||
|
|
||||||
4. 提供不可靠和可靠通信 网络层只提供了不可靠通信,而在传输层的 TCP 协议提供了可靠通信。
|
4. 提供不可靠和可靠通信 网络层只提供了不可靠通信,而在传输层的 TCP 协议提供了可靠通信。
|
||||||
|
@@ -13,7 +13,13 @@ import (
|
|||||||
// udp报文结构 当收到udp报文时 会用这个结构来保存udp报文数据
|
// udp报文结构 当收到udp报文时 会用这个结构来保存udp报文数据
|
||||||
type udpPacket struct {
|
type udpPacket struct {
|
||||||
udpPacketEntry // 链表实现
|
udpPacketEntry // 链表实现
|
||||||
// TODO 需要添加
|
senderAddress tcpip.FullAddress
|
||||||
|
data buffer.VectorisedView
|
||||||
|
timestamp int64
|
||||||
|
hasTimestamp bool
|
||||||
|
// views is used as buffer for data when its length is large
|
||||||
|
// enough to store a VectorisedView.
|
||||||
|
views [8]buffer.View
|
||||||
}
|
}
|
||||||
|
|
||||||
type endpointState int
|
type endpointState int
|
||||||
@@ -40,7 +46,7 @@ type endpoint struct {
|
|||||||
rcvBufSizeMax int
|
rcvBufSizeMax int
|
||||||
rcvBufSize int
|
rcvBufSize int
|
||||||
rcvClosed bool
|
rcvClosed bool
|
||||||
rcvTimestamp bool
|
rcvTimestamp bool // 通过SetSocket进行设置 是否开启时间戳
|
||||||
|
|
||||||
// The following fields are protected by the mu mutex.
|
// The following fields are protected by the mu mutex.
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
@@ -57,7 +63,7 @@ type endpoint struct {
|
|||||||
// shutdownFlags represent the current shutdown state of the endpoint.
|
// shutdownFlags represent the current shutdown state of the endpoint.
|
||||||
shutdownFlags tcpip.ShutdownFlags
|
shutdownFlags tcpip.ShutdownFlags
|
||||||
|
|
||||||
// TODO
|
multicastMemberships []multicastMembership
|
||||||
|
|
||||||
// effectiveNetProtos contains the network protocols actually in use. In
|
// effectiveNetProtos contains the network protocols actually in use. In
|
||||||
// most cases it will only contain "netProto", but in cases like IPv6
|
// most cases it will only contain "netProto", but in cases like IPv6
|
||||||
@@ -68,6 +74,12 @@ type endpoint struct {
|
|||||||
effectiveNetProtos []tcpip.NetworkProtocolNumber
|
effectiveNetProtos []tcpip.NetworkProtocolNumber
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 多播的成员关系,包括多播地址和网卡ID
|
||||||
|
type multicastMembership struct {
|
||||||
|
nicID tcpip.NICID
|
||||||
|
multicastAddr tcpip.Address
|
||||||
|
}
|
||||||
|
|
||||||
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber,
|
func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber,
|
||||||
waiterQueue *waiter.Queue) *endpoint {
|
waiterQueue *waiter.Queue) *endpoint {
|
||||||
log.Println("新建一个udp端")
|
log.Println("新建一个udp端")
|
||||||
@@ -76,8 +88,32 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber,
|
|||||||
netProto: netProto,
|
netProto: netProto,
|
||||||
waiterQueue: waiterQueue,
|
waiterQueue: waiterQueue,
|
||||||
multicastTTL: 1,
|
multicastTTL: 1,
|
||||||
rcvBufSizeMax: 32 * 1024,
|
rcvBufSizeMax: 32 * 1024, // 接收缓存 32k
|
||||||
sndBufSize: 32 * 1024}
|
sndBufSize: 32 * 1024, // 发送缓存 32k
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewConnectedEndpoint creates a new endpoint in the connected state using the
|
||||||
|
// provided route.
|
||||||
|
func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID,
|
||||||
|
waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
|
||||||
|
ep := newEndpoint(stack, r.NetProto, waiterQueue)
|
||||||
|
|
||||||
|
// Register new endpoint so that packets are routed to it.
|
||||||
|
if err := stack.RegisterTransportEndpoint(r.NICID(),
|
||||||
|
[]tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
|
||||||
|
ep.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.id = id
|
||||||
|
ep.route = r.Clone()
|
||||||
|
ep.dstPort = id.RemotePort
|
||||||
|
ep.regNICID = r.NICID()
|
||||||
|
|
||||||
|
ep.state = stateConnected
|
||||||
|
|
||||||
|
return ep, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close UDP端的关闭,释放相应的资源
|
// Close UDP端的关闭,释放相应的资源
|
||||||
@@ -98,8 +134,37 @@ func (e *endpoint) Close() {
|
|||||||
e.mu.Unlock()
|
e.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
|
||||||
return nil, tcpip.ControlMessages{}, nil
|
e.rcvMu.Lock()
|
||||||
|
|
||||||
|
// 如果接收链表为空,即没有任何数据
|
||||||
|
if e.rcvList.Empty() {
|
||||||
|
err := tcpip.ErrWouldBlock
|
||||||
|
if e.rcvClosed {
|
||||||
|
err = tcpip.ErrClosedForReceive
|
||||||
|
}
|
||||||
|
e.rcvMu.Unlock()
|
||||||
|
return buffer.View{}, tcpip.ControlMessages{}, err
|
||||||
|
}
|
||||||
|
// 从接收链表中取出最前面的数据报,接着从链表中删除该数据报
|
||||||
|
// 然后减少接收缓存的大小
|
||||||
|
p := e.rcvList.Front()
|
||||||
|
e.rcvList.Remove(p)
|
||||||
|
e.rcvBufSize -= p.data.Size()
|
||||||
|
ts := e.rcvTimestamp
|
||||||
|
|
||||||
|
e.rcvMu.Unlock()
|
||||||
|
|
||||||
|
if ts && !p.hasTimestamp {
|
||||||
|
// Linux uses the current time.
|
||||||
|
p.timestamp = e.stack.NowNanoseconds()
|
||||||
|
}
|
||||||
|
if addr != nil {
|
||||||
|
// 赋值发送地址
|
||||||
|
*addr = p.senderAddress
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: ts, Timestamp: p.timestamp}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
|
func (e *endpoint) Write(tcpip.Payload, tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
|
||||||
@@ -141,8 +206,95 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t
|
|||||||
return netProto, nil
|
return netProto, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *endpoint) Connect(address tcpip.FullAddress) *tcpip.Error {
|
// Connect UDP中调用connect内核仅仅把对端ip&port记录下来. 这样在发送数据的时候无需再次指定
|
||||||
log.Println("连接")
|
// UDP多次调用connect有两种用途:1,指定一个新的ip&port连结. 2,断开和之前的ip&port的连结
|
||||||
|
func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
|
||||||
|
// 目标端口为0是错误的
|
||||||
|
if addr.Port == 0 {
|
||||||
|
// We don't support connecting to port zero.
|
||||||
|
return tcpip.ErrInvalidEndpointState
|
||||||
|
}
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
|
||||||
|
nicid := addr.NIC
|
||||||
|
var localPort uint16
|
||||||
|
// 判断UDP端的状态
|
||||||
|
switch e.state {
|
||||||
|
case stateInitial:
|
||||||
|
// 如果是初始状态,直接下一步
|
||||||
|
case stateBound, stateConnected:
|
||||||
|
localPort = e.id.LocalPort
|
||||||
|
log.Printf("绑定了 %d 的udp端 向[%d]网卡发起连接\n", localPort, nicid)
|
||||||
|
if e.bindNICID == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if nicid != 0 && nicid != e.bindNICID {
|
||||||
|
return tcpip.ErrInvalidEndpointState
|
||||||
|
}
|
||||||
|
nicid = e.bindNICID
|
||||||
|
default:
|
||||||
|
return tcpip.ErrInvalidEndpointState
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查地址的映射,得到相应的协议
|
||||||
|
netProto, err := e.checkV4Mapped(&addr, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Find a route to the desired destination.
|
||||||
|
// 在全局协议栈中查找路由
|
||||||
|
r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer r.Release()
|
||||||
|
|
||||||
|
// 新建一个传输端的标识,包括源IP、源端口、目的IP、目的端口
|
||||||
|
id := stack.TransportEndpointID{
|
||||||
|
LocalAddress: r.LocalAddress,
|
||||||
|
LocalPort: localPort,
|
||||||
|
RemotePort: addr.Port,
|
||||||
|
RemoteAddress: r.RemoteAddress,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置网络层协议,IPV4或IPV6,或两者都有
|
||||||
|
netProtos := []tcpip.NetworkProtocolNumber{netProto}
|
||||||
|
if netProto == header.IPv6ProtocolNumber && !e.v6only {
|
||||||
|
netProtos = []tcpip.NetworkProtocolNumber{
|
||||||
|
header.IPv4ProtocolNumber,
|
||||||
|
header.IPv6ProtocolNumber,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将该UDP端注册到协议栈中
|
||||||
|
id, err = e.registerWithStack(nicid, netProtos, id)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Remove the old registration.
|
||||||
|
// 如果源端口不为0,则尝试在传输层端中删除老的UDP端
|
||||||
|
if e.id.LocalPort != 0 {
|
||||||
|
e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println(e.id, id)
|
||||||
|
|
||||||
|
// 赋值UDP端的属性
|
||||||
|
e.id = id
|
||||||
|
e.route = r.Clone()
|
||||||
|
e.dstPort = addr.Port
|
||||||
|
e.regNICID = nicid
|
||||||
|
e.effectiveNetProtos = netProtos
|
||||||
|
|
||||||
|
// 更改该UDP端的状态为已连接
|
||||||
|
e.state = stateConnected
|
||||||
|
|
||||||
|
// 标志该UDP端可以接收数据了
|
||||||
|
e.rcvMu.Lock()
|
||||||
|
e.rcvReady = true
|
||||||
|
e.rcvMu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,7 +319,7 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ
|
|||||||
}
|
}
|
||||||
id.LocalPort = port
|
id.LocalPort = port
|
||||||
}
|
}
|
||||||
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
|
err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) // 往网卡注册一个绑定了端口的udp端
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 释放端口
|
// 释放端口
|
||||||
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
|
e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
|
||||||
@@ -206,6 +358,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
|
|||||||
LocalAddress: addr.Addr,
|
LocalAddress: addr.Addr,
|
||||||
LocalPort: addr.Port,
|
LocalPort: addr.Port,
|
||||||
}
|
}
|
||||||
|
log.Println("Bind", id)
|
||||||
// 在协议栈中注册该UDP端
|
// 在协议栈中注册该UDP端
|
||||||
id, err = e.registerWithStack(addr.NIC, netProtos, id)
|
id, err = e.registerWithStack(addr.NIC, netProtos, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -229,6 +382,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error
|
|||||||
// 标记状态为已绑定
|
// 标记状态为已绑定
|
||||||
e.state = stateBound
|
e.state = stateBound
|
||||||
|
|
||||||
|
e.rcvMu.Lock()
|
||||||
|
e.rcvReady = true
|
||||||
|
e.rcvMu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -271,9 +428,64 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从网络层接收到UDP数据报时的处理函数
|
// HandlePacket 从网络层接收到UDP数据报时的处理函数
|
||||||
|
// 首先 UDP 端有接收队列的概念,不像网络层接收到数据包立马发送给传输层,
|
||||||
|
// 对于协议栈来说,传输层是最后的一站,接下来的数据就需要交给用户层了,
|
||||||
|
// 但是用户层的行为是不可预知的,不知道用户层何时将数据取走(也就是 UDP Read 过程),
|
||||||
|
// 那么协议栈就实现一个接收队列,将接收的数据去掉 UDP 头部后保存在这个队列中,用户层需要的时候取走就可以了,
|
||||||
|
// 但是队列存数据量是有限制的,这个限制叫接收缓存大小,当接收队列中的数据总和超过这个缓存,那么接下来的这些报文将会被直接丢包。
|
||||||
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
|
func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
|
||||||
|
// Get the header then trim it from the view.
|
||||||
|
hdr := header.UDP(vv.First())
|
||||||
|
if int(hdr.Length()) > vv.Size() {
|
||||||
|
// Malformed packet.
|
||||||
|
// 错误报文
|
||||||
|
e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("udp 正式接收数据", hdr)
|
||||||
|
// 去除UDP首部
|
||||||
|
vv.TrimFront(header.UDPMinimumSize)
|
||||||
|
|
||||||
|
e.rcvMu.Lock()
|
||||||
|
e.stack.Stats().UDP.PacketsReceived.Increment()
|
||||||
|
|
||||||
|
// Drop the packet if our buffer is currently full.
|
||||||
|
// 如果UDP的接收缓存已经满了,那么丢弃报文。
|
||||||
|
if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
|
||||||
|
e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
|
||||||
|
e.rcvMu.Unlock()
|
||||||
|
log.Println("udp 接收缓存不足 丢弃报文")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 接收缓存是否为空
|
||||||
|
wasEmpty := e.rcvBufSize == 0
|
||||||
|
// 新建一个UDP数据包结构 插入到接收链表中
|
||||||
|
pkt := &udpPacket{
|
||||||
|
senderAddress: tcpip.FullAddress{
|
||||||
|
NIC: r.NICID(),
|
||||||
|
Addr: id.RemoteAddress,
|
||||||
|
Port: hdr.SourcePort(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// 复制UDP数据包的用户数据
|
||||||
|
pkt.data = vv.Clone(pkt.views[:]) // 当vv中的数据<=8时 无需再次分配内存
|
||||||
|
// 插入到接收链表中 并增加已经使用的缓存
|
||||||
|
e.rcvList.PushBack(pkt)
|
||||||
|
e.rcvBufSize += vv.Size()
|
||||||
|
|
||||||
|
if e.rcvTimestamp {
|
||||||
|
pkt.timestamp = e.stack.NowNanoseconds()
|
||||||
|
pkt.hasTimestamp = true
|
||||||
|
}
|
||||||
|
|
||||||
|
e.rcvMu.Unlock()
|
||||||
|
// TODO 通知用户层可以读取数据了
|
||||||
|
if wasEmpty {
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
|
// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
|
||||||
|
@@ -38,9 +38,8 @@ func (*protocol) MinimumPacketSize() int {
|
|||||||
// ParsePorts returns the source and destination ports stored in the given udp
|
// ParsePorts returns the source and destination ports stored in the given udp
|
||||||
// packet.
|
// packet.
|
||||||
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
|
func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
|
||||||
//h := header.UDP(v)
|
h := header.UDP(v)
|
||||||
//return h.SourcePort(), h.DestinationPort(), nil
|
return h.SourcePort(), h.DestinationPort(), nil
|
||||||
return 0, 0, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
|
// HandleUnknownDestinationPacket handles packets targeted at this protocol but
|
||||||
|
Reference in New Issue
Block a user