支持转发 udp

This commit is contained in:
xmdhs
2024-01-23 23:03:06 +08:00
parent 599d0a3dd6
commit fcc45b4ff0
5 changed files with 449 additions and 20 deletions

43
main.go
View File

@@ -5,6 +5,7 @@ import (
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"net/http"
@@ -25,6 +26,7 @@ var (
test bool
target string
comm string
udp bool
)
func init() {
@@ -32,8 +34,9 @@ func init() {
flag.StringVar(&localAddr, "l", "", "local addr")
flag.StringVar(&port, "p", "8086", "port")
flag.StringVar(&target, "d", "", "forward to target host")
flag.BoolVar(&test, "t", false, "test server")
flag.BoolVar(&test, "t", false, "test server (only tcp)")
flag.StringVar(&comm, "e", "", "run script for mapped address")
flag.BoolVar(&udp, "u", false, "udp")
flag.Parse()
}
@@ -67,7 +70,7 @@ func main() {
log.Println(err)
}
}
})
}, udp, test)
if err != nil {
log.Println(err)
}
@@ -75,12 +78,19 @@ func main() {
}
}
func openPort(ctx context.Context, target, localAddr string, portu uint16, stun string, finish func(netip.AddrPort)) error {
func openPort(ctx context.Context, target, localAddr string, portu uint16,
stun string, finish func(netip.AddrPort), udp bool, testserver bool) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
if target != "" {
l, err := natmap.Forward(ctx, portu, target, func(s string) {
var forward func(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error)
if udp {
forward = natmap.ForwardUdp
} else {
forward = natmap.Forward
}
l, err := forward(ctx, portu, target, func(s string) {
log.Println(s)
})
if err != nil {
@@ -88,7 +98,7 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16, stun
}
defer l.Close()
}
if test {
if testserver {
l, err := testServer(ctx, portu)
if err != nil {
return fmt.Errorf("openPort: %w", err)
@@ -96,10 +106,17 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16, stun
defer l.Close()
}
errCh := make(chan error, 1)
m, s, err := natmap.NatMap(ctx, stun, localAddr, uint16(portu), func(s error) {
var nmap func(ctx context.Context, stunAddr string, host string, port uint16, log func(error)) (*natmap.Map, netip.AddrPort, error)
if udp {
nmap = natmap.NatMapUdp
} else {
nmap = natmap.NatMap
}
m, s, err := nmap(ctx, stun, localAddr, uint16(portu), func(s error) {
cancel()
select {
case errCh <- ErrNatMap{err: s}:
case errCh <- s:
default:
}
})
@@ -117,18 +134,6 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16, stun
return nil
}
type ErrNatMap struct {
err error
}
func (e ErrNatMap) Error() string {
return e.err.Error()
}
func (e ErrNatMap) Unwrap() error {
return e.err
}
func testServer(ctx context.Context, port uint16) (net.Listener, error) {
s := http.Server{
ReadTimeout: 5 * time.Second,

395
natmap/forward.go Normal file
View File

@@ -0,0 +1,395 @@
// Package forward contains a UDP packet forwarder.
// https://github.com/gwangyi/udp-forward
package natmap
import (
"errors"
"log"
"net"
"sync"
"time"
)
type connection struct {
available chan struct{}
udp *net.UDPConn
lastActive time.Time
}
type Logger interface {
Println(v ...any)
}
// Forwarder represents a UDP packet forwarder.
type Forwarder struct {
src *net.UDPAddr
router Router
client *net.UDPAddr
listenerConn *net.UDPConn
connections map[string]*connection
connectionsMutex *sync.RWMutex
connectCallback func(addr string)
disconnectCallback func(addr string)
timeout time.Duration
closed bool
bufferSize int
logger Logger
}
// Router represents a router that gives the destination address.
type Router interface {
Route(*net.UDPAddr) *net.UDPAddr
}
type staticRouter struct {
*net.UDPAddr
}
func (r staticRouter) Route(*net.UDPAddr) *net.UDPAddr {
return r.UDPAddr
}
type funcRouter func(*net.UDPAddr) *net.UDPAddr
func (r funcRouter) Route(incoming *net.UDPAddr) *net.UDPAddr {
return r(incoming)
}
// config represents the configuration of Forwarder.
type config struct {
listenerFactory func() (*net.UDPConn, error)
router Router
timeout time.Duration
bufferSize int
logger Logger
}
// Option gives the way to customize the forwarder.
type Option func(*config) error
// WithAddr lets the new forwarder listen from given address.
func WithAddr(src string) Option {
return func(c *config) error {
srcAddr, err := net.ResolveUDPAddr("udp", src)
if err != nil {
return err
}
c.listenerFactory = func() (*net.UDPConn, error) {
return net.ListenUDP("udp", srcAddr)
}
return nil
}
}
// WithConn lets the new forwarder to use given conn instead of new one.
func WithConn(conn *net.UDPConn) Option {
return func(c *config) error {
c.listenerFactory = func() (*net.UDPConn, error) {
return conn, nil
}
return nil
}
}
// WithDestination lets the new forwarder forward packets to the given address.
func WithDestination(dest string) Option {
return func(c *config) error {
destAddr, err := net.ResolveUDPAddr("udp", dest)
if err != nil {
return err
}
c.router = staticRouter{UDPAddr: destAddr}
return nil
}
}
// WithRouter lets the new forwarder forward packets according to the given router.
func WithRouter(router Router) Option {
return func(c *config) error {
c.router = router
return nil
}
}
// WithRouterFunc does the same as WithRouter, but with a function.
func WithRouterFunc(router func(*net.UDPAddr) *net.UDPAddr) Option {
return WithRouter(funcRouter(router))
}
// WithTimeout sets the timeout.
// No interaction more than the timeout will remove the connection from the NAT
// table.
func WithTimeout(timeout time.Duration) Option {
return func(c *config) error {
c.timeout = timeout
return nil
}
}
// WithBufferSize sets the buffer size that is used by forwarding.
// Larger packet can be discarded.
func WithBufferSize(size int) Option {
return func(c *config) error {
c.bufferSize = size
return nil
}
}
// WithLogger sets a logger.
func WithLogger(logger Logger) Option {
return func(c *config) error {
c.logger = logger
return nil
}
}
type emptyLogger struct{}
func (emptyLogger) Println(v ...any) {}
// WithoutLogger lets forwarder not log anything.
func WithoutLogger() Option {
return WithLogger(emptyLogger{})
}
// DefaultTimeout is the default timeout period of inactivity for convenience
// sake. It is equivelant to 5 minutes.
const DefaultTimeout = time.Minute * 5
// Forward forwards UDP packets from the src address to the dst address, with a
// timeout to "disconnect" clients after the timeout period of inactivity. It
// implements a reverse NAT and thus supports multiple seperate users. Forward
// is also asynchronous.
func forward(options ...Option) (*Forwarder, error) {
config := &config{
timeout: DefaultTimeout,
bufferSize: 4096,
logger: log.Default(),
}
options = append([]Option{WithAddr(":")}, options...)
for _, opt := range options {
if err := opt(config); err != nil {
return nil, err
}
}
forwarder := new(Forwarder)
forwarder.connectCallback = func(addr string) {}
forwarder.disconnectCallback = func(addr string) {}
forwarder.connectionsMutex = new(sync.RWMutex)
forwarder.connections = make(map[string]*connection)
forwarder.timeout = config.timeout
forwarder.router = config.router
forwarder.bufferSize = config.bufferSize
forwarder.logger = config.logger
var err error
forwarder.listenerConn, err = config.listenerFactory()
if err != nil {
return nil, err
}
forwarder.src, _ = forwarder.listenerConn.LocalAddr().(*net.UDPAddr)
go forwarder.janitor()
go forwarder.run()
return forwarder, nil
}
func (f *Forwarder) run() {
for {
buf := make([]byte, f.bufferSize)
oob := make([]byte, f.bufferSize)
n, _, _, addr, err := f.listenerConn.ReadMsgUDP(buf, oob)
if err != nil {
f.logger.Println("forward: failed to read, terminating:", err)
return
}
go f.handle(buf[:n], addr)
}
}
func (f *Forwarder) janitor() {
for !f.closed {
time.Sleep(f.timeout)
var keysToDelete []string
f.connectionsMutex.RLock()
for k, conn := range f.connections {
if conn.lastActive.Before(time.Now().Add(-f.timeout)) {
keysToDelete = append(keysToDelete, k)
}
}
f.connectionsMutex.RUnlock()
f.connectionsMutex.Lock()
for _, k := range keysToDelete {
f.connections[k].udp.Close()
delete(f.connections, k)
}
f.connectionsMutex.Unlock()
for _, k := range keysToDelete {
f.disconnectCallback(k)
}
}
}
func (f *Forwarder) handle(data []byte, addr *net.UDPAddr) {
f.connectionsMutex.Lock()
conn, found := f.connections[addr.String()]
if !found {
f.connections[addr.String()] = &connection{
available: make(chan struct{}),
udp: nil,
lastActive: time.Now(),
}
}
f.connectionsMutex.Unlock()
if !found {
var udpConn *net.UDPConn
var err error
dst := f.router.Route(addr)
if dst == nil {
f.connectionsMutex.Lock()
delete(f.connections, addr.String())
f.connectionsMutex.Unlock()
return
}
if dst.IP.To4()[0] == 127 {
// log.Println("using local listener")
laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:")
udpConn, err = net.DialUDP("udp", laddr, dst)
} else {
udpConn, err = net.DialUDP("udp", nil, dst)
}
if err != nil {
f.logger.Println("udp-forward: failed to dial:", err)
delete(f.connections, addr.String())
return
}
f.connectionsMutex.Lock()
f.connections[addr.String()].udp = udpConn
f.connections[addr.String()].lastActive = time.Now()
close(f.connections[addr.String()].available)
f.connectionsMutex.Unlock()
f.connectCallback(addr.String())
_, _, err = udpConn.WriteMsgUDP(data, nil, nil)
if err != nil {
f.logger.Println("udp-forward: error sending initial packet to client", err)
}
for {
// log.Println("in loop to read from NAT connection to servers")
buf := make([]byte, f.bufferSize)
oob := make([]byte, f.bufferSize)
n, _, _, _, err := udpConn.ReadMsgUDP(buf, oob)
if err != nil {
f.connectionsMutex.Lock()
udpConn.Close()
delete(f.connections, addr.String())
f.connectionsMutex.Unlock()
f.disconnectCallback(addr.String())
f.logger.Println("udp-forward: abnormal read, closing:", err)
return
}
// log.Println("sent packet to client")
_, _, err = f.listenerConn.WriteMsgUDP(buf[:n], nil, addr)
if err != nil {
f.logger.Println("udp-forward: error sending packet to client:", err)
}
}
// unreachable
}
<-conn.available
// log.Println("sent packet to server", conn.udp.RemoteAddr())
_, _, err := conn.udp.WriteMsgUDP(data, nil, nil)
if err != nil {
f.logger.Println("udp-forward: error sending packet to server:", err)
}
shouldChangeTime := false
f.connectionsMutex.RLock()
if _, found := f.connections[addr.String()]; found {
if f.connections[addr.String()].lastActive.Before(
time.Now().Add(f.timeout / 4)) {
shouldChangeTime = true
}
}
f.connectionsMutex.RUnlock()
if shouldChangeTime {
f.connectionsMutex.Lock()
// Make sure it still exists
if _, found := f.connections[addr.String()]; found {
connWrapper := f.connections[addr.String()]
connWrapper.lastActive = time.Now()
f.connections[addr.String()] = connWrapper
}
f.connectionsMutex.Unlock()
}
}
// Close stops the forwarder.
func (f *Forwarder) Close() error {
var errs error
f.connectionsMutex.Lock()
f.closed = true
for _, conn := range f.connections {
err := conn.udp.Close()
if err != nil {
errs = errors.Join(errs, err)
}
}
err := f.listenerConn.Close()
if err != nil {
errs = errors.Join(errs, err)
}
f.connectionsMutex.Unlock()
return errs
}
// OnConnect can be called with a callback function to be called whenever a
// new client connects.
func (f *Forwarder) OnConnect(callback func(addr string)) {
f.connectCallback = callback
}
// OnDisconnect can be called with a callback function to be called whenever a
// new client disconnects (after 5 minutes of inactivity).
func (f *Forwarder) OnDisconnect(callback func(addr string)) {
f.disconnectCallback = callback
}
// Connected returns the list of connected clients in IP:port form.
func (f *Forwarder) Connected() []string {
f.connectionsMutex.Lock()
defer f.connectionsMutex.Unlock()
results := make([]string, 0, len(f.connections))
for key := range f.connections {
results = append(results, key)
}
return results
}
// LocalAddr returns LocalAddr of listening connection.
func (f *Forwarder) LocalAddr() *net.UDPAddr {
addr, _ := f.listenerConn.LocalAddr().(*net.UDPAddr)
return addr
}

View File

@@ -106,7 +106,7 @@ func GetLocalAddr() (net.Addr, error) {
return l.LocalAddr(), nil
}
func Forward(ctx context.Context, port uint16, target string, log func(string)) (net.Listener, error) {
func Forward(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) {
l, err := reuse.Listen(ctx, "tcp", "0.0.0.0:"+strconv.FormatUint(uint64(port), 10))
if err != nil {
return nil, fmt.Errorf("Forward: %w", err)

View File

@@ -3,9 +3,11 @@ package natmap
import (
"context"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
"time"
"github.com/xmdhs/natupnp/reuse"
@@ -54,3 +56,26 @@ func keepaliveUDP(ctx context.Context, port uint16, log func(error)) {
}
}
}
type logger struct {
log func(string)
}
func (l logger) Println(v ...any) {
build := &strings.Builder{}
fmt.Fprint(build, v...)
l.log(build.String())
}
func ForwardUdp(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) {
lc, err := reuse.ListenPacket(ctx, "udp", "0.0.0.0:"+strconv.FormatUint(uint64(port), 10))
if err != nil {
return nil, err
}
f, err := forward(WithLogger(logger{log}), WithConn(lc.(*net.UDPConn)), WithDestination(target))
if err != nil {
return nil, fmt.Errorf("ForwardUdp: %w", err)
}
return f, nil
}

View File

@@ -27,3 +27,7 @@ var listenConfig = net.ListenConfig{
func Listen(ctx context.Context, network, address string) (net.Listener, error) {
return listenConfig.Listen(ctx, network, address)
}
func ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) {
return listenConfig.ListenPacket(ctx, network, address)
}