Files
Archive/echo/internal/transporter/mux.go
2024-04-06 20:26:16 +02:00

139 lines
3.4 KiB
Go

// nolint: errcheck
package transporter
import (
"context"
"net"
"sync"
"time"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/xtaci/smux"
"go.uber.org/zap"
)
type smuxTransporter struct {
sessionMutex sync.Mutex
gcTicker *time.Ticker
l *zap.SugaredLogger
// remote addr -> SessionWithMetrics
sessionM map[string][]*SessionWithMetrics
initSessionF func(ctx context.Context, addr string) (*smux.Session, error)
}
type SessionWithMetrics struct {
session *smux.Session
createdTime time.Time
streamList []*smux.Stream
}
func (sm *SessionWithMetrics) CanNotServeNewStream() bool {
return sm.session.IsClosed() ||
sm.session.NumStreams() >= constant.SmuxMaxStreamCnt ||
time.Since(sm.createdTime) > constant.SmuxMaxAliveDuration
}
func streamDead(s *smux.Stream) bool {
select {
case _, ok := <-s.GetDieCh():
return !ok // 如果接收到值且通道未关闭,则 Stream 未死
default:
return true // 如果通道已经关闭,则 Stream 死了
}
}
func (sm *SessionWithMetrics) canCloseSession(remoteAddr string, l *zap.SugaredLogger) bool {
for _, s := range sm.streamList {
if !streamDead(s) {
return false
}
l.Debugf("session: %s stream: %d is not dead", remoteAddr, s.ID())
}
return true
}
func NewSmuxTransporter(
l *zap.SugaredLogger,
initSessionF func(ctx context.Context, addr string) (*smux.Session, error),
) *smuxTransporter {
tr := &smuxTransporter{
l: l,
initSessionF: initSessionF,
sessionM: make(map[string][]*SessionWithMetrics),
gcTicker: time.NewTicker(constant.SmuxGCDuration),
}
// start gc thread for close idle sessions
go tr.gc()
return tr
}
func (tr *smuxTransporter) gc() {
for range tr.gcTicker.C {
tr.sessionMutex.Lock()
for addr, sl := range tr.sessionM {
tr.l.Debugf("start doing gc for remote addr: %s total session count %d", addr, len(sl))
for idx := range sl {
sm := sl[idx]
if sm.CanNotServeNewStream() && sm.canCloseSession(addr, tr.l) {
tr.l.Debugf("close idle session:%s stream cnt %d",
sm.session.LocalAddr().String(), sm.session.NumStreams())
sm.session.Close()
}
}
newList := []*SessionWithMetrics{}
for _, s := range sl {
if !s.session.IsClosed() {
newList = append(newList, s)
}
}
tr.sessionM[addr] = newList
tr.l.Debugf("finish gc for remote addr: %s total session count %d", addr, len(sl))
}
tr.sessionMutex.Unlock()
}
}
func (tr *smuxTransporter) Dial(ctx context.Context, addr string) (conn net.Conn, err error) {
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
var session *smux.Session
var curSM *SessionWithMetrics
sessionList := tr.sessionM[addr]
for _, sm := range sessionList {
if sm.CanNotServeNewStream() {
continue
} else {
tr.l.Debugf("use session: %s total stream count: %d remote addr: %s",
sm.session.LocalAddr().String(), sm.session.NumStreams(), addr)
session = sm.session
curSM = sm
break
}
}
// create new one
if session == nil {
session, err = tr.initSessionF(ctx, addr)
if err != nil {
return nil, err
}
sm := &SessionWithMetrics{session: session, createdTime: time.Now(), streamList: []*smux.Stream{}}
sessionList = append(sessionList, sm)
tr.sessionM[addr] = sessionList
curSM = sm
}
stream, err := session.OpenStream()
if err != nil {
tr.l.Errorf("open stream meet error:%s", err)
session.Close()
return nil, err
}
curSM.streamList = append(curSM.streamList, stream)
return stream, nil
}