mirror of
https://github.com/ICKelin/opennotr.git
synced 2025-09-26 20:01:13 +08:00
optimize: wrap udp session
This commit is contained in:
@@ -23,18 +23,60 @@ var (
|
||||
defaultUDPSessionTimeout = 30
|
||||
)
|
||||
|
||||
// udpSession defines each client forward stream
|
||||
// the purpose of udpSession is to reuse stream
|
||||
// tha lastActive members will reset for easch packet in/out
|
||||
type udpSession struct {
|
||||
stream *yamux.Stream
|
||||
lastActive time.Time
|
||||
}
|
||||
|
||||
type udpSessionManager struct {
|
||||
sessLock sync.Mutex
|
||||
udpsessions map[string]*udpSession
|
||||
}
|
||||
|
||||
func newUDPSessionManager() *udpSessionManager {
|
||||
return &udpSessionManager{
|
||||
udpsessions: make(map[string]*udpSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (mgr *udpSessionManager) Add(key string, val *udpSession) {
|
||||
mgr.sessLock.Lock()
|
||||
defer mgr.sessLock.Unlock()
|
||||
mgr.udpsessions[key] = val
|
||||
}
|
||||
|
||||
func (mgr *udpSessionManager) Get(key string) *udpSession {
|
||||
mgr.sessLock.Lock()
|
||||
defer mgr.sessLock.Unlock()
|
||||
return mgr.udpsessions[key]
|
||||
}
|
||||
|
||||
func (mgr *udpSessionManager) Delete(key string) {
|
||||
mgr.sessLock.Lock()
|
||||
defer mgr.sessLock.Unlock()
|
||||
delete(mgr.udpsessions, key)
|
||||
}
|
||||
|
||||
func (mgr *udpSessionManager) ResetActive(key string, tm time.Time) {
|
||||
mgr.sessLock.Lock()
|
||||
defer mgr.sessLock.Unlock()
|
||||
sess, ok := mgr.udpsessions[key]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
sess.lastActive = tm
|
||||
}
|
||||
|
||||
type UDPForward struct {
|
||||
listenAddr string
|
||||
sessionTimeout int
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
sessMgr *SessionManager
|
||||
udpSessions sync.Map
|
||||
udpSessionMgr *udpSessionManager
|
||||
}
|
||||
|
||||
func NewUDPForward(cfg UDPForwardConfig) *UDPForward {
|
||||
@@ -59,6 +101,7 @@ func NewUDPForward(cfg UDPForwardConfig) *UDPForward {
|
||||
writeTimeout: time.Duration(writeTimeout) * time.Second,
|
||||
sessionTimeout: sessionTimeout,
|
||||
sessMgr: GetSessionManager(),
|
||||
udpSessionMgr: newUDPSessionManager(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,8 +171,8 @@ func (f *UDPForward) ListenAndServe() error {
|
||||
sip, sport, _ := net.SplitHostPort(raddr.String())
|
||||
|
||||
key := fmt.Sprintf("%s:%s:%s:%s", sip, sport, dip, dport)
|
||||
val, ok := f.udpSessions.Load(key)
|
||||
if !ok {
|
||||
udpsess := f.udpSessionMgr.Get(key)
|
||||
if udpsess != nil {
|
||||
sess := f.sessMgr.GetSession(dip)
|
||||
if sess == nil {
|
||||
logs.Error("no route to host: %s", dip)
|
||||
@@ -141,8 +184,9 @@ func (f *UDPForward) ListenAndServe() error {
|
||||
logs.Error("open stream fail: %v", err)
|
||||
continue
|
||||
}
|
||||
f.udpSessions.Store(key, &udpSession{stream, time.Now()})
|
||||
|
||||
udpsess = &udpSession{stream, time.Now()}
|
||||
f.udpSessionMgr.Add(key, udpsess)
|
||||
targetIP := "127.0.0.1"
|
||||
bytes := encodeProxyProtocol("udp", sip, sport, targetIP, dport)
|
||||
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
|
||||
@@ -154,20 +198,9 @@ func (f *UDPForward) ListenAndServe() error {
|
||||
}
|
||||
|
||||
go f.forwardUDP(stream, key, rawfd, origindst, raddr)
|
||||
val, ok = f.udpSessions.Load(key)
|
||||
if !ok {
|
||||
logs.Error("get stream for %s fail", key)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
udpsess, ok := val.(*udpSession)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// update active time to avoid session recycle
|
||||
udpsess.lastActive = time.Now()
|
||||
f.udpSessionMgr.ResetActive(key, time.Now())
|
||||
stream := udpsess.stream
|
||||
|
||||
bytes := encode(buf[:nr])
|
||||
@@ -184,7 +217,7 @@ func (f *UDPForward) ListenAndServe() error {
|
||||
// forwardUDP reads from stream and write to tofd via rawsocket
|
||||
func (f *UDPForward) forwardUDP(stream *yamux.Stream, sessionKey string, tofd int, fromaddr, toaddr *net.UDPAddr) {
|
||||
defer stream.Close()
|
||||
defer f.udpSessions.Delete(sessionKey)
|
||||
defer f.udpSessionMgr.Delete(sessionKey)
|
||||
hdr := make([]byte, 2)
|
||||
for {
|
||||
nr, err := stream.Read(hdr)
|
||||
@@ -213,29 +246,23 @@ func (f *UDPForward) forwardUDP(stream *yamux.Stream, sessionKey string, tofd in
|
||||
if err != nil {
|
||||
logs.Error("send via raw socket fail: %v", err)
|
||||
}
|
||||
f.udpSessionMgr.ResetActive(sessionKey, time.Now())
|
||||
}
|
||||
}
|
||||
|
||||
func (f *UDPForward) recyeleSession() {
|
||||
tick := time.NewTicker(time.Second * 5)
|
||||
for range tick.C {
|
||||
total, timeout := 0, 0
|
||||
f.udpSessions.Range(func(k, v interface{}) bool {
|
||||
total += 1
|
||||
s, ok := v.(*udpSession)
|
||||
if !ok {
|
||||
return true
|
||||
}
|
||||
|
||||
total, timeout := len(f.udpSessionMgr.udpsessions), 0
|
||||
f.udpSessionMgr.sessLock.Lock()
|
||||
for k, s := range f.udpSessionMgr.udpsessions {
|
||||
if time.Now().Sub(s.lastActive).Seconds() > float64(f.sessionTimeout) {
|
||||
logs.Warn("remove udp %v session, lastActive: %v", k, s.lastActive)
|
||||
f.udpSessions.Delete(k)
|
||||
s.stream.Close()
|
||||
timeout += 1
|
||||
delete(f.udpSessionMgr.udpsessions, k)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
}
|
||||
f.udpSessionMgr.sessLock.Unlock()
|
||||
logs.Debug("total %d, timeout %d, left: %d", total, timeout, total-timeout)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user