optimize: wrap udp session

This commit is contained in:
ICKelin
2021-05-09 10:08:55 +08:00
parent 51bcd4dd44
commit afad386497

View File

@@ -23,18 +23,60 @@ var (
defaultUDPSessionTimeout = 30
)
// udpSession defines each client forward stream
// the purpose of udpSession is to reuse stream
// tha lastActive members will reset for easch packet in/out
type udpSession struct {
stream *yamux.Stream
lastActive time.Time
}
type udpSessionManager struct {
sessLock sync.Mutex
udpsessions map[string]*udpSession
}
func newUDPSessionManager() *udpSessionManager {
return &udpSessionManager{
udpsessions: make(map[string]*udpSession),
}
}
func (mgr *udpSessionManager) Add(key string, val *udpSession) {
mgr.sessLock.Lock()
defer mgr.sessLock.Unlock()
mgr.udpsessions[key] = val
}
func (mgr *udpSessionManager) Get(key string) *udpSession {
mgr.sessLock.Lock()
defer mgr.sessLock.Unlock()
return mgr.udpsessions[key]
}
func (mgr *udpSessionManager) Delete(key string) {
mgr.sessLock.Lock()
defer mgr.sessLock.Unlock()
delete(mgr.udpsessions, key)
}
func (mgr *udpSessionManager) ResetActive(key string, tm time.Time) {
mgr.sessLock.Lock()
defer mgr.sessLock.Unlock()
sess, ok := mgr.udpsessions[key]
if !ok {
return
}
sess.lastActive = tm
}
type UDPForward struct {
listenAddr string
sessionTimeout int
readTimeout time.Duration
writeTimeout time.Duration
sessMgr *SessionManager
udpSessions sync.Map
udpSessionMgr *udpSessionManager
}
func NewUDPForward(cfg UDPForwardConfig) *UDPForward {
@@ -59,6 +101,7 @@ func NewUDPForward(cfg UDPForwardConfig) *UDPForward {
writeTimeout: time.Duration(writeTimeout) * time.Second,
sessionTimeout: sessionTimeout,
sessMgr: GetSessionManager(),
udpSessionMgr: newUDPSessionManager(),
}
}
@@ -128,8 +171,8 @@ func (f *UDPForward) ListenAndServe() error {
sip, sport, _ := net.SplitHostPort(raddr.String())
key := fmt.Sprintf("%s:%s:%s:%s", sip, sport, dip, dport)
val, ok := f.udpSessions.Load(key)
if !ok {
udpsess := f.udpSessionMgr.Get(key)
if udpsess != nil {
sess := f.sessMgr.GetSession(dip)
if sess == nil {
logs.Error("no route to host: %s", dip)
@@ -141,8 +184,9 @@ func (f *UDPForward) ListenAndServe() error {
logs.Error("open stream fail: %v", err)
continue
}
f.udpSessions.Store(key, &udpSession{stream, time.Now()})
udpsess = &udpSession{stream, time.Now()}
f.udpSessionMgr.Add(key, udpsess)
targetIP := "127.0.0.1"
bytes := encodeProxyProtocol("udp", sip, sport, targetIP, dport)
stream.SetWriteDeadline(time.Now().Add(f.writeTimeout))
@@ -154,20 +198,9 @@ func (f *UDPForward) ListenAndServe() error {
}
go f.forwardUDP(stream, key, rawfd, origindst, raddr)
val, ok = f.udpSessions.Load(key)
if !ok {
logs.Error("get stream for %s fail", key)
continue
}
}
udpsess, ok := val.(*udpSession)
if !ok {
continue
}
// update active time to avoid session recycle
udpsess.lastActive = time.Now()
f.udpSessionMgr.ResetActive(key, time.Now())
stream := udpsess.stream
bytes := encode(buf[:nr])
@@ -184,7 +217,7 @@ func (f *UDPForward) ListenAndServe() error {
// forwardUDP reads from stream and write to tofd via rawsocket
func (f *UDPForward) forwardUDP(stream *yamux.Stream, sessionKey string, tofd int, fromaddr, toaddr *net.UDPAddr) {
defer stream.Close()
defer f.udpSessions.Delete(sessionKey)
defer f.udpSessionMgr.Delete(sessionKey)
hdr := make([]byte, 2)
for {
nr, err := stream.Read(hdr)
@@ -213,29 +246,23 @@ func (f *UDPForward) forwardUDP(stream *yamux.Stream, sessionKey string, tofd in
if err != nil {
logs.Error("send via raw socket fail: %v", err)
}
f.udpSessionMgr.ResetActive(sessionKey, time.Now())
}
}
func (f *UDPForward) recyeleSession() {
tick := time.NewTicker(time.Second * 5)
for range tick.C {
total, timeout := 0, 0
f.udpSessions.Range(func(k, v interface{}) bool {
total += 1
s, ok := v.(*udpSession)
if !ok {
return true
}
total, timeout := len(f.udpSessionMgr.udpsessions), 0
f.udpSessionMgr.sessLock.Lock()
for k, s := range f.udpSessionMgr.udpsessions {
if time.Now().Sub(s.lastActive).Seconds() > float64(f.sessionTimeout) {
logs.Warn("remove udp %v session, lastActive: %v", k, s.lastActive)
f.udpSessions.Delete(k)
s.stream.Close()
timeout += 1
delete(f.udpSessionMgr.udpsessions, k)
}
return true
})
}
f.udpSessionMgr.sessLock.Unlock()
logs.Debug("total %d, timeout %d, left: %d", total, timeout, total-timeout)
}
}