mirror of
https://github.com/ICKelin/opennotr.git
synced 2025-09-27 04:05:57 +08:00
194 lines
3.9 KiB
Go
194 lines
3.9 KiB
Go
package udpproxy
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ICKelin/opennotr/internal/logs"
|
|
"github.com/ICKelin/opennotr/opennotrd/plugin"
|
|
)
|
|
|
|
// default timeout for udp session
|
|
var defaultTimeout = 30
|
|
|
|
func init() {
|
|
plugin.Register("udp", &UDPProxy{})
|
|
}
|
|
|
|
type config struct {
|
|
// session timeout(second)
|
|
SessionTimeout int `json:"sessionTimeout"`
|
|
}
|
|
|
|
type UDPProxy struct {
|
|
cfg config
|
|
}
|
|
|
|
func (p *UDPProxy) Setup(rawMessage json.RawMessage) error {
|
|
var cfg config
|
|
err := json.Unmarshal(rawMessage, &cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if cfg.SessionTimeout <= 0 {
|
|
cfg.SessionTimeout = defaultTimeout
|
|
}
|
|
p.cfg = cfg
|
|
return nil
|
|
}
|
|
|
|
func (p *UDPProxy) StopProxy(item *plugin.PluginMeta) {
|
|
select {
|
|
case item.RecycleSignal <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (p *UDPProxy) RunProxy(item *plugin.PluginMeta) (*plugin.ProxyTuple, error) {
|
|
from := item.From
|
|
laddr, err := net.ResolveUDPAddr("udp", from)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
lis, err := net.ListenUDP("udp", laddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
go p.doProxy(lis, item)
|
|
|
|
_, fromPort, _ := net.SplitHostPort(lis.LocalAddr().String())
|
|
_, toPort, _ := net.SplitHostPort(item.To)
|
|
|
|
return &plugin.ProxyTuple{
|
|
Protocol: item.Protocol,
|
|
FromPort: fromPort,
|
|
ToPort: toPort,
|
|
}, nil
|
|
}
|
|
|
|
func (p *UDPProxy) doProxy(lis *net.UDPConn, item *plugin.PluginMeta) {
|
|
defer lis.Close()
|
|
|
|
from := item.From
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// receive this signal and close the listener
|
|
// the listener close will force lis.ReadFromUDP loop break
|
|
// then close all the client socket and end udpCopy
|
|
go func() {
|
|
select {
|
|
case <-item.RecycleSignal:
|
|
logs.Info("receive recycle signal for %s", from)
|
|
lis.Close()
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}()
|
|
|
|
// sess store all backend connection
|
|
// key: client address
|
|
// value: *net.UDPConn
|
|
sess := sync.Map{}
|
|
|
|
// sessionTimeout store all session key active time
|
|
// the purpose of this is to avoid too session without expired
|
|
sessionTimeout := sync.Map{}
|
|
|
|
// close all backend sockets
|
|
// this action may end udpCopy
|
|
defer func() {
|
|
sess.Range(func(k, v interface{}) bool {
|
|
if conn, ok := v.(*net.UDPConn); ok {
|
|
conn.Close()
|
|
}
|
|
return true
|
|
})
|
|
}()
|
|
|
|
go func() {
|
|
timeout := p.cfg.SessionTimeout
|
|
interval := timeout / 2
|
|
if interval <= 0 {
|
|
interval = timeout
|
|
}
|
|
tick := time.NewTicker(time.Second * time.Duration(interval))
|
|
for range tick.C {
|
|
sessionTimeout.Range(func(k, v interface{}) bool {
|
|
lastActiveAt, ok := v.(time.Time)
|
|
if !ok {
|
|
return true
|
|
}
|
|
|
|
if time.Now().Sub(lastActiveAt).Seconds() > float64(timeout) {
|
|
sess.Delete(k)
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
}()
|
|
|
|
var buf = make([]byte, 64*1024)
|
|
for {
|
|
nr, raddr, err := lis.ReadFromUDP(buf)
|
|
if err != nil {
|
|
logs.Error("read from udp fail: %v", err)
|
|
break
|
|
}
|
|
|
|
key := raddr.String()
|
|
val, ok := sess.Load(key)
|
|
if !ok {
|
|
backendAddr, err := net.ResolveUDPAddr("udp", item.To)
|
|
if err != nil {
|
|
logs.Error("resolve udp fail: %v", err)
|
|
break
|
|
}
|
|
|
|
backendConn, err := net.DialUDP("udp", nil, backendAddr)
|
|
if err != nil {
|
|
logs.Error("dial udp fail: %v", err)
|
|
break
|
|
}
|
|
sess.Store(key, backendConn)
|
|
sessionTimeout.Store(key, time.Now())
|
|
|
|
// read from $to address and write to $from address
|
|
go p.udpCopy(lis, backendConn, raddr)
|
|
}
|
|
|
|
val, ok = sess.Load(key)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
sessionTimeout.Store(key, time.Now())
|
|
// read from $from address and write to $to address
|
|
val.(*net.UDPConn).Write(buf[:nr])
|
|
}
|
|
}
|
|
|
|
func (p *UDPProxy) udpCopy(dst, src *net.UDPConn, toaddr *net.UDPAddr) {
|
|
defer src.Close()
|
|
buf := make([]byte, 64*1024)
|
|
for {
|
|
nr, _, err := src.ReadFromUDP(buf)
|
|
if err != nil {
|
|
logs.Error("read from udp fail: %v", err)
|
|
break
|
|
}
|
|
|
|
_, err = dst.WriteToUDP(buf[:nr], toaddr)
|
|
if err != nil {
|
|
logs.Error("write to udp fail: %v", err)
|
|
break
|
|
}
|
|
}
|
|
}
|