pgcli/vpn: add a proxy server to access the PG network (#16)

This commit is contained in:
rkonfj
2024-12-28 19:33:06 +08:00
parent cdbb6ebd28
commit 4ce25bb1cf
6 changed files with 845 additions and 196 deletions

View File

@@ -0,0 +1,102 @@
package rootless
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"time"
N "github.com/sigcn/pg/net"
"github.com/sigcn/pg/socks5"
"github.com/sigcn/pg/vpn/nic/gvisor"
)
type ProxyConfig struct {
Listen string
}
type ProxyServer struct {
Config ProxyConfig
GvisorCard *gvisor.GvisorCard
udpListener *N.UDPListener
}
func (s *ProxyServer) Start(ctx context.Context, wg *sync.WaitGroup) error {
tcpListener, err := net.Listen("tcp", s.Config.Listen)
if err != nil {
return err
}
udpPacketConn, err := net.ListenPacket("udp", s.Config.Listen)
if err != nil {
tcpListener.Close()
return err
}
wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()
tcpListener.Close()
udpPacketConn.Close()
}()
s.udpListener = &N.UDPListener{PacketConn: udpPacketConn}
slog.Info("[Proxy] Server started", "listen", fmt.Sprintf("tcp+udp://%s", tcpListener.Addr().String()))
go s.run(tcpListener)
return nil
}
func (s *ProxyServer) run(tcp net.Listener) {
for {
c, err := tcp.Accept()
if err != nil {
return
}
addr, cmd, err := socks5.ServerHandshake(c, nil)
if err != nil {
slog.Error("[Proxy] SOCKS5 handshake", "err", err)
continue
}
if cmd == socks5.CmdConnect {
if err := s.proxyTCP(c, addr); err != nil {
slog.Error("[Proxy] SOCKS5 tcp", "err", err)
}
continue
}
if cmd == socks5.CmdUDPAssociate {
go func() {
if err := s.proxyUDP(addr); err != nil {
slog.Error("[Proxy] SOCKS5 udp", "err", err)
}
}()
continue
}
}
}
func (s *ProxyServer) proxyTCP(rw net.Conn, addr socks5.Addr) error {
c, err := s.GvisorCard.DialContext(context.TODO(), "tcp", addr.String())
if err != nil {
rw.Close()
return err
}
go relay(rw, c)
return nil
}
func (s *ProxyServer) proxyUDP(addr socks5.Addr) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c, err := s.udpListener.AcceptContext(ctx)
if err != nil {
return err
}
c1, err := s.GvisorCard.DialContext(context.TODO(), "udp", addr.String())
if err != nil {
c.Close()
return err
}
go relay(c, c1)
return nil
}

View File

