mirror of
https://github.com/xjasonlyu/tun2socks.git
synced 2025-10-07 09:41:09 +08:00
Refactor: optimize UDP module
Symmetric NAT support for now.
This commit is contained in:
@@ -2,33 +2,15 @@ package core
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
)
|
||||
|
||||
// TCPConn implements the net.Conn interface.
|
||||
type TCPConn interface {
|
||||
net.Conn
|
||||
ID() *stack.TransportEndpointID
|
||||
}
|
||||
|
||||
type UDPPacket interface {
|
||||
// Data get the payload of UDP Packet.
|
||||
Data() []byte
|
||||
|
||||
// Drop call after packet is used, could release resources in this function.
|
||||
Drop()
|
||||
|
||||
// ID returns the transport endpoint id of packet.
|
||||
ID() *stack.TransportEndpointID
|
||||
|
||||
// LocalAddr returns the source IP/Port of packet.
|
||||
LocalAddr() net.Addr
|
||||
|
||||
// RemoteAddr returns the destination IP/Port of packet.
|
||||
RemoteAddr() net.Addr
|
||||
|
||||
// WriteBack writes the payload with source IP/Port equals addr
|
||||
// - variable source IP/Port is important to STUN
|
||||
// - if addr is not provided, WriteBack will write out UDP packet with SourceIP/Port equals to original Target.
|
||||
WriteBack([]byte, net.Addr) (int, error)
|
||||
// UDPConn implements net.Conn and net.PacketConn.
|
||||
type UDPConn interface {
|
||||
net.Conn
|
||||
net.PacketConn
|
||||
}
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package core
|
||||
|
||||
// Handler is a TCP/UDP connection handler that implements
|
||||
// HandleTCPConn and HandleUDPConn methods.
|
||||
type Handler interface {
|
||||
Add(TCPConn)
|
||||
AddPacket(UDPPacket)
|
||||
HandleTCPConn(TCPConn)
|
||||
HandleUDPConn(UDPConn)
|
||||
}
|
||||
|
@@ -2,12 +2,10 @@ package stack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
@@ -36,10 +34,9 @@ func withTCPHandler() Option {
|
||||
return func(s *Stack) error {
|
||||
tcpForwarder := tcp.NewForwarder(s.Stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
id := r.ID()
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
// prevent potential half-open TCP connection leak.
|
||||
// RST: prevent potential half-open TCP connection leak.
|
||||
r.Complete(true)
|
||||
return
|
||||
}
|
||||
@@ -47,11 +44,7 @@ func withTCPHandler() Option {
|
||||
|
||||
setKeepalive(ep)
|
||||
|
||||
conn := &tcpConn{
|
||||
Conn: gonet.NewTCPConn(&wq, ep),
|
||||
id: &id,
|
||||
}
|
||||
s.handler.Add(conn)
|
||||
s.handler.HandleTCPConn(gonet.NewTCPConn(&wq, ep))
|
||||
})
|
||||
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||
return nil
|
||||
@@ -72,12 +65,3 @@ func setKeepalive(ep tcpip.Endpoint) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type tcpConn struct {
|
||||
net.Conn
|
||||
id *stack.TransportEndpointID
|
||||
}
|
||||
|
||||
func (c *tcpConn) ID() *stack.TransportEndpointID {
|
||||
return c.id
|
||||
}
|
||||
|
@@ -1,192 +1,24 @@
|
||||
package stack
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
// udpNoChecksum disables UDP checksum if set to true.
|
||||
udpNoChecksum = true
|
||||
"gvisor.dev/gvisor/pkg/waiter"
|
||||
)
|
||||
|
||||
func withUDPHandler() Option {
|
||||
return func(s *Stack) error {
|
||||
udpHandlePacket := func(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go HandlePacket.
|
||||
udpHdr := header.UDP(pkt.TransportHeader().View())
|
||||
if int(udpHdr.Length()) > pkt.Data().Size()+header.UDPMinimumSize {
|
||||
// Malformed packet.
|
||||
s.Stats().UDP.MalformedPacketsReceived.Increment()
|
||||
return true
|
||||
udpForwarder := udp.NewForwarder(s.Stack, func(r *udp.ForwarderRequest) {
|
||||
var wq waiter.Queue
|
||||
ep, err := r.CreateEndpoint(&wq)
|
||||
if err != nil {
|
||||
// TODO: handler errors in the future.
|
||||
return
|
||||
}
|
||||
|
||||
if !verifyChecksum(udpHdr, pkt) {
|
||||
// Checksum error.
|
||||
s.Stats().UDP.ChecksumErrors.Increment()
|
||||
return true
|
||||
}
|
||||
|
||||
s.Stats().UDP.PacketsReceived.Increment()
|
||||
|
||||
packet := &udpPacket{
|
||||
s: s,
|
||||
id: &id,
|
||||
data: pkt.Data().ExtractVV(),
|
||||
nicID: pkt.NICID,
|
||||
netHdr: pkt.Network(),
|
||||
netProto: pkt.NetworkProtocolNumber,
|
||||
}
|
||||
|
||||
s.handler.AddPacket(packet)
|
||||
return true
|
||||
}
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpHandlePacket)
|
||||
s.handler.HandleUDPConn(gonet.NewUDPConn(s.Stack, &wq, ep))
|
||||
})
|
||||
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type udpPacket struct {
|
||||
s *Stack
|
||||
id *stack.TransportEndpointID
|
||||
data buffer.VectorisedView
|
||||
nicID tcpip.NICID
|
||||
netHdr header.Network
|
||||
netProto tcpip.NetworkProtocolNumber
|
||||
}
|
||||
|
||||
func (p *udpPacket) Data() []byte {
|
||||
return p.data.ToView()
|
||||
}
|
||||
|
||||
func (p *udpPacket) Drop() {}
|
||||
|
||||
func (p *udpPacket) ID() *stack.TransportEndpointID {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *udpPacket) LocalAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP(p.id.LocalAddress), Port: int(p.id.LocalPort)}
|
||||
}
|
||||
|
||||
func (p *udpPacket) RemoteAddr() net.Addr {
|
||||
return &net.UDPAddr{IP: net.IP(p.id.RemoteAddress), Port: int(p.id.RemotePort)}
|
||||
}
|
||||
|
||||
func (p *udpPacket) WriteBack(b []byte, addr net.Addr) (int, error) {
|
||||
v := buffer.View(b)
|
||||
if len(v) > header.UDPMaximumPacketSize {
|
||||
// Payload can't possibly fit in a packet.
|
||||
return 0, fmt.Errorf("%s", &tcpip.ErrMessageTooLong{})
|
||||
}
|
||||
|
||||
var (
|
||||
localAddress tcpip.Address
|
||||
localPort uint16
|
||||
)
|
||||
|
||||
if udpAddr, ok := addr.(*net.UDPAddr); !ok {
|
||||
localAddress = p.netHdr.DestinationAddress()
|
||||
localPort = p.id.LocalPort
|
||||
} else if ipv4 := udpAddr.IP.To4(); ipv4 != nil {
|
||||
localAddress = tcpip.Address(ipv4)
|
||||
localPort = uint16(udpAddr.Port)
|
||||
} else {
|
||||
localAddress = tcpip.Address(udpAddr.IP)
|
||||
localPort = uint16(udpAddr.Port)
|
||||
}
|
||||
|
||||
route, err := p.s.FindRoute(p.nicID, localAddress, p.netHdr.SourceAddress(), p.netProto, false /* multicastLoop */)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%#v find route: %s", p.id, err)
|
||||
}
|
||||
defer route.Release()
|
||||
|
||||
data := v.ToVectorisedView()
|
||||
if err = sendUDP(route, data, localPort, p.id.RemotePort, udpNoChecksum); err != nil {
|
||||
return 0, fmt.Errorf("%v", err)
|
||||
}
|
||||
return data.Size(), nil
|
||||
}
|
||||
|
||||
// sendUDP sends a UDP segment via the provided network endpoint and under the
|
||||
// provided identity.
|
||||
func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, noChecksum bool) tcpip.Error {
|
||||
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
ReserveHeaderBytes: header.UDPMinimumSize + int(r.MaxHeaderLength()),
|
||||
Data: data,
|
||||
})
|
||||
defer pkt.DecRef()
|
||||
|
||||
// Initialize the UDP header.
|
||||
udpHdr := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
|
||||
pkt.TransportProtocolNumber = udp.ProtocolNumber
|
||||
|
||||
length := uint16(pkt.Size())
|
||||
udpHdr.Encode(&header.UDPFields{
|
||||
SrcPort: localPort,
|
||||
DstPort: remotePort,
|
||||
Length: length,
|
||||
})
|
||||
|
||||
// Set the checksum field unless TX checksum offload is enabled.
|
||||
// On IPv4, UDP checksum is optional, and a zero value indicates the
|
||||
// transmitter skipped the checksum generation (RFC768).
|
||||
// On IPv6, UDP checksum is not optional (RFC2460 Section 8.1).
|
||||
if r.RequiresTXTransportChecksum() &&
|
||||
(!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) {
|
||||
xsum := r.PseudoHeaderChecksum(udp.ProtocolNumber, length)
|
||||
for _, v := range data.Views() {
|
||||
xsum = header.Checksum(v, xsum)
|
||||
}
|
||||
udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum))
|
||||
}
|
||||
|
||||
ttl := r.DefaultTTL()
|
||||
|
||||
if err := r.WritePacket(stack.NetworkHeaderParams{
|
||||
Protocol: udp.ProtocolNumber,
|
||||
TTL: ttl,
|
||||
TOS: 0, /* default */
|
||||
}, pkt); err != nil {
|
||||
r.Stats().UDP.PacketSendErrors.Increment()
|
||||
return err
|
||||
}
|
||||
|
||||
// Track count of packets sent.
|
||||
r.Stats().UDP.PacketsSent.Increment()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ref: gVisor pkg/tcpip/transport/udp/endpoint.go verifyChecksum.
|
||||
// verifyChecksum verifies the checksum unless RX checksum offload is enabled.
|
||||
func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool {
|
||||
if pkt.RXTransportChecksumValidated {
|
||||
return true
|
||||
}
|
||||
|
||||
// On IPv4, UDP checksum is optional, and a zero value means the transmitter
|
||||
// omitted the checksum generation, as per RFC 768:
|
||||
//
|
||||
// An all zero transmitted checksum value means that the transmitter
|
||||
// generated no checksum (for debugging or for higher level protocols that
|
||||
// don't care).
|
||||
//
|
||||
// On IPv6, UDP checksum is not optional, as per RFC 2460 Section 8.1:
|
||||
//
|
||||
// Unlike IPv4, when UDP packets are originated by an IPv6 node, the UDP
|
||||
// checksum is not optional.
|
||||
if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber && hdr.Checksum() == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
netHdr := pkt.Network()
|
||||
payloadChecksum := pkt.Data().AsRange().Checksum()
|
||||
return hdr.IsChecksumValid(netHdr.SourceAddress(), netHdr.DestinationAddress(), payloadChecksum)
|
||||
}
|
||||
|
@@ -9,10 +9,10 @@ var _ core.Handler = (*fakeTunnel)(nil)
|
||||
|
||||
type fakeTunnel struct{}
|
||||
|
||||
func (*fakeTunnel) Add(conn core.TCPConn) {
|
||||
tunnel.Add(conn)
|
||||
func (*fakeTunnel) HandleTCPConn(conn core.TCPConn) {
|
||||
tunnel.TCPIn() <- conn
|
||||
}
|
||||
|
||||
func (*fakeTunnel) AddPacket(packet core.UDPPacket) {
|
||||
tunnel.AddPacket(packet)
|
||||
func (*fakeTunnel) HandleUDPConn(conn core.UDPConn) {
|
||||
tunnel.UDPIn() <- conn
|
||||
}
|
||||
|
25
tunnel/addr.go
Normal file
25
tunnel/addr.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// parseAddr parses net.Addr to IP and port.
|
||||
func parseAddr(addr net.Addr) (net.IP, uint16) {
|
||||
switch v := addr.(type) {
|
||||
case *net.TCPAddr:
|
||||
return v.IP, uint16(v.Port)
|
||||
case *net.UDPAddr:
|
||||
return v.IP, uint16(v.Port)
|
||||
default:
|
||||
return parseAddrString(addr.String())
|
||||
}
|
||||
}
|
||||
|
||||
// parseAddrString parses address string to IP and port.
|
||||
func parseAddrString(addr string) (net.IP, uint16) {
|
||||
host, port, _ := net.SplitHostPort(addr)
|
||||
portInt, _ := strconv.ParseUint(port, 10, 16)
|
||||
return net.ParseIP(host), uint16(portInt)
|
||||
}
|
@@ -22,16 +22,19 @@ func newTCPTracker(conn net.Conn, metadata *M.Metadata) net.Conn {
|
||||
return statistic.NewTCPTracker(conn, metadata, statistic.DefaultManager)
|
||||
}
|
||||
|
||||
func handleTCP(localConn core.TCPConn) {
|
||||
func handleTCPConn(localConn core.TCPConn) {
|
||||
defer localConn.Close()
|
||||
|
||||
id := localConn.ID()
|
||||
var (
|
||||
srcIP, srcPort = parseAddr(localConn.RemoteAddr())
|
||||
dstIP, dstPort = parseAddr(localConn.LocalAddr())
|
||||
)
|
||||
metadata := &M.Metadata{
|
||||
Net: M.TCP,
|
||||
SrcIP: net.IP(id.RemoteAddress),
|
||||
SrcPort: id.RemotePort,
|
||||
DstIP: net.IP(id.LocalAddress),
|
||||
DstPort: id.LocalPort,
|
||||
SrcIP: srcIP,
|
||||
SrcPort: srcPort,
|
||||
DstIP: dstIP,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
targetConn, err := proxy.Dial(metadata)
|
||||
@@ -39,13 +42,7 @@ func handleTCP(localConn core.TCPConn) {
|
||||
log.Warnf("[TCP] dial %s error: %v", metadata.DestinationAddress(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if dialerAddr, ok := targetConn.LocalAddr().(*net.TCPAddr); ok {
|
||||
metadata.MidIP = dialerAddr.IP
|
||||
metadata.MidPort = uint16(dialerAddr.Port)
|
||||
} else { /* fallback */
|
||||
metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr().String())
|
||||
}
|
||||
metadata.MidIP, metadata.MidPort = parseAddr(targetConn.LocalAddr())
|
||||
|
||||
targetConn = newTCPTracker(targetConn, metadata)
|
||||
defer targetConn.Close()
|
||||
|
@@ -1,55 +1,36 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
|
||||
"github.com/xjasonlyu/tun2socks/v2/core"
|
||||
"github.com/xjasonlyu/tun2socks/v2/log"
|
||||
)
|
||||
|
||||
const (
|
||||
// maxUDPQueueSize is the max number of UDP packets
|
||||
// could be buffered. if queue is full, upcoming packets
|
||||
// would be dropped util queue is ready again.
|
||||
maxUDPQueueSize = 1 << 9
|
||||
)
|
||||
|
||||
// Unbuffered TCP/UDP queues.
|
||||
var (
|
||||
_tcpQueue = make(chan core.TCPConn) /* unbuffered */
|
||||
_udpQueue = make(chan core.UDPPacket, maxUDPQueueSize)
|
||||
_numUDPWorkers = max(runtime.GOMAXPROCS(0), 4 /* at least 4 workers */)
|
||||
_tcpQueue = make(chan core.TCPConn)
|
||||
_udpQueue = make(chan core.UDPConn)
|
||||
)
|
||||
|
||||
func init() {
|
||||
go process()
|
||||
}
|
||||
|
||||
// Add adds tcpConn to tcpQueue.
|
||||
func Add(conn core.TCPConn) {
|
||||
_tcpQueue <- conn
|
||||
// TCPIn return fan-in TCP queue.
|
||||
func TCPIn() chan<- core.TCPConn {
|
||||
return _tcpQueue
|
||||
}
|
||||
|
||||
// AddPacket adds udpPacket to udpQueue.
|
||||
func AddPacket(packet core.UDPPacket) {
|
||||
select {
|
||||
case _udpQueue <- packet:
|
||||
default:
|
||||
log.Warnf("queue is currently full, packet will be dropped")
|
||||
packet.Drop()
|
||||
}
|
||||
// UDPIn return fan-in UDP queue.
|
||||
func UDPIn() chan<- core.UDPConn {
|
||||
return _udpQueue
|
||||
}
|
||||
|
||||
func process() {
|
||||
for i := 0; i < _numUDPWorkers; i++ {
|
||||
queue := _udpQueue
|
||||
go func() {
|
||||
for packet := range queue {
|
||||
handleUDP(packet)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
for conn := range _tcpQueue {
|
||||
go handleTCP(conn)
|
||||
for {
|
||||
select {
|
||||
case conn := <-_tcpQueue:
|
||||
go handleTCPConn(conn)
|
||||
case conn := <-_udpQueue:
|
||||
go handleUDPConn(conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
139
tunnel/udp.go
139
tunnel/udp.go
@@ -7,7 +7,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/xjasonlyu/tun2socks/v2/common/pool"
|
||||
"github.com/xjasonlyu/tun2socks/v2/component/nat"
|
||||
"github.com/xjasonlyu/tun2socks/v2/core"
|
||||
"github.com/xjasonlyu/tun2socks/v2/log"
|
||||
M "github.com/xjasonlyu/tun2socks/v2/metadata"
|
||||
@@ -15,15 +14,8 @@ import (
|
||||
"github.com/xjasonlyu/tun2socks/v2/tunnel/statistic"
|
||||
)
|
||||
|
||||
var (
|
||||
// _natTable uses source udp packet information
|
||||
// as key to store destination udp packetConn.
|
||||
_natTable = nat.NewTable()
|
||||
|
||||
// _udpSessionTimeout is the default timeout for
|
||||
// each UDP session.
|
||||
_udpSessionTimeout = 60 * time.Second
|
||||
)
|
||||
// _udpSessionTimeout is the default timeout for each UDP session.
|
||||
var _udpSessionTimeout = 60 * time.Second
|
||||
|
||||
func SetUDPTimeout(v int) {
|
||||
_udpSessionTimeout = time.Duration(v) * time.Second
|
||||
@@ -33,98 +25,58 @@ func newUDPTracker(conn net.PacketConn, metadata *M.Metadata) net.PacketConn {
|
||||
return statistic.NewUDPTracker(conn, metadata, statistic.DefaultManager)
|
||||
}
|
||||
|
||||
func handleUDP(packet core.UDPPacket) {
|
||||
id := packet.ID()
|
||||
func handleUDPConn(uc core.UDPConn) {
|
||||
defer uc.Close()
|
||||
|
||||
var (
|
||||
srcIP, srcPort = parseAddr(uc.RemoteAddr())
|
||||
dstIP, dstPort = parseAddr(uc.LocalAddr())
|
||||
)
|
||||
metadata := &M.Metadata{
|
||||
Net: M.UDP,
|
||||
SrcIP: net.IP(id.RemoteAddress),
|
||||
SrcPort: id.RemotePort,
|
||||
DstIP: net.IP(id.LocalAddress),
|
||||
DstPort: id.LocalPort,
|
||||
SrcIP: srcIP,
|
||||
SrcPort: srcPort,
|
||||
DstIP: dstIP,
|
||||
DstPort: dstPort,
|
||||
}
|
||||
|
||||
generateNATKey := func(m *M.Metadata) string {
|
||||
return m.SourceAddress() /* as Full Cone NAT Key */
|
||||
}
|
||||
key := generateNATKey(metadata)
|
||||
|
||||
handle := func(drop bool) bool {
|
||||
pc := _natTable.Get(key)
|
||||
if pc != nil {
|
||||
handleUDPToRemote(packet, pc, metadata /* as net.Addr */, drop)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
if handle(true /* drop */) {
|
||||
pc, err := proxy.DialUDP(metadata)
|
||||
if err != nil {
|
||||
log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err)
|
||||
return
|
||||
}
|
||||
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr())
|
||||
|
||||
lockKey := key + "-lock"
|
||||
cond, loaded := _natTable.GetOrCreateLock(lockKey)
|
||||
go func() {
|
||||
if loaded {
|
||||
cond.L.Lock()
|
||||
cond.Wait()
|
||||
handle(true) /* drop after sending data to remote */
|
||||
cond.L.Unlock()
|
||||
return
|
||||
}
|
||||
pc = newUDPTracker(pc, metadata)
|
||||
defer pc.Close()
|
||||
|
||||
defer func() {
|
||||
_natTable.Delete(lockKey)
|
||||
cond.Broadcast()
|
||||
}()
|
||||
|
||||
pc, err := proxy.DialUDP(metadata)
|
||||
if err != nil {
|
||||
log.Warnf("[UDP] dial %s error: %v", metadata.DestinationAddress(), err)
|
||||
return
|
||||
}
|
||||
|
||||
if dialerAddr, ok := pc.LocalAddr().(*net.UDPAddr); ok {
|
||||
metadata.MidIP = dialerAddr.IP
|
||||
metadata.MidPort = uint16(dialerAddr.Port)
|
||||
} else { /* fallback */
|
||||
metadata.MidIP, metadata.MidPort = parseAddr(pc.LocalAddr().String())
|
||||
}
|
||||
|
||||
pc = newUDPTracker(pc, metadata)
|
||||
|
||||
go func() {
|
||||
defer pc.Close()
|
||||
defer packet.Drop()
|
||||
defer _natTable.Delete(key)
|
||||
|
||||
handleUDPToLocal(packet, pc)
|
||||
}()
|
||||
|
||||
_natTable.Set(key, pc)
|
||||
handle(false /* drop */)
|
||||
}()
|
||||
go handleUDPToRemote(uc, pc, metadata)
|
||||
handleUDPToLocal(uc, pc, metadata)
|
||||
}
|
||||
|
||||
func handleUDPToRemote(packet core.UDPPacket, pc net.PacketConn, remote net.Addr, drop bool) {
|
||||
defer func() {
|
||||
if drop {
|
||||
packet.Drop()
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := pc.WriteTo(packet.Data() /* data */, remote); err != nil {
|
||||
log.Warnf("[UDP] write to %s error: %v", remote, err)
|
||||
}
|
||||
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout)) /* reset timeout */
|
||||
|
||||
log.Infof("[UDP] %s --> %s", packet.RemoteAddr(), remote)
|
||||
}
|
||||
|
||||
func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) {
|
||||
func handleUDPToRemote(uc core.UDPConn, pc net.PacketConn, remote net.Addr) {
|
||||
buf := pool.Get(pool.MaxSegmentSize)
|
||||
defer pool.Put(buf)
|
||||
|
||||
for /* just loop */ {
|
||||
for {
|
||||
n, err := uc.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := pc.WriteTo(buf[:n], remote); err != nil {
|
||||
log.Warnf("[UDP] write to %s error: %v", remote, err)
|
||||
}
|
||||
|
||||
log.Infof("[UDP] %s --> %s", uc.RemoteAddr(), remote)
|
||||
}
|
||||
}
|
||||
|
||||
func handleUDPToLocal(uc core.UDPConn, pc net.PacketConn, remote net.Addr) {
|
||||
buf := pool.Get(pool.MaxSegmentSize)
|
||||
defer pool.Put(buf)
|
||||
|
||||
for {
|
||||
pc.SetReadDeadline(time.Now().Add(_udpSessionTimeout))
|
||||
n, from, err := pc.ReadFrom(buf)
|
||||
if err != nil {
|
||||
@@ -134,11 +86,14 @@ func handleUDPToLocal(packet core.UDPPacket, pc net.PacketConn) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, err := packet.WriteBack(buf[:n], from); err != nil {
|
||||
log.Warnf("[UDP] write back from %s error: %v", from, err)
|
||||
if from.Network() != remote.Network() || from.String() != remote.String() {
|
||||
log.Warnf("[UDP] drop unknown packet from %s", from)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("[UDP] %s <-- %s", packet.RemoteAddr(), from)
|
||||
if _, err := uc.Write(buf[:n]); err != nil {
|
||||
log.Warnf("[UDP] write back from %s error: %v", from, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@@ -1,20 +0,0 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
func max(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// parseAddr parses address to IP and port.
|
||||
func parseAddr(addr string) (net.IP, uint16) {
|
||||
host, portStr, _ := net.SplitHostPort(addr)
|
||||
portInt, _ := strconv.ParseUint(portStr, 10, 16)
|
||||
return net.ParseIP(host), uint16(portInt)
|
||||
}
|
Reference in New Issue
Block a user