封装对udp链接的读操作 将其对外暴露为一个阻塞操作

This commit is contained in:
impact-eintr
2022-12-02 12:35:52 +08:00
parent be40f904fc
commit 98bccec410
6 changed files with 128 additions and 86 deletions

View File

@@ -1,13 +1,11 @@
package main package main
import ( import (
"encoding/binary"
"flag" "flag"
"io" "fmt"
"log" "log"
"net" "net"
"netstack/tcpip" "netstack/tcpip"
"netstack/tcpip/header"
"netstack/tcpip/link/fdbased" "netstack/tcpip/link/fdbased"
"netstack/tcpip/link/tuntap" "netstack/tcpip/link/tuntap"
"netstack/tcpip/network/arp" "netstack/tcpip/network/arp"
@@ -17,7 +15,6 @@ import (
"netstack/waiter" "netstack/waiter"
"os" "os"
"strings" "strings"
"time"
) )
func main() { func main() {
@@ -111,33 +108,59 @@ func main() {
go func() { go func() {
// 监听udp localPort端口 // 监听udp localPort端口
udpEp := udpListen(s, proto, 9999) conn := udpListen(s, proto, 9999)
for { for {
buf, _, err := udpEp.Read(nil) buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil { if err != nil {
if err == tcpip.ErrWouldBlock { log.Println(err)
time.Sleep(100 * time.Millisecond) break
log.Println("阻塞中")
continue
}
} }
log.Println(buf) log.Println("接收到数据", buf[:n])
break
} }
// 关闭监听服务,此时会释放端口 // 关闭监听服务,此时会释放端口
udpEp.Close() conn.Close()
}() }()
conn, _ := net.Listen("tcp", "0.0.0.0:9999") select {}
rcv := &RCV{ //conn, _ := net.Listen("tcp", "0.0.0.0:9999")
Stack: s, //rcv := &RCV{
addr: tcpip.FullAddress{}, // Stack: s,
} // addr: tcpip.FullAddress{},
TCPServer(conn, rcv) //}
//TCPServer(conn, rcv)
} }
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint { type UdpConn struct {
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
func (conn *UdpConn) Close() {
conn.ep.Close()
}
func (conn *UdpConn) Read(rcv []byte) (int, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
buf, _, err := conn.ep.Read(nil)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return 0, fmt.Errorf("%s", err.String())
}
rcv = append(rcv[:0], buf...)
return len(rcv), nil
}
}
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) *UdpConn {
var wq waiter.Queue var wq waiter.Queue
// 新建一个udp端 // 新建一个udp端
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq) ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
@@ -152,62 +175,6 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int)
log.Fatal("Bind failed: ", err) log.Fatal("Bind failed: ", err)
} }
// 注意UDP是无连接的它不需要Listen waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return ep return &UdpConn{ep, &wq, &waitEntry, notifyCh}
}
type RCV struct {
*stack.Stack
ep tcpip.Endpoint
addr tcpip.FullAddress
rcvBuf []byte
}
var transportPool = make(map[uint64]tcpip.Endpoint)
func (r *RCV) Handle(conn net.Conn) {
var err error
r.rcvBuf, err = io.ReadAll(conn)
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
panic(err)
}
switch string(r.rcvBuf[:3]) {
case "udp":
var wq waiter.Queue
// 新建一个udp端
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
if err != nil {
log.Fatal(err)
}
r.ep = ep
r.Bind()
r.Connect()
r.Close()
case "tcp":
default:
return
}
}
func (r *RCV) Bind() {
if len(r.rcvBuf) < 9 { // udp ip port
log.Println("Error: too few arg")
return
}
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
r.addr = tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(r.rcvBuf[3:7]),
Port: port,
}
r.ep.Bind(r.addr, nil)
}
func (r *RCV) Connect() {
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
}
func (r *RCV) Close() {
r.ep.Close()
} }

View File

@@ -1,9 +1,16 @@
package main package main
import ( import (
"encoding/binary"
"fmt" "fmt"
"io"
"log" "log"
"net" "net"
"netstack/tcpip"
"netstack/tcpip/header"
"netstack/tcpip/stack"
"netstack/tcpip/transport/udp"
"netstack/waiter"
"runtime" "runtime"
"strings" "strings"
) )
@@ -36,3 +43,59 @@ func TCPServer(listener net.Listener, handler TCPHandler) error {
return nil return nil
} }
var transportPool = make(map[uint64]tcpip.Endpoint)
type RCV struct {
*stack.Stack
ep tcpip.Endpoint
addr tcpip.FullAddress
rcvBuf []byte
}
func (r *RCV) Handle(conn net.Conn) {
var err error
r.rcvBuf, err = io.ReadAll(conn)
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
panic(err)
}
switch string(r.rcvBuf[:3]) {
case "udp":
var wq waiter.Queue
// 新建一个udp端
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
if err != nil {
log.Fatal(err)
}
r.ep = ep
r.Bind()
r.Connect()
r.Close()
case "tcp":
default:
return
}
}
func (r *RCV) Bind() {
if len(r.rcvBuf) < 9 { // udp ip port
log.Println("Error: too few arg")
return
}
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
r.addr = tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(r.rcvBuf[3:7]),
Port: port,
}
r.ep.Bind(r.addr, nil)
}
func (r *RCV) Connect() {
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
}
func (r *RCV) Close() {
r.ep.Close()
}

View File

@@ -26,12 +26,13 @@ func main() {
} }
log.Println("TEST") log.Println("TEST")
send := []byte("hello") for i := 0; i < 3; i++ {
send := []byte("hello" + string(i))
if _, err := conn.Write(send); err != nil { if _, err := conn.Write(send); err != nil {
panic(err) panic(err)
}
log.Printf("send: %s", string(send))
} }
log.Printf("send: %s", string(send))
//recv := make([]byte, 10) //recv := make([]byte, 10)
//rn, _, err := conn.ReadFrom(recv) //rn, _, err := conn.ReadFrom(recv)

View File

@@ -202,7 +202,16 @@ func (e *endpoint) dispatch() (bool, *tcpip.Error) {
vv := buffer.NewVectorisedView(n, e.views[:used]) // 用这些有效的内容构建vv vv := buffer.NewVectorisedView(n, e.views[:used]) // 用这些有效的内容构建vv
vv.TrimFront(e.hdrSize) // 将数据内容删除以太网头部信息 将网络层作为数据头 vv.TrimFront(e.hdrSize) // 将数据内容删除以太网头部信息 将网络层作为数据头
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv) switch p {
case header.ARPProtocolNumber, header.IPv4ProtocolNumber:
log.Println("链路层收到报文")
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
case header.IPv6ProtocolNumber:
// TODO ipv6暂时不感兴趣
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
default:
log.Println("未知类型的非法报文")
}
// 将分发后的数据无效化(设置nil可以让gc回收这些内存) // 将分发后的数据无效化(设置nil可以让gc回收这些内存)
for i := 0; i < used; i++ { for i := 0; i < used; i++ {

View File

@@ -484,7 +484,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
e.rcvMu.Unlock() e.rcvMu.Unlock()
// TODO 通知用户层可以读取数据了 // TODO 通知用户层可以读取数据了
if wasEmpty { if wasEmpty {
e.waiterQueue.Notify(waiter.EventIn)
} }
} }

View File

@@ -63,6 +63,8 @@ import (
"netstack/ilist" "netstack/ilist"
) )
// TODO 看事件机制
// EventMask represents io events as used in the poll() syscall. // EventMask represents io events as used in the poll() syscall.
type EventMask uint16 type EventMask uint16