diff --git a/main.go b/main.go index 337bb24..2726079 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "fmt" + "io" "log" "net" "net/http" @@ -25,6 +26,7 @@ var ( test bool target string comm string + udp bool ) func init() { @@ -32,8 +34,9 @@ func init() { flag.StringVar(&localAddr, "l", "", "local addr") flag.StringVar(&port, "p", "8086", "port") flag.StringVar(&target, "d", "", "forward to target host") - flag.BoolVar(&test, "t", false, "test server") + flag.BoolVar(&test, "t", false, "test server (only tcp)") flag.StringVar(&comm, "e", "", "run script for mapped address") + flag.BoolVar(&udp, "u", false, "udp") flag.Parse() } @@ -67,7 +70,7 @@ func main() { log.Println(err) } } - }) + }, udp, test) if err != nil { log.Println(err) } @@ -75,12 +78,19 @@ func main() { } } -func openPort(ctx context.Context, target, localAddr string, portu uint16, stun string, finish func(netip.AddrPort)) error { +func openPort(ctx context.Context, target, localAddr string, portu uint16, + stun string, finish func(netip.AddrPort), udp bool, testserver bool) error { ctx, cancel := context.WithCancel(ctx) defer cancel() if target != "" { - l, err := natmap.Forward(ctx, portu, target, func(s string) { + var forward func(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) + if udp { + forward = natmap.ForwardUdp + } else { + forward = natmap.Forward + } + l, err := forward(ctx, portu, target, func(s string) { log.Println(s) }) if err != nil { @@ -88,7 +98,7 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16, stun } defer l.Close() } - if test { + if testserver { l, err := testServer(ctx, portu) if err != nil { return fmt.Errorf("openPort: %w", err) @@ -96,10 +106,17 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16, stun defer l.Close() } errCh := make(chan error, 1) - m, s, err := natmap.NatMap(ctx, stun, localAddr, uint16(portu), func(s error) { + var nmap func(ctx context.Context, stunAddr string, host string, port uint16, log func(error)) (*natmap.Map, netip.AddrPort, error) + if udp { + nmap = natmap.NatMapUdp + } else { + nmap = natmap.NatMap + } + + m, s, err := nmap(ctx, stun, localAddr, uint16(portu), func(s error) { cancel() select { - case errCh <- ErrNatMap{err: s}: + case errCh <- s: default: } }) @@ -117,18 +134,6 @@ func openPort(ctx context.Context, target, localAddr string, portu uint16, stun return nil } -type ErrNatMap struct { - err error -} - -func (e ErrNatMap) Error() string { - return e.err.Error() -} - -func (e ErrNatMap) Unwrap() error { - return e.err -} - func testServer(ctx context.Context, port uint16) (net.Listener, error) { s := http.Server{ ReadTimeout: 5 * time.Second, diff --git a/natmap/forward.go b/natmap/forward.go new file mode 100644 index 0000000..5f403fb --- /dev/null +++ b/natmap/forward.go @@ -0,0 +1,395 @@ +// Package forward contains a UDP packet forwarder. +// https://github.com/gwangyi/udp-forward +package natmap + +import ( + "errors" + "log" + "net" + "sync" + "time" +) + +type connection struct { + available chan struct{} + udp *net.UDPConn + lastActive time.Time +} + +type Logger interface { + Println(v ...any) +} + +// Forwarder represents a UDP packet forwarder. +type Forwarder struct { + src *net.UDPAddr + router Router + client *net.UDPAddr + listenerConn *net.UDPConn + + connections map[string]*connection + connectionsMutex *sync.RWMutex + + connectCallback func(addr string) + disconnectCallback func(addr string) + + timeout time.Duration + + closed bool + + bufferSize int + + logger Logger +} + +// Router represents a router that gives the destination address. +type Router interface { + Route(*net.UDPAddr) *net.UDPAddr +} + +type staticRouter struct { + *net.UDPAddr +} + +func (r staticRouter) Route(*net.UDPAddr) *net.UDPAddr { + return r.UDPAddr +} + +type funcRouter func(*net.UDPAddr) *net.UDPAddr + +func (r funcRouter) Route(incoming *net.UDPAddr) *net.UDPAddr { + return r(incoming) +} + +// config represents the configuration of Forwarder. +type config struct { + listenerFactory func() (*net.UDPConn, error) + router Router + timeout time.Duration + bufferSize int + logger Logger +} + +// Option gives the way to customize the forwarder. +type Option func(*config) error + +// WithAddr lets the new forwarder listen from given address. +func WithAddr(src string) Option { + return func(c *config) error { + srcAddr, err := net.ResolveUDPAddr("udp", src) + if err != nil { + return err + } + c.listenerFactory = func() (*net.UDPConn, error) { + return net.ListenUDP("udp", srcAddr) + } + return nil + } +} + +// WithConn lets the new forwarder to use given conn instead of new one. +func WithConn(conn *net.UDPConn) Option { + return func(c *config) error { + c.listenerFactory = func() (*net.UDPConn, error) { + return conn, nil + } + return nil + } +} + +// WithDestination lets the new forwarder forward packets to the given address. +func WithDestination(dest string) Option { + return func(c *config) error { + destAddr, err := net.ResolveUDPAddr("udp", dest) + if err != nil { + return err + } + c.router = staticRouter{UDPAddr: destAddr} + return nil + } +} + +// WithRouter lets the new forwarder forward packets according to the given router. +func WithRouter(router Router) Option { + return func(c *config) error { + c.router = router + return nil + } +} + +// WithRouterFunc does the same as WithRouter, but with a function. +func WithRouterFunc(router func(*net.UDPAddr) *net.UDPAddr) Option { + return WithRouter(funcRouter(router)) +} + +// WithTimeout sets the timeout. +// No interaction more than the timeout will remove the connection from the NAT +// table. +func WithTimeout(timeout time.Duration) Option { + return func(c *config) error { + c.timeout = timeout + return nil + } +} + +// WithBufferSize sets the buffer size that is used by forwarding. +// Larger packet can be discarded. +func WithBufferSize(size int) Option { + return func(c *config) error { + c.bufferSize = size + return nil + } +} + +// WithLogger sets a logger. +func WithLogger(logger Logger) Option { + return func(c *config) error { + c.logger = logger + return nil + } +} + +type emptyLogger struct{} + +func (emptyLogger) Println(v ...any) {} + +// WithoutLogger lets forwarder not log anything. +func WithoutLogger() Option { + return WithLogger(emptyLogger{}) +} + +// DefaultTimeout is the default timeout period of inactivity for convenience +// sake. It is equivelant to 5 minutes. +const DefaultTimeout = time.Minute * 5 + +// Forward forwards UDP packets from the src address to the dst address, with a +// timeout to "disconnect" clients after the timeout period of inactivity. It +// implements a reverse NAT and thus supports multiple seperate users. Forward +// is also asynchronous. +func forward(options ...Option) (*Forwarder, error) { + config := &config{ + timeout: DefaultTimeout, + bufferSize: 4096, + logger: log.Default(), + } + + options = append([]Option{WithAddr(":")}, options...) + + for _, opt := range options { + if err := opt(config); err != nil { + return nil, err + } + } + + forwarder := new(Forwarder) + forwarder.connectCallback = func(addr string) {} + forwarder.disconnectCallback = func(addr string) {} + forwarder.connectionsMutex = new(sync.RWMutex) + forwarder.connections = make(map[string]*connection) + forwarder.timeout = config.timeout + forwarder.router = config.router + forwarder.bufferSize = config.bufferSize + forwarder.logger = config.logger + + var err error + forwarder.listenerConn, err = config.listenerFactory() + if err != nil { + return nil, err + } + forwarder.src, _ = forwarder.listenerConn.LocalAddr().(*net.UDPAddr) + + go forwarder.janitor() + go forwarder.run() + + return forwarder, nil +} + +func (f *Forwarder) run() { + for { + buf := make([]byte, f.bufferSize) + oob := make([]byte, f.bufferSize) + n, _, _, addr, err := f.listenerConn.ReadMsgUDP(buf, oob) + if err != nil { + f.logger.Println("forward: failed to read, terminating:", err) + return + } + go f.handle(buf[:n], addr) + } +} + +func (f *Forwarder) janitor() { + for !f.closed { + time.Sleep(f.timeout) + var keysToDelete []string + + f.connectionsMutex.RLock() + for k, conn := range f.connections { + if conn.lastActive.Before(time.Now().Add(-f.timeout)) { + keysToDelete = append(keysToDelete, k) + } + } + f.connectionsMutex.RUnlock() + + f.connectionsMutex.Lock() + for _, k := range keysToDelete { + f.connections[k].udp.Close() + delete(f.connections, k) + } + f.connectionsMutex.Unlock() + + for _, k := range keysToDelete { + f.disconnectCallback(k) + } + } +} + +func (f *Forwarder) handle(data []byte, addr *net.UDPAddr) { + f.connectionsMutex.Lock() + conn, found := f.connections[addr.String()] + if !found { + f.connections[addr.String()] = &connection{ + available: make(chan struct{}), + udp: nil, + lastActive: time.Now(), + } + } + f.connectionsMutex.Unlock() + + if !found { + var udpConn *net.UDPConn + var err error + dst := f.router.Route(addr) + if dst == nil { + f.connectionsMutex.Lock() + delete(f.connections, addr.String()) + f.connectionsMutex.Unlock() + return + } + if dst.IP.To4()[0] == 127 { + // log.Println("using local listener") + laddr, _ := net.ResolveUDPAddr("udp", "127.0.0.1:") + udpConn, err = net.DialUDP("udp", laddr, dst) + } else { + udpConn, err = net.DialUDP("udp", nil, dst) + } + if err != nil { + f.logger.Println("udp-forward: failed to dial:", err) + delete(f.connections, addr.String()) + return + } + + f.connectionsMutex.Lock() + f.connections[addr.String()].udp = udpConn + f.connections[addr.String()].lastActive = time.Now() + close(f.connections[addr.String()].available) + f.connectionsMutex.Unlock() + + f.connectCallback(addr.String()) + + _, _, err = udpConn.WriteMsgUDP(data, nil, nil) + if err != nil { + f.logger.Println("udp-forward: error sending initial packet to client", err) + } + + for { + // log.Println("in loop to read from NAT connection to servers") + buf := make([]byte, f.bufferSize) + oob := make([]byte, f.bufferSize) + n, _, _, _, err := udpConn.ReadMsgUDP(buf, oob) + if err != nil { + f.connectionsMutex.Lock() + udpConn.Close() + delete(f.connections, addr.String()) + f.connectionsMutex.Unlock() + f.disconnectCallback(addr.String()) + f.logger.Println("udp-forward: abnormal read, closing:", err) + return + } + + // log.Println("sent packet to client") + _, _, err = f.listenerConn.WriteMsgUDP(buf[:n], nil, addr) + if err != nil { + f.logger.Println("udp-forward: error sending packet to client:", err) + } + } + + // unreachable + } + + <-conn.available + + // log.Println("sent packet to server", conn.udp.RemoteAddr()) + _, _, err := conn.udp.WriteMsgUDP(data, nil, nil) + if err != nil { + f.logger.Println("udp-forward: error sending packet to server:", err) + } + + shouldChangeTime := false + f.connectionsMutex.RLock() + if _, found := f.connections[addr.String()]; found { + if f.connections[addr.String()].lastActive.Before( + time.Now().Add(f.timeout / 4)) { + shouldChangeTime = true + } + } + f.connectionsMutex.RUnlock() + + if shouldChangeTime { + f.connectionsMutex.Lock() + // Make sure it still exists + if _, found := f.connections[addr.String()]; found { + connWrapper := f.connections[addr.String()] + connWrapper.lastActive = time.Now() + f.connections[addr.String()] = connWrapper + } + f.connectionsMutex.Unlock() + } +} + +// Close stops the forwarder. +func (f *Forwarder) Close() error { + var errs error + f.connectionsMutex.Lock() + f.closed = true + for _, conn := range f.connections { + err := conn.udp.Close() + if err != nil { + errs = errors.Join(errs, err) + } + } + err := f.listenerConn.Close() + if err != nil { + errs = errors.Join(errs, err) + } + f.connectionsMutex.Unlock() + return errs +} + +// OnConnect can be called with a callback function to be called whenever a +// new client connects. +func (f *Forwarder) OnConnect(callback func(addr string)) { + f.connectCallback = callback +} + +// OnDisconnect can be called with a callback function to be called whenever a +// new client disconnects (after 5 minutes of inactivity). +func (f *Forwarder) OnDisconnect(callback func(addr string)) { + f.disconnectCallback = callback +} + +// Connected returns the list of connected clients in IP:port form. +func (f *Forwarder) Connected() []string { + f.connectionsMutex.Lock() + defer f.connectionsMutex.Unlock() + results := make([]string, 0, len(f.connections)) + for key := range f.connections { + results = append(results, key) + } + return results +} + +// LocalAddr returns LocalAddr of listening connection. +func (f *Forwarder) LocalAddr() *net.UDPAddr { + addr, _ := f.listenerConn.LocalAddr().(*net.UDPAddr) + return addr +} diff --git a/natmap/natmap.go b/natmap/natmap.go index 45a676a..127b507 100644 --- a/natmap/natmap.go +++ b/natmap/natmap.go @@ -106,7 +106,7 @@ func GetLocalAddr() (net.Addr, error) { return l.LocalAddr(), nil } -func Forward(ctx context.Context, port uint16, target string, log func(string)) (net.Listener, error) { +func Forward(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) { l, err := reuse.Listen(ctx, "tcp", "0.0.0.0:"+strconv.FormatUint(uint64(port), 10)) if err != nil { return nil, fmt.Errorf("Forward: %w", err) diff --git a/natmap/udp.go b/natmap/udp.go index 007bc1d..c46bdc0 100644 --- a/natmap/udp.go +++ b/natmap/udp.go @@ -3,9 +3,11 @@ package natmap import ( "context" "fmt" + "io" "net" "net/netip" "strconv" + "strings" "time" "github.com/xmdhs/natupnp/reuse" @@ -54,3 +56,26 @@ func keepaliveUDP(ctx context.Context, port uint16, log func(error)) { } } } + +type logger struct { + log func(string) +} + +func (l logger) Println(v ...any) { + build := &strings.Builder{} + fmt.Fprint(build, v...) + l.log(build.String()) +} + +func ForwardUdp(ctx context.Context, port uint16, target string, log func(string)) (io.Closer, error) { + lc, err := reuse.ListenPacket(ctx, "udp", "0.0.0.0:"+strconv.FormatUint(uint64(port), 10)) + if err != nil { + return nil, err + } + + f, err := forward(WithLogger(logger{log}), WithConn(lc.(*net.UDPConn)), WithDestination(target)) + if err != nil { + return nil, fmt.Errorf("ForwardUdp: %w", err) + } + return f, nil +} diff --git a/reuse/reuse.go b/reuse/reuse.go index a33abfa..efa18ee 100644 --- a/reuse/reuse.go +++ b/reuse/reuse.go @@ -27,3 +27,7 @@ var listenConfig = net.ListenConfig{ func Listen(ctx context.Context, network, address string) (net.Listener, error) { return listenConfig.Listen(ctx, network, address) } + +func ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) { + return listenConfig.ListenPacket(ctx, network, address) +}