@@ -94,6 +94,7 @@ func usage(flagSet *flag.FlagSet) {
logLevel := flagSet.Lookup("loglevel") logLevel := flagSet.Lookup("loglevel")
mtu := flagSet.Lookup("mtu") mtu := flagSet.Lookup("mtu")
peers := flagSet.Lookup("peers") peers := flagSet.Lookup("peers")
proxyListen := flagSet.Lookup("proxy-listen")
server := flagSet.Lookup("s") server := flagSet.Lookup("s")
tun := flagSet.Lookup("tun") tun := flagSet.Lookup("tun")
udpPort := flagSet.Lookup("udp-port") udpPort := flagSet.Lookup("udp-port")
@@ -119,6 +120,7 @@ func usage(flagSet *flag.FlagSet) {
fmt.Printf(" --key string\n\t%s\n", key.Usage) fmt.Printf(" --key string\n\t%s\n", key.Usage)
fmt.Printf(" --loglevel int\n\t%s (default %s)\n", logLevel.Usage, logLevel.DefValue) fmt.Printf(" --loglevel int\n\t%s (default %s)\n", logLevel.Usage, logLevel.DefValue)
fmt.Printf(" --mtu int\n\t%s (default %s)\n", mtu.Usage, mtu.DefValue) fmt.Printf(" --mtu int\n\t%s (default %s)\n", mtu.Usage, mtu.DefValue)
fmt.Printf(" --proxy-listen string\n\t%s\n", proxyListen.Usage)
fmt.Printf(" -s, --server string\n\t%s\n", server.Usage) fmt.Printf(" -s, --server string\n\t%s\n", server.Usage)
fmt.Printf(" --tun string\n\t%s (default %s)\n", tun.Usage, tun.DefValue) fmt.Printf(" --tun string\n\t%s (default %s)\n", tun.Usage, tun.DefValue)
fmt.Printf(" --udp-crypto\n\t%s (default %s)\n", cryptoAlgo.Usage, cryptoAlgo.DefValue) fmt.Printf(" --udp-crypto\n\t%s (default %s)\n", cryptoAlgo.Usage, cryptoAlgo.DefValue)
@@ -154,6 +156,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error)
flagSet.IntVar(&cfg.NICConfig.MTU, "mtu", 1411, "nic mtu") flagSet.IntVar(&cfg.NICConfig.MTU, "mtu", 1411, "nic mtu")
flagSet.StringVar(&cfg.NICConfig.Name, "tun", defaultTunName, "nic name") flagSet.StringVar(&cfg.NICConfig.Name, "tun", defaultTunName, "nic name")
flagSet.Var(&forwards, "forward", "start in rootless mode and create a port forward (e.g. tcp://127.0.0.1:80)") flagSet.Var(&forwards, "forward", "start in rootless mode and create a port forward (e.g. tcp://127.0.0.1:80)")
flagSet.StringVar(&cfg.ProxyConfig.Listen, "proxy-listen", "", "start a proxy server to access the PG network (e.g. 127.0.0.1:4090)")
flagSet.StringVar(&cfg.PrivateKey, "key", "", "curve25519 private key in base58 format (default generate a new one)") flagSet.StringVar(&cfg.PrivateKey, "key", "", "curve25519 private key in base58 format (default generate a new one)")
flagSet.StringVar(&cfg.SecretFile, "secret-file", "", "") flagSet.StringVar(&cfg.SecretFile, "secret-file", "", "")
@@ -213,6 +216,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error)
type Config struct { type Config struct {
NICConfig nic.Config NICConfig nic.Config
ProxyConfig rootless.ProxyConfig
DiscoPortScanOffset int DiscoPortScanOffset int
DiscoPortScanCount int DiscoPortScanCount int
DiscoPortScanDuration time.Duration DiscoPortScanDuration time.Duration
@@ -265,6 +269,13 @@ func (v *P2PVPN) Run(ctx context.Context) (err error) {
return err return err
} }
} }
if v.Config.ProxyConfig.Listen != "" {
if err := (&rootless.ProxyServer{
GvisorCard: card.(*gvisor.GvisorCard),
Config: v.Config.ProxyConfig}).Start(ctx, &wg); err != nil {
return err
}
}
if err := (&server.Server{ if err := (&server.Server{
Vnic: v.nic, Vnic: v.nic,
PeerStore: c.PeerStore(), PeerStore: c.PeerStore(),

183
net/udp.go Normal file
View File

@@ -0,0 +1,183 @@
package net
import (
"context"
"errors"
"log/slog"
"net"
"sync"
"sync/atomic"
"time"
)
var _ net.Conn = (*UDPConn)(nil)
type UDPConn struct {
removeConn func()
remoteAddr net.Addr
c net.PacketConn
closeOnce sync.Once
inbound chan []byte
closeChan chan struct{}
lastActiveTime atomic.Value
}
func (c *UDPConn) init() {
c.inbound = make(chan []byte, 512)
c.closeChan = make(chan struct{})
c.lastActiveTime.Store(time.Now())
ticker := time.NewTicker(6 * time.Second)
go func() { // create a timer to trace timeout udp conn, and close it
defer ticker.Stop()
for range ticker.C {
if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second {
c.Close()
break
}
}
}()
}
func (c *UDPConn) Read(p []byte) (int, error) {
select {
case b := <-c.inbound:
c.lastActiveTime.Store(time.Now())
return copy(p, b), nil
case <-c.closeChan:
return 0, net.ErrClosed
}
}
func (c *UDPConn) Write(p []byte) (int, error) {
c.lastActiveTime.Store(time.Now())
return c.c.WriteTo(p, c.remoteAddr)
}
func (c *UDPConn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}
func (c *UDPConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *UDPConn) Close() error {
c.closeOnce.Do(func() {
close(c.closeChan)
close(c.inbound)
c.removeConn()
slog.Log(context.Background(), -2, "UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr)
})
return nil
}
func (c *UDPConn) SetDeadline(t time.Time) error {
return errors.ErrUnsupported
}
func (c *UDPConn) SetReadDeadline(t time.Time) error {
return errors.ErrUnsupported
}
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
return errors.ErrUnsupported
}
type UDPListener struct {
PacketConn net.PacketConn
buf []byte
initOnce sync.Once
closeOnce sync.Once
udpChan chan *UDPConn
connMap map[string]*UDPConn
connMapMu sync.RWMutex
}
func (l *UDPListener) init() {
l.initOnce.Do(func() {
l.buf = make([]byte, 65535)
l.udpChan = make(chan *UDPConn, 8)
l.connMap = make(map[string]*UDPConn)
go l.readUDP()
})
}
func (l *UDPListener) readUDP() {
read := func() error {
read:
n, peerAddr, err := l.PacketConn.ReadFrom(l.buf)
if err != nil {
return err
}
l.connMapMu.RLock()
conn, ok := l.connMap[peerAddr.String()]
l.connMapMu.RUnlock()
if ok {
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
l.connMapMu.Lock()
conn, ok = l.connMap[peerAddr.String()]
if ok {
l.connMapMu.Unlock()
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
defer l.connMapMu.Unlock()
conn = &UDPConn{remoteAddr: peerAddr, c: l.PacketConn, removeConn: func() {
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
delete(l.connMap, peerAddr.String())
}}
conn.init()
l.connMap[peerAddr.String()] = conn
conn.inbound <- append([]byte(nil), l.buf[:n]...)
l.udpChan <- conn
return nil
}
for {
if err := read(); err != nil {
return
}
}
}
func (l *UDPListener) Accept() (net.Conn, error) {
return l.AcceptContext(context.Background())
}
func (l *UDPListener) AcceptContext(ctx context.Context) (net.Conn, error) {
l.init()
select {
case c := <-l.udpChan:
return c, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (l *UDPListener) Close() error {
if l.PacketConn == nil {
return nil
}
l.closeOnce.Do(func() {
l.PacketConn.Close()
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
for _, c := range l.connMap {
go c.Close()
}
})
return nil
}
func (l *UDPListener) Addr() net.Addr {
l.init()
if l.PacketConn == nil {
return nil
}
return l.PacketConn.LocalAddr()
}

457
socks5/socks5.go Normal file
View File

@@ -0,0 +1,457 @@
// Copyright (c) 2024 sigcn/pg
// Licensed under the GNU GENERAL PUBLIC LICENSE Version 3.
// Copyright (c) 2021-2024 clash
// Licensed under the GNU GENERAL PUBLIC LICENSE Version 3.
package socks5
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"net/netip"
"strconv"
)
type Authenticator interface {
Verify(user string, pass string) bool
}
// Error represents a SOCKS error
type Error byte
func (err Error) Error() string {
return "SOCKS error: " + strconv.Itoa(int(err))
}
// Command is request commands as defined in RFC 1928 section 4.
type Command = uint8
const Version = 5
// SOCKS request commands as defined in RFC 1928 section 4.
const (
CmdConnect Command = 1
CmdBind Command = 2
CmdUDPAssociate Command = 3
)
// SOCKS address types as defined in RFC 1928 section 5.
const (
AtypIPv4 = 1
AtypDomainName = 3
AtypIPv6 = 4
)
// MaxAddrLen is the maximum size of SOCKS address in bytes.
const MaxAddrLen = 1 + 1 + 255 + 2
// MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth
const MaxAuthLen = 255
// Addr represents a SOCKS address as defined in RFC 1928 section 5.
type Addr []byte
func (a Addr) String() string {
var host, port string
switch a[0] {
case AtypDomainName:
hostLen := uint16(a[1])
host = string(a[2 : 2+hostLen])
port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
case AtypIPv4:
host = net.IP(a[1 : 1+net.IPv4len]).String()
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
case AtypIPv6:
host = net.IP(a[1 : 1+net.IPv6len]).String()
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
}
return net.JoinHostPort(host, port)
}
// UDPAddr converts a socks5.Addr to *net.UDPAddr
func (a Addr) UDPAddr() *net.UDPAddr {
if len(a) == 0 {
return nil
}
switch a[0] {
case AtypIPv4:
var ip [net.IPv4len]byte
copy(ip[0:], a[1:1+net.IPv4len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
case AtypIPv6:
var ip [net.IPv6len]byte
copy(ip[0:], a[1:1+net.IPv6len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
}
// Other Atyp
return nil
}
// SOCKS errors as defined in RFC 1928 section 6.
const (
ErrGeneralFailure = Error(1)
ErrConnectionNotAllowed = Error(2)
ErrNetworkUnreachable = Error(3)
ErrHostUnreachable = Error(4)
ErrConnectionRefused = Error(5)
ErrTTLExpired = Error(6)
ErrCommandNotSupported = Error(7)
ErrAddressNotSupported = Error(8)
)
// Auth errors used to return a specific "Auth failed" error
var ErrAuth = errors.New("auth failed")
type User struct {
Username string
Password string
}
// ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
func ServerHandshake(rw net.Conn, authenticator Authenticator) (addr Addr, command Command, err error) {
// Read RFC 1928 for request and reply structure and sizes.
buf := make([]byte, MaxAddrLen)
// read VER, NMETHODS, METHODS
if _, err = io.ReadFull(rw, buf[:2]); err != nil {
return
}
nmethods := buf[1]
if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil {
return
}
// write VER METHOD
if authenticator != nil {
if _, err = rw.Write([]byte{5, 2}); err != nil {
return
}
// Get header
header := make([]byte, 2)
if _, err = io.ReadFull(rw, header); err != nil {
return
}
authBuf := make([]byte, MaxAuthLen)
// Get username
userLen := int(header[1])
if userLen <= 0 {
rw.Write([]byte{1, 1})
err = ErrAuth
return
}
if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
return
}
user := string(authBuf[:userLen])
// Get password
if _, err = rw.Read(header[:1]); err != nil {
return
}
passLen := int(header[0])
if passLen <= 0 {
rw.Write([]byte{1, 1})
err = ErrAuth
return
}
if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil {
return
}
pass := string(authBuf[:passLen])
// Verify
if ok := authenticator.Verify(user, pass); !ok {
rw.Write([]byte{1, 1})
err = ErrAuth
return
}
// Response auth state
if _, err = rw.Write([]byte{1, 0}); err != nil {
return
}
} else {
if _, err = rw.Write([]byte{5, 0}); err != nil {
return
}
}
// read VER CMD RSV ATYP DST.ADDR DST.PORT
if _, err = io.ReadFull(rw, buf[:3]); err != nil {
return
}
command = buf[1]
addr, err = ReadAddr(rw, buf)
if err != nil {
return
}
switch command {
case CmdConnect, CmdUDPAssociate:
// Acquire server listened address info
localAddr := ParseAddr(rw.LocalAddr().String())
if localAddr == nil {
err = ErrAddressNotSupported
} else {
// write VER REP RSV ATYP BND.ADDR BND.PORT
_, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
}
case CmdBind:
fallthrough
default:
err = ErrCommandNotSupported
}
return
}
// ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
buf := make([]byte, MaxAddrLen)
var err error
// VER, NMETHODS, METHODS
if user != nil {
_, err = rw.Write([]byte{5, 1, 2})
} else {
_, err = rw.Write([]byte{5, 1, 0})
}
if err != nil {
return nil, err
}
// VER, METHOD
if _, err := io.ReadFull(rw, buf[:2]); err != nil {
return nil, err
}
if buf[0] != 5 {
return nil, errors.New("SOCKS version error")
}
if buf[1] == 2 {
if user == nil {
return nil, ErrAuth
}
// password protocol version
authMsg := &bytes.Buffer{}
authMsg.WriteByte(1)
authMsg.WriteByte(uint8(len(user.Username)))
authMsg.WriteString(user.Username)
authMsg.WriteByte(uint8(len(user.Password)))
authMsg.WriteString(user.Password)
if _, err := rw.Write(authMsg.Bytes()); err != nil {
return nil, err
}
if _, err := io.ReadFull(rw, buf[:2]); err != nil {
return nil, err
}
if buf[1] != 0 {
return nil, errors.New("rejected username/password")
}
} else if buf[1] != 0 {
return nil, errors.New("SOCKS need auth")
}
// VER, CMD, RSV, ADDR
if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
return nil, err
}
// VER, REP, RSV
if _, err := io.ReadFull(rw, buf[:3]); err != nil {
return nil, err
}
return ReadAddr(rw, buf)
}
func ReadAddr(r io.Reader, b []byte) (Addr, error) {
if len(b) < MaxAddrLen {
return nil, io.ErrShortBuffer
}
_, err := io.ReadFull(r, b[:1]) // read 1st byte for address type
if err != nil {
return nil, err
}
switch b[0] {
case AtypDomainName:
_, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length
if err != nil {
return nil, err
}
domainLength := uint16(b[1])
_, err = io.ReadFull(r, b[2:2+domainLength+2])
return b[:1+1+domainLength+2], err
case AtypIPv4:
_, err = io.ReadFull(r, b[1:1+net.IPv4len+2])
return b[:1+net.IPv4len+2], err
case AtypIPv6:
_, err = io.ReadFull(r, b[1:1+net.IPv6len+2])
return b[:1+net.IPv6len+2], err
}
return nil, ErrAddressNotSupported
}
// SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed.
func SplitAddr(b []byte) Addr {
addrLen := 1
if len(b) < addrLen {
return nil
}
switch b[0] {
case AtypDomainName:
if len(b) < 2 {
return nil
}
addrLen = 1 + 1 + int(b[1]) + 2
case AtypIPv4:
addrLen = 1 + net.IPv4len + 2
case AtypIPv6:
addrLen = 1 + net.IPv6len + 2
default:
return nil
}
if len(b) < addrLen {
return nil
}
return b[:addrLen]
}
// ParseAddr parses the address in string s. Returns nil if failed.
func ParseAddr(s string) Addr {
var addr Addr
host, port, err := net.SplitHostPort(s)
if err != nil {
return nil
}
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
addr = make([]byte, 1+net.IPv4len+2)
addr[0] = AtypIPv4
copy(addr[1:], ip4)
} else {
addr = make([]byte, 1+net.IPv6len+2)
addr[0] = AtypIPv6
copy(addr[1:], ip)
}
} else {
if len(host) > 255 {
return nil
}
addr = make([]byte, 1+1+len(host)+2)
addr[0] = AtypDomainName
addr[1] = byte(len(host))
copy(addr[2:], host)
}
portnum, err := strconv.ParseUint(port, 10, 16)
if err != nil {
return nil
}
addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
return addr
}
// ParseAddrToSocksAddr parse a socks addr from net.addr
// This is a fast path of ParseAddr(addr.String())
func ParseAddrToSocksAddr(addr net.Addr) Addr {
var hostip net.IP
var port int
if udpaddr, ok := addr.(*net.UDPAddr); ok {
hostip = udpaddr.IP
port = udpaddr.Port
} else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
hostip = tcpaddr.IP
port = tcpaddr.Port
}
// fallback parse
if hostip == nil {
return ParseAddr(addr.String())
}
var parsed Addr
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
parsed = make([]byte, 1+net.IPv4len+2)
parsed[0] = AtypIPv4
copy(parsed[1:], ip4)
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
} else {
parsed = make([]byte, 1+net.IPv6len+2)
parsed[0] = AtypIPv6
copy(parsed[1:], hostip)
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
}
return parsed
}
func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr {
addr := addrPort.Addr()
if addr.Is4() {
ip4 := addr.As4()
return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())}
}
buf := make([]byte, 1+net.IPv6len+2)
buf[0] = AtypIPv6
copy(buf[1:], addr.AsSlice())
buf[1+net.IPv6len] = byte(addrPort.Port() >> 8)
buf[1+net.IPv6len+1] = byte(addrPort.Port())
return buf
}
// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
if len(packet) < 5 {
err = errors.New("insufficient length of packet")
return
}
// packet[0] and packet[1] are reserved
if !bytes.Equal(packet[:2], []byte{0, 0}) {
err = errors.New("reserved fields should be zero")
return
}
if packet[2] != 0 /* fragments */ {
err = errors.New("discarding fragmented payload")
return
}
addr = SplitAddr(packet[3:])
if addr == nil {
err = errors.New("failed to read UDP header")
}
payload = packet[3+len(addr):]
return
}
func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
if addr == nil {
err = errors.New("address is invalid")
return
}
packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{})
return
}

