Files
Archive/echo/internal/transporter/base.go
2024-09-06 20:35:09 +02:00

182 lines
5.0 KiB
Go

package transporter
import (
"context"
"fmt"
"net"
"github.com/sagernet/sing-box/common/sniff"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/cmgr"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay/conf"
)
var _ RelayServer = &BaseRelayServer{}
type BaseRelayServer struct {
cmgr cmgr.Cmgr
cfg *conf.Config
l *zap.SugaredLogger
remotes lb.RoundRobin
relayer RelayClient
}
func newBaseRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (*BaseRelayServer, error) {
relayer, err := newRelayClient(cfg)
if err != nil {
return nil, err
}
return &BaseRelayServer{
relayer: relayer,
cfg: cfg,
cmgr: cmgr,
remotes: cfg.ToRemotesLB(),
l: zap.S().Named(cfg.GetLoggerName()),
}, nil
}
func (b *BaseRelayServer) RelayTCPConn(ctx context.Context, c net.Conn, remote *lb.Node) error {
labels := []string{b.cfg.Label, metrics.METRIC_CONN_TYPE_TCP, remote.Address}
metrics.CurConnectionCount.WithLabelValues(labels...).Inc()
defer metrics.CurConnectionCount.WithLabelValues(labels...).Dec()
if err := b.checkConnectionLimit(); err != nil {
return err
}
var err error
c, err = b.sniffAndBlockProtocol(c)
if err != nil {
return err
}
c = b.applyRateLimit(c)
rc, err := b.relayer.HandShake(ctx, remote, true)
if err != nil {
return fmt.Errorf("handshake error: %w", err)
}
defer rc.Close()
b.l.Infof("RelayTCPConn from %s to %s", c.LocalAddr(), remote.Address)
return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_TCP)
}
func (b *BaseRelayServer) RelayUDPConn(ctx context.Context, c net.Conn, remote *lb.Node) error {
labels := []string{b.cfg.Label, metrics.METRIC_CONN_TYPE_UDP, remote.Address}
metrics.CurConnectionCount.WithLabelValues(labels...).Inc()
defer metrics.CurConnectionCount.WithLabelValues(labels...).Dec()
rc, err := b.relayer.HandShake(ctx, remote, false)
if err != nil {
return fmt.Errorf("handshake error: %w", err)
}
defer rc.Close()
b.l.Infof("RelayUDPConn from %s to %s", c.LocalAddr(), remote.Address)
return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_UDP)
}
func (b *BaseRelayServer) checkConnectionLimit() error {
if b.cmgr == nil {
return nil
}
if b.cfg.Options.MaxConnection > 0 && b.cmgr.CountConnection(cmgr.ConnectionTypeActive) >= b.cfg.Options.MaxConnection {
return fmt.Errorf("relay:%s active connection count exceed limit %d", b.cfg.Label, b.cfg.Options.MaxConnection)
}
return nil
}
func (b *BaseRelayServer) sniffAndBlockProtocol(c net.Conn) (net.Conn, error) {
if len(b.cfg.Options.BlockedProtocols) == 0 {
return c, nil
}
buffer := buf.NewPacket()
ctx, cancel := context.WithTimeout(context.Background(), b.cfg.Options.SniffTimeout)
defer cancel()
sniffMetadata, err := sniff.PeekStream(ctx, c, buffer, b.cfg.Options.SniffTimeout, sniff.TLSClientHello, sniff.HTTPHost)
if err != nil {
b.l.Debugf("sniff error: %s", err)
}
if sniffMetadata != nil {
b.l.Infof("sniffed protocol: %s", sniffMetadata.Protocol)
for _, p := range b.cfg.Options.BlockedProtocols {
if sniffMetadata.Protocol == p {
return c, fmt.Errorf("relay:%s blocked protocol:%s", b.cfg.Label, sniffMetadata.Protocol)
}
}
}
if !buffer.IsEmpty() {
return bufio.NewCachedConn(c, buffer), nil
} else {
buffer.Release()
}
return c, nil
}
func (b *BaseRelayServer) applyRateLimit(c net.Conn) net.Conn {
if b.cfg.Options.MaxReadRateKbps > 0 {
return conn.NewRateLimitedConn(c, b.cfg.Options.MaxReadRateKbps)
}
return c
}
func (b *BaseRelayServer) handleRelayConn(c, rc net.Conn, remote *lb.Node, connType string) error {
opts := []conn.RelayConnOption{
conn.WithLogger(b.l),
conn.WithRemote(remote),
conn.WithConnType(connType),
conn.WithRelayLabel(b.cfg.Label),
conn.WithRelayOptions(b.cfg.Options),
}
relayConn := conn.NewRelayConn(c, rc, opts...)
if b.cmgr != nil {
b.cmgr.AddConnection(relayConn)
defer b.cmgr.RemoveConnection(relayConn)
}
return relayConn.Transport()
}
func (b *BaseRelayServer) HealthCheck(ctx context.Context) (int64, error) {
remote := b.remotes.Next().Clone()
// us tcp handshake to check health
_, err := b.relayer.HandShake(ctx, remote, true)
return int64(remote.HandShakeDuration.Milliseconds()), err
}
func (b *BaseRelayServer) Close() error {
return fmt.Errorf("not implemented")
}
func (b *BaseRelayServer) ListenAndServe(ctx context.Context) error {
return fmt.Errorf("not implemented")
}
func NewNetDialer(cfg *conf.Config) *net.Dialer {
dialer := &net.Dialer{Timeout: constant.DefaultDialTimeOut}
dialer.SetMultipathTCP(cfg.Options.EnableMultipathTCP)
return dialer
}
func NewTCPListener(ctx context.Context, cfg *conf.Config) (net.Listener, error) {
addr, err := net.ResolveTCPAddr("tcp", cfg.Listen)
if err != nil {
return nil, err
}
lcfg := net.ListenConfig{}
lcfg.SetMultipathTCP(cfg.Options.EnableMultipathTCP)
return lcfg.Listen(ctx, "tcp", addr.String())
}