mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2025-10-07 17:51:16 +08:00
Refactor: optimize UDP module
Symmetric NAT support for now.
This commit is contained in:
@@ -2,33 +2,15 @@ package core
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TCPConn implements the net.Conn interface.
|
||||||
type TCPConn interface {
|
type TCPConn interface {
|
||||||
net.Conn
|
net.Conn
|
||||||
ID() *stack.TransportEndpointID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type UDPPacket interface {
|
// UDPConn implements net.Conn and net.PacketConn.
|
||||||
// Data get the payload of UDP Packet.
|
type UDPConn interface {
|
||||||
Data() []byte
|
net.Conn
|
||||||
|
net.PacketConn
|
||||||
// Drop call after packet is used, could release resources in this function.
|
|
||||||
Drop()
|
|
||||||
|
|
||||||
// ID returns the transport endpoint id of packet.
|
|
||||||
ID() *stack.TransportEndpointID
|
|
||||||
|
|
||||||
// LocalAddr returns the source IP/Port of packet.
|
|
||||||
LocalAddr() net.Addr
|
|
||||||
|
|
||||||
// RemoteAddr returns the destination IP/Port of packet.
|
|
||||||
RemoteAddr() net.Addr
|
|
||||||
|
|
||||||
// WriteBack writes the payload with source IP/Port equals addr
|
|
||||||
// - variable source IP/Port is important to STUN
|
|
||||||
// - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target.
|
|
||||||
WriteBack([]byte, net.Addr) (int, error)
|
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
|
// Handler is a TCP/UDP connection handler that implements
|
||||||
|
// HandleTCPConn and HandleUDPConn methods.
|
||||||
type Handler interface {
|
type Handler interface {
|
||||||
Add(TCPConn)
|
HandleTCPConn(TCPConn)
|
||||||
AddPacket(UDPPacket)
|
HandleUDPConn(UDPConn)
|
||||||
}
|
}
|
||||||
|
@@ -2,12 +2,10 @@ package stack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
)
|
)
|
||||||
@@ -36,10 +34,9 @@ func withTCPHandler() Option {
|
|||||||
return func(s *Stack) error {
|
return func(s *Stack) error {
|
||||||
tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
|
tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
|
||||||
var wq waiter.Queue
|
var wq waiter.Queue
|
||||||
id := r.ID()
|
|
||||||
ep, err := r.CreateEndpoint(&wq)
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// prevent potential half-open TCP connection leak.
|
// RST: prevent potential half-open TCP connection leak.
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -47,11 +44,7 @@ func withTCPHandler() Option {
|
|||||||
|
|
||||||
setKeepalive(ep)
|
setKeepalive(ep)
|
||||||
|
|
||||||
conn := &tcpConn{
|
s.handler.HandleTCPConn(gonet.NewTCPConn(&wq, ep))
|
||||||
Conn: gonet.NewTCPConn(&wq, ep),
|
|
||||||
id: &id,
|
|
||||||
}
|
|
||||||
s.handler.Add(conn)
|
|
||||||
})
|
})
|
||||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
return nil
|
return nil
|
||||||
@@ -72,12 +65,3 @@ func setKeepalive(ep tcpip.Endpoint) error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type tcpConn struct {
|
|
||||||
net.Conn
|
|
||||||
id *stack.TransportEndpointID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *tcpConn) ID() *stack.TransportEndpointID {
|
|
||||||
return c.id
|
|
||||||
}
|
|
||||||
|
@@ -1,192 +1,24 @@
|
|||||||
package stack
|
package stack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"net"
|
|
||||||
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
)
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
const (
|
|
||||||
// udpNoChecksum disables UDP checksum if set to true.
|
|
||||||
udpNoChecksum = true
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func withUDPHandler() Option {
|
func withUDPHandler() Option {
|
||||||
return func(s *Stack) error {
|
return func(s *Stack) error {
|
||||||
udpHandlePacket := func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) {
|
||||||
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go HandlePacket.
|
var wq waiter.Queue
|
||||||
udpHdr := header.UDP(pkt.TransportHeader().View())
|
ep, err := r.CreateEndpoint(&wq)
|
||||||
if int(udpHdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
|
|
||||||
// Malformed packet.
|
|
||||||
s.Stats().UDP.MalformedPacketsReceived.Increment()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if !verifyChecksum(udpHdr, pkt) {
|
|
||||||
// Checksum error.
|
|
||||||
s.Stats().UDP.ChecksumErrors.Increment()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Stats().UDP.PacketsReceived.Increment()
|
|
||||||
|
|
||||||
packet := &udpPacket{
|
|
||||||
s: s,
|
|
||||||
id: &id,
|
|
||||||
data: pkt.Data().ExtractVV(),
|
|
||||||
nicID: pkt.NICID,
|
|
||||||
netHdr: pkt.Network(),
|
|
||||||
netProto: pkt.NetworkProtocolNumber,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.handler.AddPacket(packet)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpHandlePacket)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpPacket struct {
|
|
||||||
s *Stack
|
|
||||||
id *stack.TransportEndpointID
|
|
||||||
data buffer.VectorisedView
|
|
||||||
nicID tcpip.NICID
|
|
||||||
netHdr header.Network
|
|
||||||
netProto tcpip.NetworkProtocolNumber
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *udpPacket) Data() []byte {
|
|
||||||
return p.data.ToView()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *udpPacket) Drop() {}
|
|
||||||
|
|
||||||
func (p *udpPacket) ID() *stack.TransportEndpointID {
|
|
||||||
return p.id
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *udpPacket) LocalAddr() net.Addr {
|
|
||||||
return &net.UDPAddr{IP: net.IP(p.id.LocalAddress), Port: int(p.id.LocalPort)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *udpPacket) RemoteAddr() net.Addr {
|
|
||||||
return &net.UDPAddr{IP: net.IP(p.id.RemoteAddress), Port: int(p.id.RemotePort)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
|
|
||||||
v := buffer.View(b)
|
|
||||||
if len(v) > header.UDPMaximumPacketSize {
|
|
||||||
// Payload can't possibly fit in a packet.
|
|
||||||
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
localAddress tcpip.Address
|
|
||||||
localPort uint16
|
|
||||||
)
|
|
||||||
|
|
||||||
if udpAddr, ok := addr.(*net.UDPAddr); !ok {
|
|
||||||
localAddress = p.netHdr.DestinationAddress()
|
|
||||||
localPort = p.id.LocalPort
|
|
||||||
} else if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
|
|
||||||
localAddress = tcpip.Address(ipv4)
|
|
||||||
localPort = uint16(udpAddr.Port)
|
|
||||||
} else {
|
|
||||||
localAddress = tcpip.Address(udpAddr.IP)
|
|
||||||
localPort = uint16(udpAddr.Port)
|
|
||||||
}
|
|
||||||
|
|
||||||
route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
|
// TODO: handler errors in the future.
|
||||||
|
return
|
||||||
}
|
}
|
||||||
defer route.Release()
|
|
||||||
|
|
||||||
data := v.ToVectorisedView()
|
s.handler.HandleUDPConn(gonet.NewUDPConn(s.Stack, &wq, ep))
|
||||||
if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil {
|
|
||||||
return 0, fmt.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
return data.Size(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// sendUDP sends a UDP segment via the provided network endpoint and under the
|
|
||||||
// provided identity.
|
|
||||||
func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, noChecksum bool) tcpip.Error {
|
|
||||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
|
||||||
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
|
|
||||||
Data: data,
|
|
||||||
})
|
})
|
||||||
defer pkt.DecRef()
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
// Initialize the UDP header.
|
|
||||||
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
|
||||||
pkt.TransportProtocolNumber = udp.ProtocolNumber
|
|
||||||
|
|
||||||
length := uint16(pkt.Size())
|
|
||||||
udpHdr.Encode(&header.UDPFields{
|
|
||||||
SrcPort: localPort,
|
|
||||||
DstPort: remotePort,
|
|
||||||
Length: length,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Set the checksum field unless TX checksum offload is enabled.
|
|
||||||
// On IPv4, UDP checksum is optional, and a zero value indicates the
|
|
||||||
// transmitter skipped the checksum generation (RFC768).
|
|
||||||
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
|
|
||||||
if r.RequiresTXTransportChecksum() &&
|
|
||||||
(!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
|
|
||||||
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
|
|
||||||
for _, v := range data.Views() {
|
|
||||||
xsum = header.Checksum(v, xsum)
|
|
||||||
}
|
|
||||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
|
||||||
}
|
|
||||||
|
|
||||||
ttl := r.DefaultTTL()
|
|
||||||
|
|
||||||
if err := r.WritePacket(stack.NetworkHeaderParams{
|
|
||||||
Protocol: udp.ProtocolNumber,
|
|
||||||
TTL: ttl,
|
|
||||||
TOS: 0, /* default */
|
|
||||||
}, pkt); err != nil {
|
|
||||||
r.Stats().UDP.PacketSendErrors.Increment()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track count of packets sent.
|
|
||||||
r.Stats().UDP.PacketsSent.Increment()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go verifyChecksum.
|
|
||||||
// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
|
|
||||||
func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
|
|
||||||
if pkt.RXTransportChecksumValidated {
|
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
|
|
||||||
// omitted the checksum generation, as per RFC 768:
|
|
||||||
//
|
|
||||||
// An all zero transmitted checksum value means that the transmitter
|
|
||||||
// generated no checksum (for debugging or for higher level protocols that
|
|
||||||
// don't care).
|
|
||||||
//
|
|
||||||
// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
|
|
||||||
//
|
|
||||||
// Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
|
|
||||||
// checksum is not optional.
|
|
||||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
netHdr := pkt.Network()
|
|
||||||
payloadChecksum := pkt.Data().AsRange().Checksum()
|
|
||||||
return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum)
|
|
||||||
}
|
}
|
||||||
|
@@ -9,10 +9,10 @@ var _ core.Handler = (*fakeTunnel)(nil)
|
|||||||
|
|
||||||
type fakeTunnel struct{}
|
type fakeTunnel struct{}
|
||||||
|
|
||||||
func (*fakeTunnel) Add(conn core.TCPConn) {
|
func (*fakeTunnel) HandleTCPConn(conn core.TCPConn) {
|
||||||
tunnel.Add(conn)
|
tunnel.TCPIn() <- conn
|
||||||
}
|
}
|
||||||
|
|
||||||
func (*fakeTunnel) AddPacket(packet core.UDPPacket) {
|
func (*fakeTunnel) HandleUDPConn(conn core.UDPConn) {
|
||||||
tunnel.AddPacket(packet)
|
tunnel.UDPIn() <- conn
|
||||||
}
|
}
|
||||||
|
25
tunnel/addr.go
Normal file
25
tunnel/addr.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// parseAddr parses net.Addr to IP and port.
|
||||||
|
func parseAddr(addr net.Addr) (net.IP, uint16) {
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.TCPAddr:
|
||||||
|
return v.IP, uint16(v.Port)
|
||||||
|
case *net.UDPAddr:
|
||||||
|
return v.IP, uint16(v.Port)
|
||||||
|
default:
|
||||||
|
return parseAddrString(addr.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAddrString parses address string to IP and port.
|
||||||
|
func parseAddrString(addr string) (net.IP, uint16) {
|
||||||
|
host, port, _ := net.SplitHostPort(addr)
|
||||||
|
portInt, _ := strconv.ParseUint(port, 10, 16)
|
||||||
|
return net.ParseIP(host), uint16(portInt)
|
||||||
|
}
|
@@ -22,16 +22,19 @@ func newTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn {
|
|||||||
return statistic.NewTCPTracker(conn, metadata, statistic.DefaultManager)
|
return statistic.NewTCPTracker(conn, metadata, statistic.DefaultManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleTCP(localConn core.TCPConn) {
|
func handleTCPConn(localConn core.TCPConn) {
|
||||||
defer localConn.Close()
|
defer localConn.Close()
|
||||||
|
|
||||||
id := localConn.ID()
|
var (
|
||||||
|
srcIP, srcPort = parseAddr(localConn.RemoteAddr())
|
||||||
|
dstIP, dstPort = parseAddr(localConn.LocalAddr())
|
||||||
|
)
|
||||||
metadata := &M.Metadata{
|
metadata := &M.Metadata{
|
||||||
Net: M.TCP,
|
Net: M.TCP,
|
||||||
SrcIP: net.IP(id.RemoteAddress),
|
SrcIP: srcIP,
|
||||||
SrcPort: id.RemotePort,
|
SrcPort: srcPort,
|
||||||
DstIP: net.IP(id.LocalAddress),
|
DstIP: dstIP,
|
||||||
DstPort: id.LocalPort,
|
DstPort: dstPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
targetConn, err := proxy.Dial(metadata)
|
targetConn, err := proxy.Dial(metadata)
|
||||||
@@ -39,13 +42,7 @@ func handleTCP(localConn core.TCPConn) {
|
|||||||
log.Warnf("[TCP] dial %s error: %v", metadata.DestinationAddress(), err)
|
log.Warnf("[TCP] dial %s error: %v", metadata.DestinationAddress(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr())
|
||||||
if dialerAddr, ok := targetConn.LocalAddr().(*net.TCPAddr); ok {
|
|
||||||
metadata.MidIP = dialerAddr.IP
|
|
||||||
metadata.MidPort = uint16(dialerAddr.Port)
|
|
||||||
} else { /* fallback */
|
|
||||||
metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
targetConn = newTCPTracker(targetConn, metadata)
|
targetConn = newTCPTracker(targetConn, metadata)
|
||||||
defer targetConn.Close()
|
defer targetConn.Close()
|
||||||
|
@@ -1,55 +1,36 @@
|
|||||||
package tunnel
|
package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"github.com/xjasonlyu/tun2socks/v2/core"
|
"github.com/xjasonlyu/tun2socks/v2/core"
|
||||||
"github.com/xjasonlyu/tun2socks/v2/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// maxUDPQueueSize is the max number of UDP packets
|
|
||||||
// could be buffered. if queue is full, upcoming packets
|
|
||||||
// would be dropped util queue is ready again.
|
|
||||||
maxUDPQueueSize = 1 << 9
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Unbuffered TCP/UDP queues.
|
||||||
var (
|
var (
|
||||||
_tcpQueue = make(chan core.TCPConn) /* unbuffered */
|
_tcpQueue = make(chan core.TCPConn)
|
||||||
_udpQueue = make(chan core.UDPPacket, maxUDPQueueSize)
|
_udpQueue = make(chan core.UDPConn)
|
||||||
_numUDPWorkers = max(runtime.GOMAXPROCS(0), 4 /* at least 4 workers */)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
go process()
|
go process()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add adds tcpConn to tcpQueue.
|
// TCPIn return fan-in TCP queue.
|
||||||
func Add(conn core.TCPConn) {
|
func TCPIn() chan<- core.TCPConn {
|
||||||
_tcpQueue <- conn
|
return _tcpQueue
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPacket adds udpPacket to udpQueue.
|
// UDPIn return fan-in UDP queue.
|
||||||
func AddPacket(packet core.UDPPacket) {
|
func UDPIn() chan<- core.UDPConn {
|
||||||
select {
|
return _udpQueue
|
||||||
case _udpQueue <- packet:
|
|
||||||
default:
|
|
||||||
log.Warnf("queue is currently full, packet will be dropped")
|
|
||||||
packet.Drop()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func process() {
|
func process() {
|
||||||
for i := 0; i < _numUDPWorkers; i++ {
|
for {
|
||||||
queue := _udpQueue
|
select {
|
||||||
go func() {
|
case conn := <-_tcpQueue:
|
||||||
for packet := range queue {
|
go handleTCPConn(conn)
|
||||||
handleUDP(packet)
|
case conn := <-_udpQueue:
|
||||||
|
go handleUDPConn(conn)
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
for conn := range _tcpQueue {
|
|
||||||
go handleTCP(conn)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
129
tunnel/udp.go
129
tunnel/udp.go
@@ -7,7 +7,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/xjasonlyu/tun2socks/v2/common/pool"
|
"github.com/xjasonlyu/tun2socks/v2/common/pool"
|
||||||
"github.com/xjasonlyu/tun2socks/v2/component/nat"
|
|
||||||
"github.com/xjasonlyu/tun2socks/v2/core"
|
"github.com/xjasonlyu/tun2socks/v2/core"
|
||||||
"github.com/xjasonlyu/tun2socks/v2/log"
|
"github.com/xjasonlyu/tun2socks/v2/log"
|
||||||
M "github.com/xjasonlyu/tun2socks/v2/metadata"
|
M "github.com/xjasonlyu/tun2socks/v2/metadata"
|
||||||
@@ -15,15 +14,8 @@ import (
|
|||||||
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
|
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// _udpSessionTimeout is the default timeout for each UDP session.
|
||||||
// _natTable uses source udp packet information
|
var _udpSessionTimeout = 60 * time.Second
|
||||||
// as key to store destination udp packetConn.
|
|
||||||
_natTable = nat.NewTable()
|
|
||||||
|
|
||||||
// _udpSessionTimeout is the default timeout for
|
|
||||||
// each UDP session.
|
|
||||||
_udpSessionTimeout = 60 * time.Second
|
|
||||||
)
|
|
||||||
|
|
||||||
func SetUDPTimeout(v int) {
|
func SetUDPTimeout(v int) {
|
||||||
_udpSessionTimeout = time.Duration(v) * time.Second
|
_udpSessionTimeout = time.Duration(v) * time.Second
|
||||||
@@ -33,98 +25,58 @@ func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
|
|||||||
return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager)
|
return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleUDP(packet core.UDPPacket) {
|
func handleUDPConn(uc core.UDPConn) {
|
||||||
id := packet.ID()
|
defer uc.Close()
|
||||||
|
|
||||||
|
var (
|
||||||
|
srcIP, srcPort = parseAddr(uc.RemoteAddr())
|
||||||
|
dstIP, dstPort = parseAddr(uc.LocalAddr())
|
||||||
|
)
|
||||||
metadata := &M.Metadata{
|
metadata := &M.Metadata{
|
||||||
Net: M.UDP,
|
Net: M.UDP,
|
||||||
SrcIP: net.IP(id.RemoteAddress),
|
SrcIP: srcIP,
|
||||||
SrcPort: id.RemotePort,
|
SrcPort: srcPort,
|
||||||
DstIP: net.IP(id.LocalAddress),
|
DstIP: dstIP,
|
||||||
DstPort: id.LocalPort,
|
DstPort: dstPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
generateNATKey := func(m *M.Metadata) string {
|
|
||||||
return m.SourceAddress() /* as Full Cone NAT Key */
|
|
||||||
}
|
|
||||||
key := generateNATKey(metadata)
|
|
||||||
|
|
||||||
handle := func(drop bool) bool {
|
|
||||||
pc := _natTable.Get(key)
|
|
||||||
if pc != nil {
|
|
||||||
handleUDPToRemote(packet, pc, metadata /* as net.Addr */, drop)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if handle(true /* drop */) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
lockKey := key + "-lock"
|
|
||||||
cond, loaded := _natTable.GetOrCreateLock(lockKey)
|
|
||||||
go func() {
|
|
||||||
if loaded {
|
|
||||||
cond.L.Lock()
|
|
||||||
cond.Wait()
|
|
||||||
handle(true) /* drop after sending data to remote */
|
|
||||||
cond.L.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
_natTable.Delete(lockKey)
|
|
||||||
cond.Broadcast()
|
|
||||||
}()
|
|
||||||
|
|
||||||
pc, err := proxy.DialUDP(metadata)
|
pc, err := proxy.DialUDP(metadata)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err)
|
log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr())
|
||||||
if dialerAddr, ok := pc.LocalAddr().(*net.UDPAddr); ok {
|
|
||||||
metadata.MidIP = dialerAddr.IP
|
|
||||||
metadata.MidPort = uint16(dialerAddr.Port)
|
|
||||||
} else { /* fallback */
|
|
||||||
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr().String())
|
|
||||||
}
|
|
||||||
|
|
||||||
pc = newUDPTracker(pc, metadata)
|
pc = newUDPTracker(pc, metadata)
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer pc.Close()
|
defer pc.Close()
|
||||||
defer packet.Drop()
|
|
||||||
defer _natTable.Delete(key)
|
|
||||||
|
|
||||||
handleUDPToLocal(packet, pc)
|
go handleUDPToRemote(uc, pc, metadata)
|
||||||
}()
|
handleUDPToLocal(uc, pc, metadata)
|
||||||
|
|
||||||
_natTable.Set(key, pc)
|
|
||||||
handle(false /* drop */)
|
|
||||||
}()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleUDPToRemote(packet core.UDPPacket, pc net.PacketConn, remote net.Addr, drop bool) {
|
func handleUDPToRemote(uc core.UDPConn, pc net.PacketConn, remote net.Addr) {
|
||||||
defer func() {
|
|
||||||
if drop {
|
|
||||||
packet.Drop()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil {
|
|
||||||
log.Warnf("[UDP] write to %s error: %v", remote, err)
|
|
||||||
}
|
|
||||||
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */
|
|
||||||
|
|
||||||
log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) {
|
|
||||||
buf := pool.Get(pool.MaxSegmentSize)
|
buf := pool.Get(pool.MaxSegmentSize)
|
||||||
defer pool.Put(buf)
|
defer pool.Put(buf)
|
||||||
|
|
||||||
for /* just loop */ {
|
for {
|
||||||
|
n, err := uc.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := pc.WriteTo(buf[:n], remote); err != nil {
|
||||||
|
log.Warnf("[UDP] write to %s error: %v", remote, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("[UDP] %s --> %s", uc.RemoteAddr(), remote)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleUDPToLocal(uc core.UDPConn, pc net.PacketConn, remote net.Addr) {
|
||||||
|
buf := pool.Get(pool.MaxSegmentSize)
|
||||||
|
defer pool.Put(buf)
|
||||||
|
|
||||||
|
for {
|
||||||
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout))
|
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout))
|
||||||
n, from, err := pc.ReadFrom(buf)
|
n, from, err := pc.ReadFrom(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -134,11 +86,14 @@ func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := packet.WriteBack(buf[:n], from); err != nil {
|
if from.Network() != remote.Network() || from.String() != remote.String() {
|
||||||
log.Warnf("[UDP] write back from %s error: %v", from, err)
|
log.Warnf("[UDP] drop unknown packet from %s", from)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("[UDP] %s <-- %s", packet.RemoteAddr(), from)
|
if _, err := uc.Write(buf[:n]); err != nil {
|
||||||
|
log.Warnf("[UDP] write back from %s error: %v", from, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -1,20 +0,0 @@
|
|||||||
package tunnel
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"strconv"
|
|
||||||
)
|
|
||||||
|
|
||||||
func max(a, b int) int {
|
|
||||||
if a > b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseAddr parses address to IP and port.
|
|
||||||
func parseAddr(addr string) (net.IP, uint16) {
|
|
||||||
host, portStr, _ := net.SplitHostPort(addr)
|
|
||||||
portInt, _ := strconv.ParseUint(portStr, 10, 16)
|
|
||||||
return net.ParseIP(host), uint16(portInt)
|
|
||||||
}
|
|
Reference in New Issue
Block a user