Files
opennotr/opennotrd/plugin/udpproxy/udpproxy.go
2021-06-05 16:36:28 +08:00

194 lines
3.9 KiB
Go

package udpproxy
import (
"context"
"encoding/json"
"net"
"sync"
"time"
"github.com/ICKelin/opennotr/internal/logs"
"github.com/ICKelin/opennotr/opennotrd/plugin"
)
// default timeout for udp session
var defaultTimeout = 30
func init() {
plugin.Register("udp", &UDPProxy{})
}
type config struct {
// session timeout(second)
SessionTimeout int `json:"sessionTimeout"`
}
type UDPProxy struct {
cfg config
}
func (p *UDPProxy) Setup(rawMessage json.RawMessage) error {
var cfg config
err := json.Unmarshal(rawMessage, &cfg)
if err != nil {
return err
}
if cfg.SessionTimeout <= 0 {
cfg.SessionTimeout = defaultTimeout
}
p.cfg = cfg
return nil
}
func (p *UDPProxy) StopProxy(item *plugin.PluginMeta) {
select {
case item.RecycleSignal <- struct{}{}:
default:
}
}
func (p *UDPProxy) RunProxy(item *plugin.PluginMeta) (*plugin.ProxyTuple, error) {
from := item.From
laddr, err := net.ResolveUDPAddr("udp", from)
if err != nil {
return nil, err
}
lis, err := net.ListenUDP("udp", laddr)
if err != nil {
return nil, err
}
go p.doProxy(lis, item)
_, fromPort, _ := net.SplitHostPort(lis.LocalAddr().String())
_, toPort, _ := net.SplitHostPort(item.To)
return &plugin.ProxyTuple{
Protocol: item.Protocol,
FromPort: fromPort,
ToPort: toPort,
}, nil
}
func (p *UDPProxy) doProxy(lis *net.UDPConn, item *plugin.PluginMeta) {
defer lis.Close()
from := item.From
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// receive this signal and close the listener
// the listener close will force lis.ReadFromUDP loop break
// then close all the client socket and end udpCopy
go func() {
select {
case <-item.RecycleSignal:
logs.Info("receive recycle signal for %s", from)
lis.Close()
case <-ctx.Done():
return
}
}()
// sess store all backend connection
// key: client address
// value: *net.UDPConn
sess := sync.Map{}
// sessionTimeout store all session key active time
// the purpose of this is to avoid too session without expired
sessionTimeout := sync.Map{}
// close all backend sockets
// this action may end udpCopy
defer func() {
sess.Range(func(k, v interface{}) bool {
if conn, ok := v.(*net.UDPConn); ok {
conn.Close()
}
return true
})
}()
go func() {
timeout := p.cfg.SessionTimeout
interval := timeout / 2
if interval <= 0 {
interval = timeout
}
tick := time.NewTicker(time.Second * time.Duration(interval))
for range tick.C {
sessionTimeout.Range(func(k, v interface{}) bool {
lastActiveAt, ok := v.(time.Time)
if !ok {
return true
}
if time.Now().Sub(lastActiveAt).Seconds() > float64(timeout) {
sess.Delete(k)
}
return true
})
}
}()
var buf = make([]byte, 64*1024)
for {
nr, raddr, err := lis.ReadFromUDP(buf)
if err != nil {
logs.Error("read from udp fail: %v", err)
break
}
key := raddr.String()
val, ok := sess.Load(key)
if !ok {
backendAddr, err := net.ResolveUDPAddr("udp", item.To)
if err != nil {
logs.Error("resolve udp fail: %v", err)
break
}
backendConn, err := net.DialUDP("udp", nil, backendAddr)
if err != nil {
logs.Error("dial udp fail: %v", err)
break
}
sess.Store(key, backendConn)
sessionTimeout.Store(key, time.Now())
// read from $to address and write to $from address
go p.udpCopy(lis, backendConn, raddr)
}
val, ok = sess.Load(key)
if !ok {
continue
}
sessionTimeout.Store(key, time.Now())
// read from $from address and write to $to address
val.(*net.UDPConn).Write(buf[:nr])
}
}
func (p *UDPProxy) udpCopy(dst, src *net.UDPConn, toaddr *net.UDPAddr) {
defer src.Close()
buf := make([]byte, 64*1024)
for {
nr, _, err := src.ReadFromUDP(buf)
if err != nil {
logs.Error("read from udp fail: %v", err)
break
}
_, err = dst.WriteToUDP(buf[:nr], toaddr)
if err != nil {
logs.Error("write to udp fail: %v", err)
break
}
}
}