mirror of
https://github.com/impact-eintr/netstack.git
synced 2025-10-17 02:10:37 +08:00
封装对udp链接的读操作 将其对外暴露为一个阻塞操作
This commit is contained in:
@@ -1,13 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"flag"
|
||||
"io"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"netstack/tcpip"
|
||||
"netstack/tcpip/header"
|
||||
"netstack/tcpip/link/fdbased"
|
||||
"netstack/tcpip/link/tuntap"
|
||||
"netstack/tcpip/network/arp"
|
||||
@@ -17,7 +15,6 @@ import (
|
||||
"netstack/waiter"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -111,33 +108,59 @@ func main() {
|
||||
|
||||
go func() {
|
||||
// 监听udp localPort端口
|
||||
udpEp := udpListen(s, proto, 9999)
|
||||
conn := udpListen(s, proto, 9999)
|
||||
|
||||
for {
|
||||
buf, _, err := udpEp.Read(nil)
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
if err == tcpip.ErrWouldBlock {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
log.Println("阻塞中")
|
||||
continue
|
||||
}
|
||||
}
|
||||
log.Println(buf)
|
||||
log.Println(err)
|
||||
break
|
||||
}
|
||||
log.Println("接收到数据", buf[:n])
|
||||
}
|
||||
// 关闭监听服务,此时会释放端口
|
||||
udpEp.Close()
|
||||
conn.Close()
|
||||
}()
|
||||
|
||||
conn, _ := net.Listen("tcp", "0.0.0.0:9999")
|
||||
rcv := &RCV{
|
||||
Stack: s,
|
||||
addr: tcpip.FullAddress{},
|
||||
}
|
||||
TCPServer(conn, rcv)
|
||||
select {}
|
||||
//conn, _ := net.Listen("tcp", "0.0.0.0:9999")
|
||||
//rcv := &RCV{
|
||||
// Stack: s,
|
||||
// addr: tcpip.FullAddress{},
|
||||
//}
|
||||
//TCPServer(conn, rcv)
|
||||
}
|
||||
|
||||
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint {
|
||||
type UdpConn struct {
|
||||
ep tcpip.Endpoint
|
||||
wq *waiter.Queue
|
||||
we *waiter.Entry
|
||||
notifyCh chan struct{}
|
||||
}
|
||||
|
||||
func (conn *UdpConn) Close() {
|
||||
conn.ep.Close()
|
||||
}
|
||||
|
||||
func (conn *UdpConn) Read(rcv []byte) (int, error) {
|
||||
conn.wq.EventRegister(conn.we, waiter.EventIn)
|
||||
defer conn.wq.EventUnregister(conn.we)
|
||||
for {
|
||||
buf, _, err := conn.ep.Read(nil)
|
||||
if err != nil {
|
||||
if err == tcpip.ErrWouldBlock {
|
||||
<-conn.notifyCh
|
||||
continue
|
||||
}
|
||||
return 0, fmt.Errorf("%s", err.String())
|
||||
}
|
||||
rcv = append(rcv[:0], buf...)
|
||||
return len(rcv), nil
|
||||
}
|
||||
}
|
||||
|
||||
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) *UdpConn {
|
||||
var wq waiter.Queue
|
||||
// 新建一个udp端
|
||||
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
|
||||
@@ -152,62 +175,6 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int)
|
||||
log.Fatal("Bind failed: ", err)
|
||||
}
|
||||
|
||||
// 注意UDP是无连接的,它不需要Listen
|
||||
return ep
|
||||
}
|
||||
|
||||
type RCV struct {
|
||||
*stack.Stack
|
||||
ep tcpip.Endpoint
|
||||
addr tcpip.FullAddress
|
||||
rcvBuf []byte
|
||||
}
|
||||
|
||||
var transportPool = make(map[uint64]tcpip.Endpoint)
|
||||
|
||||
func (r *RCV) Handle(conn net.Conn) {
|
||||
var err error
|
||||
r.rcvBuf, err = io.ReadAll(conn)
|
||||
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
|
||||
panic(err)
|
||||
}
|
||||
|
||||
switch string(r.rcvBuf[:3]) {
|
||||
case "udp":
|
||||
var wq waiter.Queue
|
||||
// 新建一个udp端
|
||||
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
r.ep = ep
|
||||
r.Bind()
|
||||
r.Connect()
|
||||
r.Close()
|
||||
case "tcp":
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RCV) Bind() {
|
||||
if len(r.rcvBuf) < 9 { // udp ip port
|
||||
log.Println("Error: too few arg")
|
||||
return
|
||||
}
|
||||
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
|
||||
r.addr = tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.Address(r.rcvBuf[3:7]),
|
||||
Port: port,
|
||||
}
|
||||
r.ep.Bind(r.addr, nil)
|
||||
}
|
||||
|
||||
func (r *RCV) Connect() {
|
||||
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
|
||||
}
|
||||
|
||||
func (r *RCV) Close() {
|
||||
r.ep.Close()
|
||||
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
|
||||
return &UdpConn{ep, &wq, &waitEntry, notifyCh}
|
||||
}
|
||||
|
@@ -1,9 +1,16 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"netstack/tcpip"
|
||||
"netstack/tcpip/header"
|
||||
"netstack/tcpip/stack"
|
||||
"netstack/tcpip/transport/udp"
|
||||
"netstack/waiter"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
@@ -36,3 +43,59 @@ func TCPServer(listener net.Listener, handler TCPHandler) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var transportPool = make(map[uint64]tcpip.Endpoint)
|
||||
|
||||
type RCV struct {
|
||||
*stack.Stack
|
||||
ep tcpip.Endpoint
|
||||
addr tcpip.FullAddress
|
||||
rcvBuf []byte
|
||||
}
|
||||
|
||||
func (r *RCV) Handle(conn net.Conn) {
|
||||
var err error
|
||||
r.rcvBuf, err = io.ReadAll(conn)
|
||||
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
|
||||
panic(err)
|
||||
}
|
||||
|
||||
switch string(r.rcvBuf[:3]) {
|
||||
case "udp":
|
||||
var wq waiter.Queue
|
||||
// 新建一个udp端
|
||||
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
r.ep = ep
|
||||
r.Bind()
|
||||
r.Connect()
|
||||
r.Close()
|
||||
case "tcp":
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RCV) Bind() {
|
||||
if len(r.rcvBuf) < 9 { // udp ip port
|
||||
log.Println("Error: too few arg")
|
||||
return
|
||||
}
|
||||
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
|
||||
r.addr = tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.Address(r.rcvBuf[3:7]),
|
||||
Port: port,
|
||||
}
|
||||
r.ep.Bind(r.addr, nil)
|
||||
}
|
||||
|
||||
func (r *RCV) Connect() {
|
||||
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
|
||||
}
|
||||
|
||||
func (r *RCV) Close() {
|
||||
r.ep.Close()
|
||||
}
|
||||
|
@@ -26,12 +26,13 @@ func main() {
|
||||
}
|
||||
log.Println("TEST")
|
||||
|
||||
send := []byte("hello")
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
send := []byte("hello" + string(i))
|
||||
if _, err := conn.Write(send); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
log.Printf("send: %s", string(send))
|
||||
}
|
||||
|
||||
//recv := make([]byte, 10)
|
||||
//rn, _, err := conn.ReadFrom(recv)
|
||||
|
@@ -202,7 +202,16 @@ func (e *endpoint) dispatch() (bool, *tcpip.Error) {
|
||||
vv := buffer.NewVectorisedView(n, e.views[:used]) // 用这些有效的内容构建vv
|
||||
vv.TrimFront(e.hdrSize) // 将数据内容删除以太网头部信息 将网络层作为数据头
|
||||
|
||||
switch p {
|
||||
case header.ARPProtocolNumber, header.IPv4ProtocolNumber:
|
||||
log.Println("链路层收到报文")
|
||||
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
|
||||
case header.IPv6ProtocolNumber:
|
||||
// TODO ipv6暂时不感兴趣
|
||||
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
|
||||
default:
|
||||
log.Println("未知类型的非法报文")
|
||||
}
|
||||
|
||||
// 将分发后的数据无效化(设置nil可以让gc回收这些内存)
|
||||
for i := 0; i < used; i++ {
|
||||
|
@@ -484,7 +484,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
|
||||
e.rcvMu.Unlock()
|
||||
// TODO 通知用户层可以读取数据了
|
||||
if wasEmpty {
|
||||
|
||||
e.waiterQueue.Notify(waiter.EventIn)
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -63,6 +63,8 @@ import (
|
||||
"netstack/ilist"
|
||||
)
|
||||
|
||||
// TODO 看事件机制
|
||||
|
||||
// EventMask represents io events as used in the poll() syscall.
|
||||
type EventMask uint16
|
||||
|
||||
|
Reference in New Issue
Block a user