Files
netstack/cmd/netstack/main.go

395 lines
9.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"flag"
"fmt"
"log"
"net"
"netstack/logger"
"netstack/tcpip"
"netstack/tcpip/header"
"netstack/tcpip/link/fdbased"
"netstack/tcpip/link/loopback"
"netstack/tcpip/link/tuntap"
"netstack/tcpip/network/arp"
"netstack/tcpip/network/ipv4"
"netstack/tcpip/network/ipv6"
"netstack/tcpip/stack"
"netstack/tcpip/transport/tcp"
"netstack/tcpip/transport/udp"
"netstack/waiter"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
)
var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
func main() {
flag.Parse()
if len(flag.Args()) != 4 {
log.Fatal("Usage: ", os.Args[0], " <tap-device> <local-address/mask> <ip-address> <local-port>")
}
logger.SetFlags(logger.HANDSHAKE)
log.SetFlags(log.Lshortfile | log.LstdFlags)
tapName := flag.Arg(0)
cidrName := flag.Arg(1)
addrName := flag.Arg(2)
portName := flag.Arg(3)
log.Printf("tap: %v, addr: %v, port: %v", tapName, addrName, portName)
maddr, err := net.ParseMAC(*mac)
if err != nil {
log.Fatalf("Bad MAC address: %v", *mac)
}
parsedAddr := net.ParseIP(addrName)
if err != nil {
log.Fatalf("Bad addrress: %v", addrName)
}
// 解析地址ip地址ipv4或者ipv6地址都支持
var addr tcpip.Address
var proto tcpip.NetworkProtocolNumber
if parsedAddr.To4() != nil {
addr = tcpip.Address(parsedAddr.To4())
proto = ipv4.ProtocolNumber
} else if parsedAddr.To16() != nil {
addr = tcpip.Address(parsedAddr.To16())
proto = ipv6.ProtocolNumber
} else {
log.Fatalf("Unknown IP type: %v", parsedAddr)
}
localPort, err := strconv.Atoi(portName)
if err != nil {
log.Fatalf("Unable to convert port %v: %v", portName, err)
}
// 虚拟网卡配置
conf := &tuntap.Config{
Name: tapName,
Mode: tuntap.TAP,
}
var fd int
// 新建虚拟网卡
fd, err = tuntap.NewNetDev(conf)
if err != nil {
log.Fatal(err)
}
// 启动tap网卡
_ = tuntap.SetLinkUp(tapName)
// 设置路由
_ = tuntap.SetRoute(tapName, cidrName)
// 抽象的文件接口
linkID := fdbased.New(&fdbased.Options{
FD: fd,
MTU: 1500,
Address: tcpip.LinkAddress(maddr),
ResolutionRequired: true,
})
_ = linkID
loopbackLinkID := loopback.New()
// 新建相关协议的协议栈
s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName},
[]string{tcp.ProtocolName, udp.ProtocolName}, stack.Options{})
// 新建抽象的网卡
if err := s.CreateNamedNIC(1, "vnic1", loopbackLinkID); err != nil {
log.Fatal(err)
}
// 在该协议栈上添加和注册相应的网络层
if err := s.AddAddress(1, proto, addr); err != nil {
log.Fatal(err)
}
// 在该协议栈上添加和注册ARP协议
if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
log.Fatal(err)
}
// 添加默认路由
s.SetRouteTable([]tcpip.Route{
{
Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
Mask: tcpip.AddressMask(strings.Repeat("\x00", len(addr))),
Gateway: "",
NIC: 1,
},
})
done := make(chan struct{}, 2)
//logger.SetFlags(logger.TCP)
go func() { // echo server
listener := tcpListen(s, proto, addr, localPort)
done <- struct{}{}
for {
conn, err := listener.Accept()
if err != nil {
log.Println(err)
}
log.Println("服务端 建立连接")
go func() {
for {
buf := make([]byte, 1024)
n, err := conn.Read(buf)
if err != nil {
log.Println(err)
break
}
fmt.Println("data: ", n, len(buf), string(buf))
// conn.Write([]byte("Server echo"))
//}
}
}()
}
}()
<-done
go func() {
port := localPort
conn, err := Dial(s, header.IPv4ProtocolNumber, addr, port)
if err != nil {
log.Fatal(err)
}
log.Printf("客户端 建立连接\n")
time.Sleep(time.Second)
log.Printf("\n\n客户端 写入数据")
buf := make([]byte, 1<<17)
conn.Write(buf)
time.Sleep(1 * time.Minute)
conn.Close()
}()
close(done)
c := make(chan os.Signal)
signal.Notify(c, os.Interrupt, os.Kill, syscall.SIGUSR1, syscall.SIGUSR2)
<-c
}
// Dial 呼叫tcp服务端
func Dial(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, port int) (*TcpConn, error) {
remote := tcpip.FullAddress{
Addr: addr,
Port: uint16(port),
}
var wq waiter.Queue
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventOut)
// 新建一个tcp端
ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
if err != nil {
return nil, fmt.Errorf("%s", err.String())
}
err = ep.Connect(remote)
if err != nil {
if err == tcpip.ErrConnectStarted {
<-notifyCh
} else {
return nil, fmt.Errorf("%s", err.String())
}
}
return &TcpConn{
ep: ep,
wq: &wq,
we: &waitEntry,
notifyCh: notifyCh}, nil
}
// TcpConn 一条tcp连接
type TcpConn struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
// Read 读数据
func (conn *TcpConn) Read(rcv []byte) (int, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
buf, _, err := conn.ep.Read(&conn.raddr)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return 0, fmt.Errorf("%s", err.String())
}
n := len(buf)
if n > cap(rcv) {
n = cap(rcv)
}
rcv = append(rcv[:0], buf[:n]...)
return len(buf), nil
}
}
// Write 写数据
func (conn *TcpConn) Write(snd []byte) error {
for {
_, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr})
if err != nil {
if err == tcpip.ErrNoLinkAddress {
<-notifyCh
continue
}
return fmt.Errorf("%s", err.String())
}
return nil
}
}
// Close 关闭连接
func (conn *TcpConn) Close() {
conn.ep.Close()
}
// Listener tcp连接监听器
type Listener struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
// Accept 封装tcp的accept操作
func (l *Listener) Accept() (*TcpConn, error) {
l.wq.EventRegister(l.we, waiter.EventIn)
defer l.wq.EventUnregister(l.we)
for {
ep, wq, err := l.ep.Accept()
if err != nil {
if err == tcpip.ErrWouldBlock {
<-l.notifyCh
continue
}
return nil, fmt.Errorf("%s", err.String())
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &TcpConn{ep: ep,
wq: wq,
we: &waitEntry,
notifyCh: notifyCh}, nil
}
}
func tcpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *Listener {
var wq waiter.Queue
// 新建一个tcp端
ep, err := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
if err != nil {
log.Fatal(err)
}
// 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 1, Addr: "", Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err)
}
// 开始监听
if err := ep.Listen(10); err != nil {
log.Fatal("Listen failed: ", err)
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &Listener{
ep: ep,
wq: &wq,
we: &waitEntry,
notifyCh: notifyCh}
}
type UdpConn struct {
raddr tcpip.FullAddress
ep tcpip.Endpoint
wq *waiter.Queue
we *waiter.Entry
notifyCh chan struct{}
}
func (conn *UdpConn) Close() {
conn.ep.Close()
}
func (conn *UdpConn) Read(rcv []byte) (int, error) {
conn.wq.EventRegister(conn.we, waiter.EventIn)
defer conn.wq.EventUnregister(conn.we)
for {
buf, _, err := conn.ep.Read(&conn.raddr)
if err != nil {
if err == tcpip.ErrWouldBlock {
<-conn.notifyCh
continue
}
return 0, fmt.Errorf("%s", err.String())
}
n := len(buf)
if n > cap(rcv) {
n = cap(rcv)
}
rcv = append(rcv[:0], buf[:n]...)
return n, nil
}
}
func (conn *UdpConn) Write(snd []byte) error {
for {
_, notifyCh, err := conn.ep.Write(tcpip.SlicePayload(snd), tcpip.WriteOptions{To: &conn.raddr})
if err != nil {
if err == tcpip.ErrNoLinkAddress {
<-notifyCh
continue
}
return fmt.Errorf("%s", err.String())
}
return nil
}
}
func udpListen(s *stack.Stack, proto tcpip.NetworkProtocolNumber, addr tcpip.Address, localPort int) *UdpConn {
var wq waiter.Queue
// 新建一个udp端
ep, err := s.NewEndpoint(udp.ProtocolNumber, proto, &wq)
if err != nil {
log.Fatal(err)
}
// 绑定IP和端口这里的IP地址为空表示绑定任何IP
// 0.0.0.0:9999 这台机器上的所有ip的9999段端口数据都会使用该传输层实现
// 此时就会调用端口管理器
if err := ep.Bind(tcpip.FullAddress{NIC: 0, Addr: addr, Port: uint16(localPort)}, nil); err != nil {
log.Fatal("Bind failed: ", err)
}
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
return &UdpConn{
ep: ep,
wq: &wq,
we: &waitEntry,
notifyCh: notifyCh}
}