mirror of
https://github.com/sigcn/pg.git
synced 2025-09-26 22:05:50 +08:00
pgcli/vpn: add a proxy server to access the PG network (#16)
This commit is contained in:
102
cmd/pgcli/vpn/rootless/proxy.go
Normal file
102
cmd/pgcli/vpn/rootless/proxy.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package rootless
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
N "github.com/sigcn/pg/net"
|
||||
"github.com/sigcn/pg/socks5"
|
||||
"github.com/sigcn/pg/vpn/nic/gvisor"
|
||||
)
|
||||
|
||||
type ProxyConfig struct {
|
||||
Listen string
|
||||
}
|
||||
|
||||
type ProxyServer struct {
|
||||
Config ProxyConfig
|
||||
GvisorCard *gvisor.GvisorCard
|
||||
|
||||
udpListener *N.UDPListener
|
||||
}
|
||||
|
||||
func (s *ProxyServer) Start(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
tcpListener, err := net.Listen("tcp", s.Config.Listen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
udpPacketConn, err := net.ListenPacket("udp", s.Config.Listen)
|
||||
if err != nil {
|
||||
tcpListener.Close()
|
||||
return err
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-ctx.Done()
|
||||
tcpListener.Close()
|
||||
udpPacketConn.Close()
|
||||
}()
|
||||
s.udpListener = &N.UDPListener{PacketConn: udpPacketConn}
|
||||
slog.Info("[Proxy] Server started", "listen", fmt.Sprintf("tcp+udp://%s", tcpListener.Addr().String()))
|
||||
go s.run(tcpListener)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServer) run(tcp net.Listener) {
|
||||
for {
|
||||
c, err := tcp.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
addr, cmd, err := socks5.ServerHandshake(c, nil)
|
||||
if err != nil {
|
||||
slog.Error("[Proxy] SOCKS5 handshake", "err", err)
|
||||
continue
|
||||
}
|
||||
if cmd == socks5.CmdConnect {
|
||||
if err := s.proxyTCP(c, addr); err != nil {
|
||||
slog.Error("[Proxy] SOCKS5 tcp", "err", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if cmd == socks5.CmdUDPAssociate {
|
||||
go func() {
|
||||
if err := s.proxyUDP(addr); err != nil {
|
||||
slog.Error("[Proxy] SOCKS5 udp", "err", err)
|
||||
}
|
||||
}()
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyServer) proxyTCP(rw net.Conn, addr socks5.Addr) error {
|
||||
c, err := s.GvisorCard.DialContext(context.TODO(), "tcp", addr.String())
|
||||
if err != nil {
|
||||
rw.Close()
|
||||
return err
|
||||
}
|
||||
go relay(rw, c)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ProxyServer) proxyUDP(addr socks5.Addr) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
c, err := s.udpListener.AcceptContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c1, err := s.GvisorCard.DialContext(context.TODO(), "udp", addr.String())
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return err
|
||||
}
|
||||
go relay(c, c1)
|
||||
return nil
|
||||
}
|
@@ -94,6 +94,7 @@ func usage(flagSet *flag.FlagSet) {
|
||||
logLevel := flagSet.Lookup("loglevel")
|
||||
mtu := flagSet.Lookup("mtu")
|
||||
peers := flagSet.Lookup("peers")
|
||||
proxyListen := flagSet.Lookup("proxy-listen")
|
||||
server := flagSet.Lookup("s")
|
||||
tun := flagSet.Lookup("tun")
|
||||
udpPort := flagSet.Lookup("udp-port")
|
||||
@@ -119,6 +120,7 @@ func usage(flagSet *flag.FlagSet) {
|
||||
fmt.Printf(" --key string\n\t%s\n", key.Usage)
|
||||
fmt.Printf(" --loglevel int\n\t%s (default %s)\n", logLevel.Usage, logLevel.DefValue)
|
||||
fmt.Printf(" --mtu int\n\t%s (default %s)\n", mtu.Usage, mtu.DefValue)
|
||||
fmt.Printf(" --proxy-listen string\n\t%s\n", proxyListen.Usage)
|
||||
fmt.Printf(" -s, --server string\n\t%s\n", server.Usage)
|
||||
fmt.Printf(" --tun string\n\t%s (default %s)\n", tun.Usage, tun.DefValue)
|
||||
fmt.Printf(" --udp-crypto\n\t%s (default %s)\n", cryptoAlgo.Usage, cryptoAlgo.DefValue)
|
||||
@@ -154,6 +156,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error)
|
||||
flagSet.IntVar(&cfg.NICConfig.MTU, "mtu", 1411, "nic mtu")
|
||||
flagSet.StringVar(&cfg.NICConfig.Name, "tun", defaultTunName, "nic name")
|
||||
flagSet.Var(&forwards, "forward", "start in rootless mode and create a port forward (e.g. tcp://127.0.0.1:80)")
|
||||
flagSet.StringVar(&cfg.ProxyConfig.Listen, "proxy-listen", "", "start a proxy server to access the PG network (e.g. 127.0.0.1:4090)")
|
||||
|
||||
flagSet.StringVar(&cfg.PrivateKey, "key", "", "curve25519 private key in base58 format (default generate a new one)")
|
||||
flagSet.StringVar(&cfg.SecretFile, "secret-file", "", "")
|
||||
@@ -213,6 +216,7 @@ func createConfig(flagSet *flag.FlagSet, args []string) (cfg Config, err error)
|
||||
|
||||
type Config struct {
|
||||
NICConfig nic.Config
|
||||
ProxyConfig rootless.ProxyConfig
|
||||
DiscoPortScanOffset int
|
||||
DiscoPortScanCount int
|
||||
DiscoPortScanDuration time.Duration
|
||||
@@ -265,6 +269,13 @@ func (v *P2PVPN) Run(ctx context.Context) (err error) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if v.Config.ProxyConfig.Listen != "" {
|
||||
if err := (&rootless.ProxyServer{
|
||||
GvisorCard: card.(*gvisor.GvisorCard),
|
||||
Config: v.Config.ProxyConfig}).Start(ctx, &wg); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := (&server.Server{
|
||||
Vnic: v.nic,
|
||||
PeerStore: c.PeerStore(),
|
||||
|
183
net/udp.go
Normal file
183
net/udp.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _ net.Conn = (*UDPConn)(nil)
|
||||
|
||||
type UDPConn struct {
|
||||
removeConn func()
|
||||
remoteAddr net.Addr
|
||||
c net.PacketConn
|
||||
|
||||
closeOnce sync.Once
|
||||
inbound chan []byte
|
||||
closeChan chan struct{}
|
||||
lastActiveTime atomic.Value
|
||||
}
|
||||
|
||||
func (c *UDPConn) init() {
|
||||
c.inbound = make(chan []byte, 512)
|
||||
c.closeChan = make(chan struct{})
|
||||
c.lastActiveTime.Store(time.Now())
|
||||
ticker := time.NewTicker(6 * time.Second)
|
||||
go func() { // create a timer to trace timeout udp conn, and close it
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second {
|
||||
c.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *UDPConn) Read(p []byte) (int, error) {
|
||||
select {
|
||||
case b := <-c.inbound:
|
||||
c.lastActiveTime.Store(time.Now())
|
||||
return copy(p, b), nil
|
||||
case <-c.closeChan:
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *UDPConn) Write(p []byte) (int, error) {
|
||||
c.lastActiveTime.Store(time.Now())
|
||||
return c.c.WriteTo(p, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *UDPConn) LocalAddr() net.Addr {
|
||||
return c.c.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *UDPConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *UDPConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.closeChan)
|
||||
close(c.inbound)
|
||||
c.removeConn()
|
||||
slog.Log(context.Background(), -2, "UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *UDPConn) SetDeadline(t time.Time) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
func (c *UDPConn) SetReadDeadline(t time.Time) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
type UDPListener struct {
|
||||
PacketConn net.PacketConn
|
||||
|
||||
buf []byte
|
||||
initOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
udpChan chan *UDPConn
|
||||
|
||||
connMap map[string]*UDPConn
|
||||
connMapMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (l *UDPListener) init() {
|
||||
l.initOnce.Do(func() {
|
||||
l.buf = make([]byte, 65535)
|
||||
l.udpChan = make(chan *UDPConn, 8)
|
||||
l.connMap = make(map[string]*UDPConn)
|
||||
go l.readUDP()
|
||||
})
|
||||
}
|
||||
|
||||
func (l *UDPListener) readUDP() {
|
||||
read := func() error {
|
||||
read:
|
||||
n, peerAddr, err := l.PacketConn.ReadFrom(l.buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
l.connMapMu.RLock()
|
||||
conn, ok := l.connMap[peerAddr.String()]
|
||||
l.connMapMu.RUnlock()
|
||||
if ok {
|
||||
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
||||
goto read
|
||||
}
|
||||
l.connMapMu.Lock()
|
||||
conn, ok = l.connMap[peerAddr.String()]
|
||||
if ok {
|
||||
l.connMapMu.Unlock()
|
||||
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
||||
goto read
|
||||
}
|
||||
defer l.connMapMu.Unlock()
|
||||
conn = &UDPConn{remoteAddr: peerAddr, c: l.PacketConn, removeConn: func() {
|
||||
l.connMapMu.Lock()
|
||||
defer l.connMapMu.Unlock()
|
||||
delete(l.connMap, peerAddr.String())
|
||||
}}
|
||||
conn.init()
|
||||
l.connMap[peerAddr.String()] = conn
|
||||
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
||||
l.udpChan <- conn
|
||||
return nil
|
||||
}
|
||||
for {
|
||||
if err := read(); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UDPListener) Accept() (net.Conn, error) {
|
||||
return l.AcceptContext(context.Background())
|
||||
}
|
||||
|
||||
func (l *UDPListener) AcceptContext(ctx context.Context) (net.Conn, error) {
|
||||
l.init()
|
||||
select {
|
||||
case c := <-l.udpChan:
|
||||
return c, nil
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UDPListener) Close() error {
|
||||
if l.PacketConn == nil {
|
||||
return nil
|
||||
}
|
||||
l.closeOnce.Do(func() {
|
||||
l.PacketConn.Close()
|
||||
l.connMapMu.Lock()
|
||||
defer l.connMapMu.Unlock()
|
||||
for _, c := range l.connMap {
|
||||
go c.Close()
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *UDPListener) Addr() net.Addr {
|
||||
l.init()
|
||||
if l.PacketConn == nil {
|
||||
return nil
|
||||
}
|
||||
return l.PacketConn.LocalAddr()
|
||||
}
|
457
socks5/socks5.go
Normal file
457
socks5/socks5.go
Normal file
@@ -0,0 +1,457 @@
|
||||
// Copyright (c) 2024 sigcn/pg
|
||||
// Licensed under the GNU GENERAL PUBLIC LICENSE Version 3.
|
||||
// Copyright (c) 2021-2024 clash
|
||||
// Licensed under the GNU GENERAL PUBLIC LICENSE Version 3.
|
||||
package socks5
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type Authenticator interface {
|
||||
Verify(user string, pass string) bool
|
||||
}
|
||||
|
||||
// Error represents a SOCKS error
|
||||
type Error byte
|
||||
|
||||
func (err Error) Error() string {
|
||||
return "SOCKS error: " + strconv.Itoa(int(err))
|
||||
}
|
||||
|
||||
// Command is request commands as defined in RFC 1928 section 4.
|
||||
type Command = uint8
|
||||
|
||||
const Version = 5
|
||||
|
||||
// SOCKS request commands as defined in RFC 1928 section 4.
|
||||
const (
|
||||
CmdConnect Command = 1
|
||||
CmdBind Command = 2
|
||||
CmdUDPAssociate Command = 3
|
||||
)
|
||||
|
||||
// SOCKS address types as defined in RFC 1928 section 5.
|
||||
const (
|
||||
AtypIPv4 = 1
|
||||
AtypDomainName = 3
|
||||
AtypIPv6 = 4
|
||||
)
|
||||
|
||||
// MaxAddrLen is the maximum size of SOCKS address in bytes.
|
||||
const MaxAddrLen = 1 + 1 + 255 + 2
|
||||
|
||||
// MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth
|
||||
const MaxAuthLen = 255
|
||||
|
||||
// Addr represents a SOCKS address as defined in RFC 1928 section 5.
|
||||
type Addr []byte
|
||||
|
||||
func (a Addr) String() string {
|
||||
var host, port string
|
||||
|
||||
switch a[0] {
|
||||
case AtypDomainName:
|
||||
hostLen := uint16(a[1])
|
||||
host = string(a[2 : 2+hostLen])
|
||||
port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
|
||||
case AtypIPv4:
|
||||
host = net.IP(a[1 : 1+net.IPv4len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
|
||||
case AtypIPv6:
|
||||
host = net.IP(a[1 : 1+net.IPv6len]).String()
|
||||
port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
|
||||
}
|
||||
|
||||
return net.JoinHostPort(host, port)
|
||||
}
|
||||
|
||||
// UDPAddr converts a socks5.Addr to *net.UDPAddr
|
||||
func (a Addr) UDPAddr() *net.UDPAddr {
|
||||
if len(a) == 0 {
|
||||
return nil
|
||||
}
|
||||
switch a[0] {
|
||||
case AtypIPv4:
|
||||
var ip [net.IPv4len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv4len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
|
||||
case AtypIPv6:
|
||||
var ip [net.IPv6len]byte
|
||||
copy(ip[0:], a[1:1+net.IPv6len])
|
||||
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
|
||||
}
|
||||
// Other Atyp
|
||||
return nil
|
||||
}
|
||||
|
||||
// SOCKS errors as defined in RFC 1928 section 6.
|
||||
const (
|
||||
ErrGeneralFailure = Error(1)
|
||||
ErrConnectionNotAllowed = Error(2)
|
||||
ErrNetworkUnreachable = Error(3)
|
||||
ErrHostUnreachable = Error(4)
|
||||
ErrConnectionRefused = Error(5)
|
||||
ErrTTLExpired = Error(6)
|
||||
ErrCommandNotSupported = Error(7)
|
||||
ErrAddressNotSupported = Error(8)
|
||||
)
|
||||
|
||||
// Auth errors used to return a specific "Auth failed" error
|
||||
var ErrAuth = errors.New("auth failed")
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
// ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
|
||||
func ServerHandshake(rw net.Conn, authenticator Authenticator) (addr Addr, command Command, err error) {
|
||||
// Read RFC 1928 for request and reply structure and sizes.
|
||||
buf := make([]byte, MaxAddrLen)
|
||||
// read VER, NMETHODS, METHODS
|
||||
if _, err = io.ReadFull(rw, buf[:2]); err != nil {
|
||||
return
|
||||
}
|
||||
nmethods := buf[1]
|
||||
if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// write VER METHOD
|
||||
if authenticator != nil {
|
||||
if _, err = rw.Write([]byte{5, 2}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Get header
|
||||
header := make([]byte, 2)
|
||||
if _, err = io.ReadFull(rw, header); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
authBuf := make([]byte, MaxAuthLen)
|
||||
// Get username
|
||||
userLen := int(header[1])
|
||||
if userLen <= 0 {
|
||||
rw.Write([]byte{1, 1})
|
||||
err = ErrAuth
|
||||
return
|
||||
}
|
||||
if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
|
||||
return
|
||||
}
|
||||
user := string(authBuf[:userLen])
|
||||
|
||||
// Get password
|
||||
if _, err = rw.Read(header[:1]); err != nil {
|
||||
return
|
||||
}
|
||||
passLen := int(header[0])
|
||||
if passLen <= 0 {
|
||||
rw.Write([]byte{1, 1})
|
||||
err = ErrAuth
|
||||
return
|
||||
}
|
||||
if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil {
|
||||
return
|
||||
}
|
||||
pass := string(authBuf[:passLen])
|
||||
|
||||
// Verify
|
||||
if ok := authenticator.Verify(user, pass); !ok {
|
||||
rw.Write([]byte{1, 1})
|
||||
err = ErrAuth
|
||||
return
|
||||
}
|
||||
|
||||
// Response auth state
|
||||
if _, err = rw.Write([]byte{1, 0}); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if _, err = rw.Write([]byte{5, 0}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// read VER CMD RSV ATYP DST.ADDR DST.PORT
|
||||
if _, err = io.ReadFull(rw, buf[:3]); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
command = buf[1]
|
||||
addr, err = ReadAddr(rw, buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch command {
|
||||
case CmdConnect, CmdUDPAssociate:
|
||||
// Acquire server listened address info
|
||||
localAddr := ParseAddr(rw.LocalAddr().String())
|
||||
if localAddr == nil {
|
||||
err = ErrAddressNotSupported
|
||||
} else {
|
||||
// write VER REP RSV ATYP BND.ADDR BND.PORT
|
||||
_, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
|
||||
}
|
||||
case CmdBind:
|
||||
fallthrough
|
||||
default:
|
||||
err = ErrCommandNotSupported
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
|
||||
func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
|
||||
buf := make([]byte, MaxAddrLen)
|
||||
var err error
|
||||
|
||||
// VER, NMETHODS, METHODS
|
||||
if user != nil {
|
||||
_, err = rw.Write([]byte{5, 1, 2})
|
||||
} else {
|
||||
_, err = rw.Write([]byte{5, 1, 0})
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// VER, METHOD
|
||||
if _, err := io.ReadFull(rw, buf[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if buf[0] != 5 {
|
||||
return nil, errors.New("SOCKS version error")
|
||||
}
|
||||
|
||||
if buf[1] == 2 {
|
||||
if user == nil {
|
||||
return nil, ErrAuth
|
||||
}
|
||||
|
||||
// password protocol version
|
||||
authMsg := &bytes.Buffer{}
|
||||
authMsg.WriteByte(1)
|
||||
authMsg.WriteByte(uint8(len(user.Username)))
|
||||
authMsg.WriteString(user.Username)
|
||||
authMsg.WriteByte(uint8(len(user.Password)))
|
||||
authMsg.WriteString(user.Password)
|
||||
|
||||
if _, err := rw.Write(authMsg.Bytes()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(rw, buf[:2]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if buf[1] != 0 {
|
||||
return nil, errors.New("rejected username/password")
|
||||
}
|
||||
} else if buf[1] != 0 {
|
||||
return nil, errors.New("SOCKS need auth")
|
||||
}
|
||||
|
||||
// VER, CMD, RSV, ADDR
|
||||
if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// VER, REP, RSV
|
||||
if _, err := io.ReadFull(rw, buf[:3]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ReadAddr(rw, buf)
|
||||
}
|
||||
|
||||
func ReadAddr(r io.Reader, b []byte) (Addr, error) {
|
||||
if len(b) < MaxAddrLen {
|
||||
return nil, io.ErrShortBuffer
|
||||
}
|
||||
_, err := io.ReadFull(r, b[:1]) // read 1st byte for address type
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch b[0] {
|
||||
case AtypDomainName:
|
||||
_, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domainLength := uint16(b[1])
|
||||
_, err = io.ReadFull(r, b[2:2+domainLength+2])
|
||||
return b[:1+1+domainLength+2], err
|
||||
case AtypIPv4:
|
||||
_, err = io.ReadFull(r, b[1:1+net.IPv4len+2])
|
||||
return b[:1+net.IPv4len+2], err
|
||||
case AtypIPv6:
|
||||
_, err = io.ReadFull(r, b[1:1+net.IPv6len+2])
|
||||
return b[:1+net.IPv6len+2], err
|
||||
}
|
||||
|
||||
return nil, ErrAddressNotSupported
|
||||
}
|
||||
|
||||
// SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed.
|
||||
func SplitAddr(b []byte) Addr {
|
||||
addrLen := 1
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch b[0] {
|
||||
case AtypDomainName:
|
||||
if len(b) < 2 {
|
||||
return nil
|
||||
}
|
||||
addrLen = 1 + 1 + int(b[1]) + 2
|
||||
case AtypIPv4:
|
||||
addrLen = 1 + net.IPv4len + 2
|
||||
case AtypIPv6:
|
||||
addrLen = 1 + net.IPv6len + 2
|
||||
default:
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
if len(b) < addrLen {
|
||||
return nil
|
||||
}
|
||||
|
||||
return b[:addrLen]
|
||||
}
|
||||
|
||||
// ParseAddr parses the address in string s. Returns nil if failed.
|
||||
func ParseAddr(s string) Addr {
|
||||
var addr Addr
|
||||
host, port, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
addr = make([]byte, 1+net.IPv4len+2)
|
||||
addr[0] = AtypIPv4
|
||||
copy(addr[1:], ip4)
|
||||
} else {
|
||||
addr = make([]byte, 1+net.IPv6len+2)
|
||||
addr[0] = AtypIPv6
|
||||
copy(addr[1:], ip)
|
||||
}
|
||||
} else {
|
||||
if len(host) > 255 {
|
||||
return nil
|
||||
}
|
||||
addr = make([]byte, 1+1+len(host)+2)
|
||||
addr[0] = AtypDomainName
|
||||
addr[1] = byte(len(host))
|
||||
copy(addr[2:], host)
|
||||
}
|
||||
|
||||
portnum, err := strconv.ParseUint(port, 10, 16)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
|
||||
|
||||
return addr
|
||||
}
|
||||
|
||||
// ParseAddrToSocksAddr parse a socks addr from net.addr
|
||||
// This is a fast path of ParseAddr(addr.String())
|
||||
func ParseAddrToSocksAddr(addr net.Addr) Addr {
|
||||
var hostip net.IP
|
||||
var port int
|
||||
if udpaddr, ok := addr.(*net.UDPAddr); ok {
|
||||
hostip = udpaddr.IP
|
||||
port = udpaddr.Port
|
||||
} else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
|
||||
hostip = tcpaddr.IP
|
||||
port = tcpaddr.Port
|
||||
}
|
||||
|
||||
// fallback parse
|
||||
if hostip == nil {
|
||||
return ParseAddr(addr.String())
|
||||
}
|
||||
|
||||
var parsed Addr
|
||||
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
|
||||
parsed = make([]byte, 1+net.IPv4len+2)
|
||||
parsed[0] = AtypIPv4
|
||||
copy(parsed[1:], ip4)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
|
||||
|
||||
} else {
|
||||
parsed = make([]byte, 1+net.IPv6len+2)
|
||||
parsed[0] = AtypIPv6
|
||||
copy(parsed[1:], hostip)
|
||||
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
|
||||
}
|
||||
return parsed
|
||||
}
|
||||
|
||||
func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr {
|
||||
addr := addrPort.Addr()
|
||||
if addr.Is4() {
|
||||
ip4 := addr.As4()
|
||||
return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())}
|
||||
}
|
||||
|
||||
buf := make([]byte, 1+net.IPv6len+2)
|
||||
buf[0] = AtypIPv6
|
||||
copy(buf[1:], addr.AsSlice())
|
||||
buf[1+net.IPv6len] = byte(addrPort.Port() >> 8)
|
||||
buf[1+net.IPv6len+1] = byte(addrPort.Port())
|
||||
return buf
|
||||
}
|
||||
|
||||
// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
|
||||
func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
|
||||
if len(packet) < 5 {
|
||||
err = errors.New("insufficient length of packet")
|
||||
return
|
||||
}
|
||||
|
||||
// packet[0] and packet[1] are reserved
|
||||
if !bytes.Equal(packet[:2], []byte{0, 0}) {
|
||||
err = errors.New("reserved fields should be zero")
|
||||
return
|
||||
}
|
||||
|
||||
if packet[2] != 0 /* fragments */ {
|
||||
err = errors.New("discarding fragmented payload")
|
||||
return
|
||||
}
|
||||
|
||||
addr = SplitAddr(packet[3:])
|
||||
if addr == nil {
|
||||
err = errors.New("failed to read UDP header")
|
||||
}
|
||||
|
||||
payload = packet[3+len(addr):]
|
||||
return
|
||||
}
|
||||
|
||||
func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
|
||||
if addr == nil {
|
||||
err = errors.New("address is invalid")
|
||||
return
|
||||
}
|
||||
packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{})
|
||||
return
|
||||
}
|
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
N "github.com/sigcn/pg/net"
|
||||
"github.com/sigcn/pg/vpn/nic"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -17,6 +18,8 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -122,6 +125,75 @@ func (g *GvisorCard) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GvisorCard) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
g.init()
|
||||
if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") {
|
||||
return nil, errors.New("only tcp/udp is supported")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(network, "tcp") {
|
||||
tcpAddr, err := net.ResolveTCPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var add tcpip.Address
|
||||
var protocol tcpip.NetworkProtocolNumber
|
||||
if tcpAddr.IP.To4() != nil {
|
||||
add = tcpip.AddrFrom4(tcpAddr.AddrPort().Addr().As4())
|
||||
protocol = ipv4.ProtocolNumber
|
||||
} else {
|
||||
add = tcpip.AddrFrom16(tcpAddr.AddrPort().Addr().As16())
|
||||
protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
addr := tcpip.FullAddress{
|
||||
NIC: g.nicID,
|
||||
Addr: add,
|
||||
Port: tcpAddr.AddrPort().Port()}
|
||||
return gonet.DialContextTCP(ctx, g.Stack, addr, protocol)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(network, "udp") {
|
||||
udpAddr, err := net.ResolveUDPAddr(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var add tcpip.Address
|
||||
var protocol tcpip.NetworkProtocolNumber
|
||||
if udpAddr.IP.To4() != nil {
|
||||
add = tcpip.AddrFrom4(udpAddr.AddrPort().Addr().As4())
|
||||
protocol = ipv4.ProtocolNumber
|
||||
} else {
|
||||
add = tcpip.AddrFrom16(udpAddr.AddrPort().Addr().As16())
|
||||
protocol = ipv6.ProtocolNumber
|
||||
}
|
||||
addr := &tcpip.FullAddress{
|
||||
NIC: g.nicID,
|
||||
Addr: add,
|
||||
Port: udpAddr.AddrPort().Port()}
|
||||
return gonet.DialUDP(g.Stack, nil, addr, protocol)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (g *GvisorCard) listenUDP(addr tcpip.FullAddress) (net.PacketConn, error) {
|
||||
var wq waiter.Queue
|
||||
var ep tcpip.Endpoint
|
||||
var err tcpip.Error
|
||||
if net.IP(addr.Addr.AsSlice()).To4() != nil {
|
||||
ep, err = g.Stack.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
|
||||
} else {
|
||||
ep, err = g.Stack.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, errors.New(err.String())
|
||||
}
|
||||
err = ep.Bind(addr)
|
||||
if err != nil {
|
||||
return nil, errors.New(err.String())
|
||||
}
|
||||
return gonet.NewUDPConn(&wq, ep), nil
|
||||
}
|
||||
|
||||
func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l net.Listener, err error) {
|
||||
g.init()
|
||||
if !strings.HasPrefix(network, "tcp") && !strings.HasPrefix(network, "udp") {
|
||||
@@ -140,12 +212,20 @@ func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l
|
||||
|
||||
if network == "udp4" {
|
||||
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port}
|
||||
return &udpListener{s: g.Stack, addr: addr}, nil
|
||||
pc, err := g.listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &N.UDPListener{PacketConn: pc}, nil
|
||||
}
|
||||
|
||||
if network == "udp6" {
|
||||
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port}
|
||||
return &udpListener{s: g.Stack, addr: addr}, nil
|
||||
pc, err := g.listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &N.UDPListener{PacketConn: pc}, nil
|
||||
}
|
||||
|
||||
var listeners []net.Listener
|
||||
@@ -160,11 +240,19 @@ func (g *GvisorCard) Listen(ctx context.Context, network string, port uint16) (l
|
||||
if network == "udp" {
|
||||
if g.addr4.Len() > 0 {
|
||||
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr4, Port: port}
|
||||
listeners = append(listeners, &udpListener{s: g.Stack, addr: addr})
|
||||
pc, err := g.listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listeners = append(listeners, &N.UDPListener{PacketConn: pc})
|
||||
}
|
||||
if g.addr6.Len() > 0 {
|
||||
addr := tcpip.FullAddress{NIC: g.nicID, Addr: g.addr6, Port: port}
|
||||
listeners = append(listeners, &udpListener{s: g.Stack, addr: addr})
|
||||
pc, err := g.listenUDP(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listeners = append(listeners, &N.UDPListener{PacketConn: pc})
|
||||
}
|
||||
return &combinedListeners{listeners: listeners}, nil
|
||||
}
|
||||
|
@@ -1,192 +0,0 @@
|
||||
package gvisor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
var _ net.Listener = (*udpListener)(nil)
|
||||
var _ net.Conn = (*udpConn)(nil)
|
||||
|
||||
type udpConn struct {
|
||||
removeConn func()
|
||||
remoteAddr net.Addr
|
||||
c *gonet.UDPConn
|
||||
|
||||
closeOnce sync.Once
|
||||
inbound chan []byte
|
||||
closeChan chan struct{}
|
||||
lastActiveTime atomic.Value
|
||||
}
|
||||
|
||||
func (c *udpConn) init() {
|
||||
c.inbound = make(chan []byte, 512)
|
||||
c.closeChan = make(chan struct{})
|
||||
c.lastActiveTime.Store(time.Now())
|
||||
ticker := time.NewTicker(6 * time.Second)
|
||||
go func() { // create a timer to trace timeout udp conn, and close it
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if time.Since(c.lastActiveTime.Load().(time.Time)) > 10*time.Second {
|
||||
c.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *udpConn) Read(p []byte) (int, error) {
|
||||
select {
|
||||
case b := <-c.inbound:
|
||||
c.lastActiveTime.Store(time.Now())
|
||||
return copy(p, b), nil
|
||||
case <-c.closeChan:
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
func (c *udpConn) Write(p []byte) (int, error) {
|
||||
c.lastActiveTime.Store(time.Now())
|
||||
return c.c.WriteTo(p, c.remoteAddr)
|
||||
}
|
||||
|
||||
func (c *udpConn) LocalAddr() net.Addr {
|
||||
return c.c.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *udpConn) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *udpConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
close(c.closeChan)
|
||||
close(c.inbound)
|
||||
c.removeConn()
|
||||
slog.Log(context.Background(), -2, "[gVisor] UDPConn closed", "local_addr", c.LocalAddr(), "remote_addr", c.remoteAddr)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *udpConn) SetDeadline(t time.Time) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
func (c *udpConn) SetReadDeadline(t time.Time) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
func (c *udpConn) SetWriteDeadline(t time.Time) error {
|
||||
return errors.ErrUnsupported
|
||||
}
|
||||
|
||||
type udpListener struct {
|
||||
addr tcpip.FullAddress
|
||||
s *stack.Stack
|
||||
|
||||
buf []byte
|
||||
c *gonet.UDPConn
|
||||
initErr error
|
||||
initOnce sync.Once
|
||||
closeOnce sync.Once
|
||||
|
||||
connMap map[net.Addr]*udpConn
|
||||
connMapMu sync.RWMutex
|
||||
}
|
||||
|
||||
func (l *udpListener) init() {
|
||||
l.initOnce.Do(func() {
|
||||
var wq waiter.Queue
|
||||
var ep tcpip.Endpoint
|
||||
var err tcpip.Error
|
||||
if net.IP(l.addr.Addr.AsSlice()).To4() != nil {
|
||||
ep, err = l.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
|
||||
} else {
|
||||
ep, err = l.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq)
|
||||
}
|
||||
if err != nil {
|
||||
l.initErr = errors.New(err.String())
|
||||
return
|
||||
}
|
||||
err = ep.Bind(l.addr)
|
||||
if err != nil {
|
||||
l.initErr = errors.New(err.String())
|
||||
return
|
||||
}
|
||||
l.buf = make([]byte, 65535)
|
||||
l.connMap = make(map[net.Addr]*udpConn)
|
||||
l.c = gonet.NewUDPConn(&wq, ep)
|
||||
})
|
||||
}
|
||||
|
||||
func (l *udpListener) Accept() (net.Conn, error) {
|
||||
l.init()
|
||||
if l.initErr != nil {
|
||||
return nil, l.initErr
|
||||
}
|
||||
read:
|
||||
n, peerAddr, err := l.c.ReadFrom(l.buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l.connMapMu.RLock()
|
||||
conn, ok := l.connMap[peerAddr]
|
||||
l.connMapMu.RUnlock()
|
||||
if ok {
|
||||
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
||||
goto read
|
||||
}
|
||||
l.connMapMu.Lock()
|
||||
conn, ok = l.connMap[peerAddr]
|
||||
if ok {
|
||||
l.connMapMu.Unlock()
|
||||
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
||||
goto read
|
||||
}
|
||||
defer l.connMapMu.Unlock()
|
||||
conn = &udpConn{remoteAddr: peerAddr, c: l.c, removeConn: func() {
|
||||
l.connMapMu.Lock()
|
||||
defer l.connMapMu.Unlock()
|
||||
delete(l.connMap, peerAddr)
|
||||
}}
|
||||
conn.init()
|
||||
l.connMap[peerAddr] = conn
|
||||
conn.inbound <- append([]byte(nil), l.buf[:n]...)
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *udpListener) Close() error {
|
||||
if l.c == nil {
|
||||
return nil
|
||||
}
|
||||
l.closeOnce.Do(func() {
|
||||
l.c.Close()
|
||||
l.connMapMu.Lock()
|
||||
defer l.connMapMu.Unlock()
|
||||
for _, c := range l.connMap {
|
||||
go c.Close()
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *udpListener) Addr() net.Addr {
|
||||
l.init()
|
||||
if l.c == nil {
|
||||
return nil
|
||||
}
|
||||
return l.c.LocalAddr()
|
||||
}
|
Reference in New Issue
Block a user