View File

@@ -9,6 +9,7 @@ import (
"strings" "strings"
"sync" "sync"
N "github.com/sigcn/pg/net"
"github.com/sigcn/pg/vpn/nic" "github.com/sigcn/pg/vpn/nic"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
@@ -17,6 +18,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
) )
var ( var (
@@ -122,6 +125,75 @@ func (g *GvisorCard) Close() error {
return nil return nil
} }
func (g *GvisorCard) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
g.init()
if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") {
return nil, errors.New("only tcp/udp is supported")
}
if strings.HasPrefix(network, "tcp") {
tcpAddr, err := net.ResolveTCPAddr(network, address)
if err != nil {
return nil, err
}
var add tcpip.Address
var protocol tcpip.NetworkProtocolNumber
if tcpAddr.IP.To4() != nil {
add = tcpip.AddrFrom4(tcpAddr.AddrPort().Addr().As4())
protocol = ipv4.ProtocolNumber
} else {
add = tcpip.AddrFrom16(tcpAddr.AddrPort().Addr().As16())
protocol = ipv6.ProtocolNumber
}
addr := tcpip.FullAddress{
NIC: g.nicID,
Addr: add,
Port: tcpAddr.AddrPort().Port()}
return gonet.DialContextTCP(ctx, g.Stack, addr, protocol)
}
if strings.HasPrefix(network, "udp") {
udpAddr, err := net.ResolveUDPAddr(network, address)
if err != nil {
return nil, err
}
var add tcpip.Address
var protocol tcpip.NetworkProtocolNumber
if udpAddr.IP.To4() != nil {
add = tcpip.AddrFrom4(udpAddr.AddrPort().Addr().As4())
protocol = ipv4.ProtocolNumber
} else {
add = tcpip.AddrFrom16(udpAddr.AddrPort().Addr().As16())
protocol = ipv6.ProtocolNumber
}
addr := &tcpip.FullAddress{
NIC: g.nicID,
Addr: add,
Port: udpAddr.AddrPort().Port()}
return gonet.DialUDP(g.Stack, nil, addr, protocol)
}
return nil, nil
}
func (g *GvisorCard) listenUDP(addr tcpip.FullAddress) (net.PacketConn, error) {
var wq waiter.Queue
var ep tcpip.Endpoint
var err tcpip.Error
if net.IP(addr.Addr.AsSlice()).To4() != nil {
ep, err = g.Stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
} else {
ep, err = g.Stack.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
}
if err != nil {
return nil, errors.New(err.String())
}
err = ep.Bind(addr)
if err != nil {
return nil, errors.New(err.String())
}
return gonet.NewUDPConn(&wq, ep), nil
}
func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l net.Listener, err error) { func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l net.Listener, err error) {
g.init() g.init()
if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") { if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") {
@@ -140,12 +212,20 @@ func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l
if network == "udp4" { if network == "udp4" {
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port} addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port}
return &udpListener{s: g.Stack, addr: addr}, nil pc, err := g.listenUDP(addr)
if err != nil {
return nil, err
}
return &N.UDPListener{PacketConn: pc}, nil
} }
if network == "udp6" { if network == "udp6" {
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port} addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port}
return &udpListener{s: g.Stack, addr: addr}, nil pc, err := g.listenUDP(addr)
if err != nil {
return nil, err
}
return &N.UDPListener{PacketConn: pc}, nil
} }
var listeners []net.Listener var listeners []net.Listener
@@ -160,11 +240,19 @@ func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l
if network == "udp" { if network == "udp" {
if g.addr4.Len() > 0 { if g.addr4.Len() > 0 {
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port} addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port}
listeners = append(listeners, &udpListener{s: g.Stack, addr: addr}) pc, err := g.listenUDP(addr)
if err != nil {
return nil, err
}
listeners = append(listeners, &N.UDPListener{PacketConn: pc})
} }
if g.addr6.Len() > 0 { if g.addr6.Len() > 0 {
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port} addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port}
listeners = append(listeners, &udpListener{s: g.Stack, addr: addr}) pc, err := g.listenUDP(addr)
if err != nil {
return nil, err
}
listeners = append(listeners, &N.UDPListener{PacketConn: pc})
} }
return &combinedListeners{listeners: listeners}, nil return &combinedListeners{listeners: listeners}, nil
} }

