mirror of
https://github.com/impact-eintr/netstack.git
synced 2025-10-17 02:10:37 +08:00
封装对udp链接的读操作 将其对外暴露为一个阻塞操作
This commit is contained in:
@@ -1,13 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"flag"
|
"flag"
|
||||||
"io"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"netstack/tcpip"
|
"netstack/tcpip"
|
||||||
"netstack/tcpip/header"
|
|
||||||
"netstack/tcpip/link/fdbased"
|
"netstack/tcpip/link/fdbased"
|
||||||
"netstack/tcpip/link/tuntap"
|
"netstack/tcpip/link/tuntap"
|
||||||
"netstack/tcpip/network/arp"
|
"netstack/tcpip/network/arp"
|
||||||
@@ -17,7 +15,6 @@ import (
|
|||||||
"netstack/waiter"
|
"netstack/waiter"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -111,33 +108,59 @@ func main() {
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// 监听udp localPort端口
|
// 监听udp localPort端口
|
||||||
udpEp := udpListen(s, proto, 9999)
|
conn := udpListen(s, proto, 9999)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf, _, err := udpEp.Read(nil)
|
buf := make([]byte, 1024)
|
||||||
|
n, err := conn.Read(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == tcpip.ErrWouldBlock {
|
log.Println(err)
|
||||||
time.Sleep(100 * time.Millisecond)
|
break
|
||||||
log.Println("阻塞中")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
log.Println(buf)
|
log.Println("接收到数据", buf[:n])
|
||||||
break
|
|
||||||
}
|
}
|
||||||
// 关闭监听服务,此时会释放端口
|
// 关闭监听服务,此时会释放端口
|
||||||
udpEp.Close()
|
conn.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, _ := net.Listen("tcp", "0.0.0.0:9999")
|
select {}
|
||||||
rcv := &RCV{
|
//conn, _ := net.Listen("tcp", "0.0.0.0:9999")
|
||||||
Stack: s,
|
//rcv := &RCV{
|
||||||
addr: tcpip.FullAddress{},
|
// Stack: s,
|
||||||
}
|
// addr: tcpip.FullAddress{},
|
||||||
TCPServer(conn, rcv)
|
//}
|
||||||
|
//TCPServer(conn, rcv)
|
||||||
}
|
}
|
||||||
|
|
||||||
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) tcpip.Endpoint {
|
type UdpConn struct {
|
||||||
|
ep tcpip.Endpoint
|
||||||
|
wq *waiter.Queue
|
||||||
|
we *waiter.Entry
|
||||||
|
notifyCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *UdpConn) Close() {
|
||||||
|
conn.ep.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (conn *UdpConn) Read(rcv []byte) (int, error) {
|
||||||
|
conn.wq.EventRegister(conn.we, waiter.EventIn)
|
||||||
|
defer conn.wq.EventUnregister(conn.we)
|
||||||
|
for {
|
||||||
|
buf, _, err := conn.ep.Read(nil)
|
||||||
|
if err != nil {
|
||||||
|
if err == tcpip.ErrWouldBlock {
|
||||||
|
<-conn.notifyCh
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("%s", err.String())
|
||||||
|
}
|
||||||
|
rcv = append(rcv[:0], buf...)
|
||||||
|
return len(rcv), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int) *UdpConn {
|
||||||
var wq waiter.Queue
|
var wq waiter.Queue
|
||||||
// 新建一个udp端
|
// 新建一个udp端
|
||||||
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
|
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
|
||||||
@@ -152,62 +175,6 @@ func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, localPort int)
|
|||||||
log.Fatal("Bind failed: ", err)
|
log.Fatal("Bind failed: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 注意UDP是无连接的,它不需要Listen
|
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
|
||||||
return ep
|
return &UdpConn{ep, &wq, &waitEntry, notifyCh}
|
||||||
}
|
|
||||||
|
|
||||||
type RCV struct {
|
|
||||||
*stack.Stack
|
|
||||||
ep tcpip.Endpoint
|
|
||||||
addr tcpip.FullAddress
|
|
||||||
rcvBuf []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var transportPool = make(map[uint64]tcpip.Endpoint)
|
|
||||||
|
|
||||||
func (r *RCV) Handle(conn net.Conn) {
|
|
||||||
var err error
|
|
||||||
r.rcvBuf, err = io.ReadAll(conn)
|
|
||||||
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch string(r.rcvBuf[:3]) {
|
|
||||||
case "udp":
|
|
||||||
var wq waiter.Queue
|
|
||||||
// 新建一个udp端
|
|
||||||
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
r.ep = ep
|
|
||||||
r.Bind()
|
|
||||||
r.Connect()
|
|
||||||
r.Close()
|
|
||||||
case "tcp":
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RCV) Bind() {
|
|
||||||
if len(r.rcvBuf) < 9 { // udp ip port
|
|
||||||
log.Println("Error: too few arg")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
|
|
||||||
r.addr = tcpip.FullAddress{
|
|
||||||
NIC: 1,
|
|
||||||
Addr: tcpip.Address(r.rcvBuf[3:7]),
|
|
||||||
Port: port,
|
|
||||||
}
|
|
||||||
r.ep.Bind(r.addr, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RCV) Connect() {
|
|
||||||
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RCV) Close() {
|
|
||||||
r.ep.Close()
|
|
||||||
}
|
}
|
||||||
|
@@ -1,9 +1,16 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
|
"netstack/tcpip"
|
||||||
|
"netstack/tcpip/header"
|
||||||
|
"netstack/tcpip/stack"
|
||||||
|
"netstack/tcpip/transport/udp"
|
||||||
|
"netstack/waiter"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -36,3 +43,59 @@ func TCPServer(listener net.Listener, handler TCPHandler) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var transportPool = make(map[uint64]tcpip.Endpoint)
|
||||||
|
|
||||||
|
type RCV struct {
|
||||||
|
*stack.Stack
|
||||||
|
ep tcpip.Endpoint
|
||||||
|
addr tcpip.FullAddress
|
||||||
|
rcvBuf []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RCV) Handle(conn net.Conn) {
|
||||||
|
var err error
|
||||||
|
r.rcvBuf, err = io.ReadAll(conn)
|
||||||
|
if err != nil && len(r.rcvBuf) < 9 { // proto + ip + port
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch string(r.rcvBuf[:3]) {
|
||||||
|
case "udp":
|
||||||
|
var wq waiter.Queue
|
||||||
|
// 新建一个udp端
|
||||||
|
ep, err := r.NewEndpoint(udp.ProtocolNumber, header.IPv4ProtocolNumber, &wq)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
r.ep = ep
|
||||||
|
r.Bind()
|
||||||
|
r.Connect()
|
||||||
|
r.Close()
|
||||||
|
case "tcp":
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RCV) Bind() {
|
||||||
|
if len(r.rcvBuf) < 9 { // udp ip port
|
||||||
|
log.Println("Error: too few arg")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
port := binary.BigEndian.Uint16(r.rcvBuf[7:9])
|
||||||
|
r.addr = tcpip.FullAddress{
|
||||||
|
NIC: 1,
|
||||||
|
Addr: tcpip.Address(r.rcvBuf[3:7]),
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
r.ep.Bind(r.addr, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RCV) Connect() {
|
||||||
|
r.ep.Connect(tcpip.FullAddress{NIC: 1, Addr: "\xc0\xa8\x01\x02", Port: 8888})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *RCV) Close() {
|
||||||
|
r.ep.Close()
|
||||||
|
}
|
||||||
|
@@ -26,12 +26,13 @@ func main() {
|
|||||||
}
|
}
|
||||||
log.Println("TEST")
|
log.Println("TEST")
|
||||||
|
|
||||||
send := []byte("hello")
|
for i := 0; i < 3; i++ {
|
||||||
|
send := []byte("hello" + string(i))
|
||||||
if _, err := conn.Write(send); err != nil {
|
if _, err := conn.Write(send); err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
}
|
||||||
|
log.Printf("send: %s", string(send))
|
||||||
}
|
}
|
||||||
log.Printf("send: %s", string(send))
|
|
||||||
|
|
||||||
//recv := make([]byte, 10)
|
//recv := make([]byte, 10)
|
||||||
//rn, _, err := conn.ReadFrom(recv)
|
//rn, _, err := conn.ReadFrom(recv)
|
||||||
|
@@ -202,7 +202,16 @@ func (e *endpoint) dispatch() (bool, *tcpip.Error) {
|
|||||||
vv := buffer.NewVectorisedView(n, e.views[:used]) // 用这些有效的内容构建vv
|
vv := buffer.NewVectorisedView(n, e.views[:used]) // 用这些有效的内容构建vv
|
||||||
vv.TrimFront(e.hdrSize) // 将数据内容删除以太网头部信息 将网络层作为数据头
|
vv.TrimFront(e.hdrSize) // 将数据内容删除以太网头部信息 将网络层作为数据头
|
||||||
|
|
||||||
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
|
switch p {
|
||||||
|
case header.ARPProtocolNumber, header.IPv4ProtocolNumber:
|
||||||
|
log.Println("链路层收到报文")
|
||||||
|
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
|
||||||
|
case header.IPv6ProtocolNumber:
|
||||||
|
// TODO ipv6暂时不感兴趣
|
||||||
|
e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, localLinkAddr, p, vv)
|
||||||
|
default:
|
||||||
|
log.Println("未知类型的非法报文")
|
||||||
|
}
|
||||||
|
|
||||||
// 将分发后的数据无效化(设置nil可以让gc回收这些内存)
|
// 将分发后的数据无效化(设置nil可以让gc回收这些内存)
|
||||||
for i := 0; i < used; i++ {
|
for i := 0; i < used; i++ {
|
||||||
|
@@ -484,7 +484,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv
|
|||||||
e.rcvMu.Unlock()
|
e.rcvMu.Unlock()
|
||||||
// TODO 通知用户层可以读取数据了
|
// TODO 通知用户层可以读取数据了
|
||||||
if wasEmpty {
|
if wasEmpty {
|
||||||
|
e.waiterQueue.Notify(waiter.EventIn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -63,6 +63,8 @@ import (
|
|||||||
"netstack/ilist"
|
"netstack/ilist"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO 看事件机制
|
||||||
|
|
||||||
// EventMask represents io events as used in the poll() syscall.
|
// EventMask represents io events as used in the poll() syscall.
|
||||||
type EventMask uint16
|
type EventMask uint16
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user