mirror of
https://github.com/bolucat/Archive.git
synced 2025-10-29 02:52:47 +08:00
202 lines
4.8 KiB
Go
202 lines
4.8 KiB
Go
// nolint: errcheck
|
|
package transporter
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Ehco1996/ehco/internal/constant"
|
|
"github.com/xtaci/smux"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type smuxTransporter struct {
|
|
sessionMutex sync.Mutex
|
|
|
|
gcTicker *time.Ticker
|
|
l *zap.SugaredLogger
|
|
|
|
// remote addr -> SessionWithMetrics
|
|
sessionM map[string][]*SessionWithMetrics
|
|
|
|
initSessionF func(ctx context.Context, addr string) (*smux.Session, error)
|
|
}
|
|
|
|
type SessionWithMetrics struct {
|
|
session *smux.Session
|
|
|
|
createdTime time.Time
|
|
streamList []*smux.Stream
|
|
}
|
|
|
|
func (sm *SessionWithMetrics) CanNotServeNewStream() bool {
|
|
return sm.session.IsClosed() ||
|
|
sm.session.NumStreams() >= constant.SmuxMaxStreamCnt ||
|
|
time.Since(sm.createdTime) > constant.SmuxMaxAliveDuration
|
|
}
|
|
|
|
func streamDead(s *smux.Stream) bool {
|
|
select {
|
|
case _, ok := <-s.GetDieCh():
|
|
return !ok // 如果接收到值且通道未关闭,则 Stream 未死
|
|
default:
|
|
return true // 如果通道已经关闭,则 Stream 死了
|
|
}
|
|
}
|
|
|
|
func (sm *SessionWithMetrics) canCloseSession(remoteAddr string, l *zap.SugaredLogger) bool {
|
|
for _, s := range sm.streamList {
|
|
if !streamDead(s) {
|
|
return false
|
|
}
|
|
l.Debugf("session: %s stream: %d is not dead", remoteAddr, s.ID())
|
|
}
|
|
return true
|
|
}
|
|
|
|
func NewSmuxTransporter(
|
|
l *zap.SugaredLogger,
|
|
initSessionF func(ctx context.Context, addr string) (*smux.Session, error),
|
|
) *smuxTransporter {
|
|
tr := &smuxTransporter{
|
|
l: l,
|
|
initSessionF: initSessionF,
|
|
sessionM: make(map[string][]*SessionWithMetrics),
|
|
gcTicker: time.NewTicker(constant.SmuxGCDuration),
|
|
}
|
|
// start gc thread for close idle sessions
|
|
go tr.gc()
|
|
return tr
|
|
}
|
|
|
|
func (tr *smuxTransporter) gc() {
|
|
for range tr.gcTicker.C {
|
|
tr.sessionMutex.Lock()
|
|
for addr, sl := range tr.sessionM {
|
|
tr.l.Debugf("start doing gc for remote addr: %s total session count %d", addr, len(sl))
|
|
for idx := range sl {
|
|
sm := sl[idx]
|
|
if sm.CanNotServeNewStream() && sm.canCloseSession(addr, tr.l) {
|
|
tr.l.Debugf("close idle session:%s stream cnt %d",
|
|
sm.session.LocalAddr().String(), sm.session.NumStreams())
|
|
sm.session.Close()
|
|
}
|
|
}
|
|
newList := []*SessionWithMetrics{}
|
|
for _, s := range sl {
|
|
if !s.session.IsClosed() {
|
|
newList = append(newList, s)
|
|
}
|
|
}
|
|
tr.sessionM[addr] = newList
|
|
tr.l.Debugf("finish gc for remote addr: %s total session count %d", addr, len(sl))
|
|
}
|
|
tr.sessionMutex.Unlock()
|
|
}
|
|
}
|
|
|
|
func (tr *smuxTransporter) Dial(ctx context.Context, addr string) (conn net.Conn, err error) {
|
|
tr.sessionMutex.Lock()
|
|
defer tr.sessionMutex.Unlock()
|
|
var session *smux.Session
|
|
var curSM *SessionWithMetrics
|
|
|
|
sessionList := tr.sessionM[addr]
|
|
for _, sm := range sessionList {
|
|
if sm.CanNotServeNewStream() {
|
|
continue
|
|
} else {
|
|
tr.l.Debugf("use session: %s total stream count: %d remote addr: %s",
|
|
sm.session.LocalAddr().String(), sm.session.NumStreams(), addr)
|
|
session = sm.session
|
|
curSM = sm
|
|
break
|
|
}
|
|
}
|
|
// create new one
|
|
if session == nil {
|
|
session, err = tr.initSessionF(ctx, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sm := &SessionWithMetrics{session: session, createdTime: time.Now(), streamList: []*smux.Stream{}}
|
|
sessionList = append(sessionList, sm)
|
|
tr.sessionM[addr] = sessionList
|
|
curSM = sm
|
|
}
|
|
|
|
stream, err := session.OpenStream()
|
|
if err != nil {
|
|
tr.l.Errorf("open stream meet error:%s", err)
|
|
session.Close()
|
|
return nil, err
|
|
}
|
|
curSM.streamList = append(curSM.streamList, stream)
|
|
return stream, nil
|
|
}
|
|
|
|
type muxServer interface {
|
|
ListenAndServe() error
|
|
Accept() (net.Conn, error)
|
|
Close() error
|
|
mux(net.Conn)
|
|
}
|
|
|
|
func newMuxServer(listenAddr string, l *zap.SugaredLogger) *muxServerImpl {
|
|
return &muxServerImpl{
|
|
errChan: make(chan error, 1),
|
|
connChan: make(chan net.Conn, 1024),
|
|
listenAddr: listenAddr,
|
|
l: l,
|
|
}
|
|
}
|
|
|
|
type muxServerImpl struct {
|
|
errChan chan error
|
|
connChan chan net.Conn
|
|
|
|
listenAddr string
|
|
l *zap.SugaredLogger
|
|
}
|
|
|
|
func (s *muxServerImpl) Accept() (net.Conn, error) {
|
|
select {
|
|
case conn := <-s.connChan:
|
|
return conn, nil
|
|
case err := <-s.errChan:
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
func (s *muxServerImpl) mux(conn net.Conn) {
|
|
defer conn.Close()
|
|
|
|
cfg := smux.DefaultConfig()
|
|
cfg.KeepAliveDisabled = true
|
|
session, err := smux.Server(conn, cfg)
|
|
if err != nil {
|
|
s.l.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.listenAddr, err)
|
|
return
|
|
}
|
|
defer session.Close() // nolint: errcheck
|
|
|
|
s.l.Debugf("session init %s %s", conn.RemoteAddr(), s.listenAddr)
|
|
defer s.l.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.listenAddr)
|
|
|
|
for {
|
|
stream, err := session.AcceptStream()
|
|
if err != nil {
|
|
s.l.Errorf("accept stream err: %s", err)
|
|
break
|
|
}
|
|
select {
|
|
case s.connChan <- stream:
|
|
default:
|
|
stream.Close() // nolint: errcheck
|
|
s.l.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
|
|
}
|
|
}
|
|
}
|