View File

@@ -1,192 +0,0 @@
package gvisor
import (
"context"
"errors"
"log/slog"
"net"
"sync"
"sync/atomic"
"time"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/waiter"
)
var _ net.Listener = (*udpListener)(nil)
var _ net.Conn = (*udpConn)(nil)
type udpConn struct {
removeConn func()
remoteAddr net.Addr
c *gonet.UDPConn
closeOnce sync.Once
inbound chan []byte
closeChan chan struct{}
lastActiveTime atomic.Value
}
func (c *udpConn) init() {
c.inbound = make(chan []byte, 512)
c.closeChan = make(chan struct{})
c.lastActiveTime.Store(time.Now())
ticker := time.NewTicker(6 * time.Second)
go func() { // create a timer to trace timeout udp conn, and close it
defer ticker.Stop()
for range ticker.C {
if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second {
c.Close()
break
}
}
}()
}
func (c *udpConn) Read(p []byte) (int, error) {
select {
case b := <-c.inbound:
c.lastActiveTime.Store(time.Now())
return copy(p, b), nil
case <-c.closeChan:
return 0, net.ErrClosed
}
}
func (c *udpConn) Write(p []byte) (int, error) {
c.lastActiveTime.Store(time.Now())
return c.c.WriteTo(p, c.remoteAddr)
}
func (c *udpConn) LocalAddr() net.Addr {
return c.c.LocalAddr()
}
func (c *udpConn) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *udpConn) Close() error {
c.closeOnce.Do(func() {
close(c.closeChan)
close(c.inbound)
c.removeConn()
slog.Log(context.Background(), -2, "[gVisor] UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr)
})
return nil
}
func (c *udpConn) SetDeadline(t time.Time) error {
return errors.ErrUnsupported
}
func (c *udpConn) SetReadDeadline(t time.Time) error {
return errors.ErrUnsupported
}
func (c *udpConn) SetWriteDeadline(t time.Time) error {
return errors.ErrUnsupported
}
type udpListener struct {
addr tcpip.FullAddress
s *stack.Stack
buf []byte
c *gonet.UDPConn
initErr error
initOnce sync.Once
closeOnce sync.Once
connMap map[net.Addr]*udpConn
connMapMu sync.RWMutex
}
func (l *udpListener) init() {
l.initOnce.Do(func() {
var wq waiter.Queue
var ep tcpip.Endpoint
var err tcpip.Error
if net.IP(l.addr.Addr.AsSlice()).To4() != nil {
ep, err = l.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
} else {
ep, err = l.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
}
if err != nil {
l.initErr = errors.New(err.String())
return
}
err = ep.Bind(l.addr)
if err != nil {
l.initErr = errors.New(err.String())
return
}
l.buf = make([]byte, 65535)
l.connMap = make(map[net.Addr]*udpConn)
l.c = gonet.NewUDPConn(&wq, ep)
})
}
func (l *udpListener) Accept() (net.Conn, error) {
l.init()
if l.initErr != nil {
return nil, l.initErr
}
read:
n, peerAddr, err := l.c.ReadFrom(l.buf)
if err != nil {
return nil, err
}
l.connMapMu.RLock()
conn, ok := l.connMap[peerAddr]
l.connMapMu.RUnlock()
if ok {
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
l.connMapMu.Lock()
conn, ok = l.connMap[peerAddr]
if ok {
l.connMapMu.Unlock()
conn.inbound <- append([]byte(nil), l.buf[:n]...)
goto read
}
defer l.connMapMu.Unlock()
conn = &udpConn{remoteAddr: peerAddr, c: l.c, removeConn: func() {
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
delete(l.connMap, peerAddr)
}}
conn.init()
l.connMap[peerAddr] = conn
conn.inbound <- append([]byte(nil), l.buf[:n]...)
return conn, nil
}
func (l *udpListener) Close() error {
if l.c == nil {
return nil
}
l.closeOnce.Do(func() {
l.c.Close()
l.connMapMu.Lock()
defer l.connMapMu.Unlock()
for _, c := range l.connMap {
go c.Close()
}
})
return nil
}
func (l *udpListener) Addr() net.Addr {
l.init()
if l.c == nil {
return nil
}
return l.c.LocalAddr()
}