Update On Sun Apr 7 20:25:36 CEST 2024

This commit is contained in:
github-action[bot]
2024-04-07 20:25:37 +02:00
parent e2eecadc74
commit 258ae1804d
80 changed files with 1034 additions and 1053 deletions

View File

@@ -10,7 +10,7 @@ import (
cli "github.com/urfave/cli/v2"
)
var cliLogger = log.MustNewLogger("info").Sugar().Named("cli-app")
var cliLogger = log.MustNewLogger("info").Sugar().Named("cli")
func startAction(ctx *cli.Context) error {
cfg, err := InitConfigAndComponents()

View File

@@ -49,10 +49,10 @@ func loadConfig() (cfg *config.Config, err error) {
}
}
// init tls
// init tls when need
for _, cfg := range cfg.RelayConfigs {
if cfg.ListenType == constant.Listen_WSS || cfg.ListenType == constant.Listen_MWSS ||
cfg.TransportType == constant.Transport_WSS || cfg.TransportType == constant.Transport_MWSS {
if cfg.ListenType == constant.RelayTypeWSS || cfg.ListenType == constant.RelayTypeMWSS ||
cfg.TransportType == constant.RelayTypeWSS || cfg.TransportType == constant.RelayTypeMWSS {
if err := tls.InitTlsCfg(); err != nil {
return nil, err
}

View File

@@ -20,19 +20,16 @@ const (
SmuxMaxAliveDuration = 10 * time.Minute
SmuxMaxStreamCnt = 5
Listen_RAW = "raw"
Listen_WS = "ws"
Listen_WSS = "wss"
Listen_MWSS = "mwss"
Listen_MTCP = "mtcp"
Transport_RAW = "raw"
Transport_WS = "ws"
Transport_WSS = "wss"
Transport_MWSS = "mwss"
Transport_MTCP = "mtcp"
// todo add udp buffer size
BUFFER_POOL_SIZE = 1024 // support 512 connections
BUFFER_SIZE = 20 * 1024 // 20KB the maximum packet size of shadowsocks is about 16 KiB
)
// relay type
const (
RelayTypeRaw = "raw"
RelayTypeWS = "ws"
RelayTypeWSS = "wss"
RelayTypeMWSS = "mwss"
RelayTypeMTCP = "mtcp"
)

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/pkg/lb"
"go.uber.org/zap"
)
@@ -22,19 +24,19 @@ func (r *Config) Validate() error {
if r.Adjust() != nil {
return errors.New("adjust config failed")
}
if r.ListenType != constant.Listen_RAW &&
r.ListenType != constant.Listen_WS &&
r.ListenType != constant.Listen_WSS &&
r.ListenType != constant.Listen_MTCP &&
r.ListenType != constant.Listen_MWSS {
if r.ListenType != constant.RelayTypeRaw &&
r.ListenType != constant.RelayTypeWS &&
r.ListenType != constant.RelayTypeWSS &&
r.ListenType != constant.RelayTypeMTCP &&
r.ListenType != constant.RelayTypeMWSS {
return fmt.Errorf("invalid listen type:%s", r.ListenType)
}
if r.TransportType != constant.Transport_RAW &&
r.TransportType != constant.Transport_WS &&
r.TransportType != constant.Transport_WSS &&
r.TransportType != constant.Transport_MTCP &&
r.TransportType != constant.Transport_MWSS {
if r.TransportType != constant.RelayTypeRaw &&
r.TransportType != constant.RelayTypeWS &&
r.TransportType != constant.RelayTypeWSS &&
r.TransportType != constant.RelayTypeMTCP &&
r.TransportType != constant.RelayTypeMWSS {
return fmt.Errorf("invalid transport type:%s", r.ListenType)
}
@@ -106,16 +108,31 @@ func (r *Config) Different(new *Config) bool {
}
// todo make this shorter and more readable
func (r *Config) defaultLabel() string {
defaultLabel := fmt.Sprintf("<At=%s Over=%s TCP-To=%s UDP-To=%s Through=%s>",
r.Listen, r.ListenType, r.TCPRemotes, r.UDPRemotes, r.TransportType)
func (r *Config) DefaultLabel() string {
defaultLabel := fmt.Sprintf("<At=%s TCP-To=%s TP=%s>",
r.Listen, r.TCPRemotes, r.TransportType)
return defaultLabel
}
func (r *Config) Adjust() error {
if r.Label == "" {
r.Label = r.defaultLabel()
r.Label = r.DefaultLabel()
zap.S().Debugf("label is empty, set default label:%s", r.Label)
}
return nil
}
func (r *Config) ToTCPRemotes() lb.RoundRobin {
tcpNodeList := make([]*lb.Node, len(r.TCPRemotes))
for idx, addr := range r.TCPRemotes {
tcpNodeList[idx] = &lb.Node{
Address: addr,
Label: fmt.Sprintf("%s-%s", r.Label, addr),
}
}
return lb.NewRoundRobin(tcpNodeList)
}
func (r *Config) GetLoggerName() string {
return fmt.Sprintf("%s(%s<->%s)", r.Label, r.ListenType, r.TransportType)
}

View File

@@ -1,138 +1,50 @@
package relay
import (
"net"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/cmgr"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/relay/conf"
"github.com/Ehco1996/ehco/internal/transporter"
)
type Relay struct {
Name string // unique name for all relay
TransportType string
ListenType string
TP transporter.RelayTransporter
LocalTCPAddr *net.TCPAddr
closeTcpF func() error
cfg *conf.Config
l *zap.SugaredLogger
relayServer transporter.RelayServer
}
func (r *Relay) UniqueID() string {
return r.cfg.Label
}
func NewRelay(cfg *conf.Config, connMgr cmgr.Cmgr) (*Relay, error) {
localTCPAddr, err := net.ResolveTCPAddr("tcp", cfg.Listen)
base := transporter.NewBaseTransporter(cfg, connMgr)
s, err := transporter.NewRelayServer(cfg.ListenType, base)
if err != nil {
return nil, err
}
r := &Relay{
cfg: cfg,
l: zap.S().Named("relay"),
Name: cfg.Label,
LocalTCPAddr: localTCPAddr,
ListenType: cfg.ListenType,
TransportType: cfg.TransportType,
TP: transporter.NewRelayTransporter(cfg, connMgr),
relayServer: s,
cfg: cfg,
l: zap.S().Named("relay"),
}
return r, nil
}
func (r *Relay) ListenAndServe() error {
errCh := make(chan error)
if len(r.cfg.TCPRemotes) > 0 {
switch r.ListenType {
case constant.Listen_RAW:
go func() {
errCh <- r.RunLocalTCPServer()
}()
case constant.Listen_MTCP:
go func() {
errCh <- r.RunLocalMTCPServer()
}()
case constant.Listen_WS:
go func() {
errCh <- r.RunLocalWSServer()
}()
case constant.Listen_WSS:
go func() {
errCh <- r.RunLocalWSSServer()
}()
case constant.Listen_MWSS:
go func() {
errCh <- r.RunLocalMWSSServer()
}()
}
}
go func() {
r.l.Infof("Start TCP Relay Server:%s", r.cfg.DefaultLabel())
errCh <- r.relayServer.ListenAndServe()
}()
return <-errCh
}
func (r *Relay) Close() {
r.l.Infof("Close relay label: %s", r.Name)
if r.closeTcpF != nil {
err := r.closeTcpF()
if err != nil {
r.l.Errorf(err.Error())
}
r.l.Infof("Close TCP Relay Server:%s", r.cfg.DefaultLabel())
if err := r.relayServer.Close(); err != nil {
r.l.Errorf(err.Error())
}
}
func (r *Relay) RunLocalTCPServer() error {
rawServer, err := transporter.NewRawServer(r.LocalTCPAddr.String(), r.TP)
if err != nil {
return err
}
r.closeTcpF = func() error {
return rawServer.Close()
}
r.l.Infof("Start TCP relay Server: %s", r.Name)
return rawServer.ListenAndServe()
}
func (r *Relay) RunLocalMTCPServer() error {
tp := r.TP.(*transporter.RawClient)
mTCPServer := transporter.NewMTCPServer(r.LocalTCPAddr.String(), tp, r.l.Named("MTCPServer"))
r.closeTcpF = func() error {
return mTCPServer.Close()
}
r.l.Infof("Start MTCP relay Server: %s", r.Name)
return mTCPServer.ListenAndServe()
}
func (r *Relay) RunLocalWSServer() error {
tp := r.TP.(*transporter.RawClient)
wsServer := transporter.NewWSServer(r.LocalTCPAddr.String(), tp, r.l.Named("WSServer"))
r.closeTcpF = func() error {
return wsServer.Close()
}
r.l.Infof("Start WS relay Server: %s", r.Name)
return wsServer.ListenAndServe()
}
func (r *Relay) RunLocalWSSServer() error {
tp := r.TP.(*transporter.RawClient)
wssServer := transporter.NewWSSServer(r.LocalTCPAddr.String(), tp, r.l.Named("WSSServer"))
r.closeTcpF = func() error {
return wssServer.Close()
}
r.l.Infof("Start WSS relay Server: %s", r.Name)
return wssServer.ListenAndServe()
}
func (r *Relay) RunLocalMWSSServer() error {
tp := r.TP.(*transporter.RawClient)
mwssServer := transporter.NewMWSSServer(r.LocalTCPAddr.String(), tp, r.l.Named("MWSSServer"))
r.closeTcpF = func() error {
return mwssServer.Close()
}
r.l.Infof("Start MWSS relay Server: %s", r.Name)
return mwssServer.ListenAndServe()
}

View File

@@ -46,18 +46,18 @@ func NewServer(cfg *config.Config) (*Server, error) {
}
func (s *Server) startOneRelay(r *Relay) {
s.relayM.Store(r.Name, r)
s.relayM.Store(r.UniqueID(), r)
// mute closed network error for tcp server and mute http.ErrServerClosed for http server when config reload
if err := r.ListenAndServe(); err != nil &&
!errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
s.l.Errorf("start relay %s meet error: %s", r.Name, err)
s.l.Errorf("start relay %s meet error: %s", r.UniqueID(), err)
s.errCH <- err
}
}
func (s *Server) stopOneRelay(r *Relay) {
r.Close()
s.relayM.Delete(r.Name)
s.relayM.Delete(r.UniqueID())
}
func (s *Server) Start(ctx context.Context) error {

View File

@@ -0,0 +1,55 @@
package transporter
import (
"net"
"github.com/Ehco1996/ehco/internal/cmgr"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay/conf"
"github.com/Ehco1996/ehco/pkg/lb"
"go.uber.org/zap"
)
type baseTransporter struct {
cmgr cmgr.Cmgr
cfg *conf.Config
tCPRemotes lb.RoundRobin
l *zap.SugaredLogger
}
func NewBaseTransporter(cfg *conf.Config, cmgr cmgr.Cmgr) *baseTransporter {
return &baseTransporter{
cfg: cfg,
cmgr: cmgr,
tCPRemotes: cfg.ToTCPRemotes(),
l: zap.S().Named(cfg.GetLoggerName()),
}
}
func (b *baseTransporter) GetTCPListenAddr() (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", b.cfg.Listen)
}
func (b *baseTransporter) GetRemote() *lb.Node {
return b.tCPRemotes.Next()
}
func (b *baseTransporter) RelayTCPConn(c net.Conn, handshakeF TCPHandShakeF) error {
remote := b.GetRemote()
metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Inc()
defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Dec()
clonedRemote := remote.Clone()
rc, err := handshakeF(clonedRemote)
if err != nil {
return err
}
b.l.Infof("RelayTCPConn from %s to %s", c.LocalAddr(), remote.Address)
relayConn := conn.NewRelayConn(
b.cfg.Label, c, rc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
b.cmgr.AddConnection(relayConn)
defer b.cmgr.RemoveConnection(relayConn)
return relayConn.Transport(remote.Label)
}

View File

@@ -1,49 +0,0 @@
package transporter
import (
"github.com/Ehco1996/ehco/internal/constant"
)
// 全局pool
var BufferPool *BytePool
func init() {
BufferPool = NewBytePool(constant.BUFFER_POOL_SIZE, constant.BUFFER_SIZE)
}
// BytePool implements a leaky pool of []byte in the form of a bounded channel
type BytePool struct {
c chan []byte
size int
}
// NewBytePool creates a new BytePool bounded to the given maxSize, with new
// byte arrays sized based on width.
func NewBytePool(maxSize int, size int) (bp *BytePool) {
return &BytePool{
c: make(chan []byte, maxSize),
size: size,
}
}
// Get gets a []byte from the BytePool, or creates a new one if none are available in the pool.
func (bp *BytePool) Get() (b []byte) {
select {
case b = <-bp.c:
// reuse existing buffer
default:
// create new buffer
b = make([]byte, bp.size)
}
return
}
// Put returns the given Buffer to the BytePool.
func (bp *BytePool) Put(b []byte) {
select {
case bp.c <- b:
// buffer went back into pool
default:
// buffer didn't go back into pool, just discard
}
}

View File

@@ -1,42 +1,54 @@
package transporter
import (
"fmt"
"net"
"github.com/Ehco1996/ehco/internal/cmgr"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/relay/conf"
"github.com/Ehco1996/ehco/pkg/lb"
)
// RelayTransporter
type RelayTransporter interface {
dialRemote(remote *lb.Node) (net.Conn, error)
HandleTCPConn(c net.Conn, remote *lb.Node) error
GetRemote() *lb.Node
type TCPHandShakeF func(remote *lb.Node) (net.Conn, error)
type RelayClient interface {
TCPHandShake(remote *lb.Node) (net.Conn, error)
RelayTCPConn(c net.Conn, handshakeF TCPHandShakeF) error
}
func NewRelayTransporter(cfg *conf.Config, connMgr cmgr.Cmgr) RelayTransporter {
tcpNodeList := make([]*lb.Node, len(cfg.TCPRemotes))
for idx, addr := range cfg.TCPRemotes {
tcpNodeList[idx] = &lb.Node{
Address: addr,
Label: fmt.Sprintf("%s-%s", cfg.Label, addr),
}
func NewRelayClient(relayType string, base *baseTransporter) (RelayClient, error) {
switch relayType {
case constant.RelayTypeRaw:
return newRawClient(base)
case constant.RelayTypeWS:
return newWsClient(base)
case constant.RelayTypeWSS:
return newWssClient(base)
case constant.RelayTypeMWSS:
return newMwssClient(base)
case constant.RelayTypeMTCP:
return newMtcpClient(base)
default:
panic("unsupported transport type")
}
}
type RelayServer interface {
ListenAndServe() error
Close() error
}
func NewRelayServer(relayType string, base *baseTransporter) (RelayServer, error) {
switch relayType {
case constant.RelayTypeRaw:
return newRawServer(base)
case constant.RelayTypeWS:
return newWsServer(base)
case constant.RelayTypeWSS:
return newWssServer(base)
case constant.RelayTypeMWSS:
return newMwssServer(base)
case constant.RelayTypeMTCP:
return newMtcpServer(base)
default:
panic("unsupported transport type")
}
raw := newRawClient(cfg.Label, lb.NewRoundRobin(tcpNodeList), connMgr)
switch cfg.TransportType {
case constant.Transport_RAW:
return raw
case constant.Transport_WS:
return newWsClient(raw)
case constant.Transport_WSS:
return newWSSClient(raw)
case constant.Transport_MWSS:
return newMWSSClient(raw)
case constant.Transport_MTCP:
return newMTCPClient(raw)
}
return nil
}

View File

@@ -136,3 +136,66 @@ func (tr *smuxTransporter) Dial(ctx context.Context, addr string) (conn net.Conn
curSM.streamList = append(curSM.streamList, stream)
return stream, nil
}
type muxServer interface {
ListenAndServe() error
Accept() (net.Conn, error)
Close() error
mux(net.Conn)
}
func newMuxServer(listenAddr string, l *zap.SugaredLogger) *muxServerImpl {
return &muxServerImpl{
errChan: make(chan error, 1),
connChan: make(chan net.Conn, 1024),
listenAddr: listenAddr,
l: l,
}
}
type muxServerImpl struct {
errChan chan error
connChan chan net.Conn
listenAddr string
l *zap.SugaredLogger
}
func (s *muxServerImpl) Accept() (net.Conn, error) {
select {
case conn := <-s.connChan:
return conn, nil
case err := <-s.errChan:
return nil, err
}
}
func (s *muxServerImpl) mux(conn net.Conn) {
defer conn.Close()
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Server(conn, cfg)
if err != nil {
s.l.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.listenAddr, err)
return
}
defer session.Close() // nolint: errcheck
s.l.Debugf("session init %s %s", conn.RemoteAddr(), s.listenAddr)
defer s.l.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.listenAddr)
for {
stream, err := session.AcceptStream()
if err != nil {
s.l.Errorf("accept stream err: %s", err)
break
}
select {
case s.connChan <- stream:
default:
stream.Close() // nolint: errcheck
s.l.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
}
}

View File

@@ -5,40 +5,33 @@ import (
"net"
"time"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/cmgr"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/pkg/lb"
)
var (
_ RelayClient = &RawClient{}
_ RelayServer = &RawServer{}
)
type RawClient struct {
relayLabel string
cmgr cmgr.Cmgr
tCPRemotes lb.RoundRobin
l *zap.SugaredLogger
*baseTransporter
dialer *net.Dialer
}
func newRawClient(relayLabel string, tcpRemotes lb.RoundRobin, cmgr cmgr.Cmgr) *RawClient {
func newRawClient(base *baseTransporter) (*RawClient, error) {
r := &RawClient{
cmgr: cmgr,
relayLabel: relayLabel,
tCPRemotes: tcpRemotes,
l: zap.S().Named(relayLabel),
baseTransporter: base,
dialer: &net.Dialer{Timeout: constant.DialTimeOut},
}
return r
return r, nil
}
func (raw *RawClient) GetRemote() *lb.Node {
return raw.tCPRemotes.Next()
}
func (raw *RawClient) dialRemote(remote *lb.Node) (net.Conn, error) {
func (raw *RawClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
d := net.Dialer{Timeout: constant.DialTimeOut}
rc, err := d.Dial("tcp", remote.Address)
rc, err := raw.dialer.Dial("tcp", remote.Address)
if err != nil {
return nil, err
}
@@ -48,38 +41,32 @@ func (raw *RawClient) dialRemote(remote *lb.Node) (net.Conn, error) {
return rc, nil
}
func (raw *RawClient) HandleTCPConn(c net.Conn, remote *lb.Node) error {
metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Inc()
defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Dec()
clonedRemote := remote.Clone()
rc, err := raw.dialRemote(clonedRemote)
if err != nil {
return err
}
raw.l.Infof("HandleTCPConn from %s to %s", c.LocalAddr(), remote.Address)
relayConn := conn.NewRelayConn(raw.relayLabel, c, rc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
raw.cmgr.AddConnection(relayConn)
defer raw.cmgr.RemoveConnection(relayConn)
return relayConn.Transport(remote.Label)
}
type RawServer struct {
rtp RelayTransporter
lis *net.TCPListener
l *zap.SugaredLogger
*baseTransporter
localTCPAddr *net.TCPAddr
lis *net.TCPListener
relayer RelayClient
}
func NewRawServer(addr string, rtp RelayTransporter) (*RawServer, error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
func newRawServer(base *baseTransporter) (*RawServer, error) {
addr, err := base.GetTCPListenAddr()
if err != nil {
return nil, err
}
lis, err := net.ListenTCP("tcp", tcpAddr)
lis, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, err
}
return &RawServer{lis: lis, rtp: rtp}, nil
relayer, err := NewRelayClient(base.cfg.TransportType, base)
if err != nil {
return nil, err
}
return &RawServer{
lis: lis,
baseTransporter: base,
localTCPAddr: addr,
relayer: relayer,
}, nil
}
func (s *RawServer) Close() error {
@@ -93,13 +80,8 @@ func (s *RawServer) ListenAndServe() error {
return err
}
go func(c net.Conn) {
remote := s.rtp.GetRemote()
metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Inc()
defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Dec()
if err := s.rtp.HandleTCPConn(c, remote); err != nil {
s.l.Errorf("HandleTCPConn meet error tp:%s from:%s to:%s err:%s",
s.rtp,
c.RemoteAddr(), remote.Address, err)
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}(c)
}

View File

@@ -1,4 +1,3 @@
// nolint: errcheck
package transporter
import (
@@ -7,29 +6,32 @@ import (
"time"
"github.com/xtaci/smux"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/pkg/lb"
)
type MTCPClient struct {
var (
_ RelayClient = &MtcpClient{}
_ RelayServer = &MtcpServer{}
)
type MtcpClient struct {
*RawClient
dialer *net.Dialer
mtp *smuxTransporter
muxTP *smuxTransporter
}
func newMTCPClient(raw *RawClient) *MTCPClient {
dialer := &net.Dialer{Timeout: constant.DialTimeOut}
c := &MTCPClient{dialer: dialer, RawClient: raw}
mtp := NewSmuxTransporter(raw.l.Named("mtcp"), c.initNewSession)
c.mtp = mtp
return c
func newMtcpClient(base *baseTransporter) (*MtcpClient, error) {
raw, err := newRawClient(base)
if err != nil {
return nil, err
}
c := &MtcpClient{RawClient: raw}
c.muxTP = NewSmuxTransporter(raw.l.Named("mtcp"), c.initNewSession)
return c, nil
}
func (c *MTCPClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
func (c *MtcpClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
rc, err := c.dialer.Dial("tcp", addr)
if err != nil {
return nil, err
@@ -45,9 +47,9 @@ func (c *MTCPClient) initNewSession(ctx context.Context, addr string) (*smux.Ses
return session, nil
}
func (s *MTCPClient) dialRemote(remote *lb.Node) (net.Conn, error) {
func (s *MtcpClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
mtcpc, err := s.mtp.Dial(context.TODO(), remote.Address)
mtcpc, err := s.muxTP.Dial(context.TODO(), remote.Address)
if err != nil {
return nil, err
}
@@ -57,87 +59,28 @@ func (s *MTCPClient) dialRemote(remote *lb.Node) (net.Conn, error) {
return mtcpc, nil
}
func (s *MTCPClient) HandleTCPConn(c net.Conn, remote *lb.Node) error {
clonedRemote := remote.Clone()
mtcpc, err := s.dialRemote(clonedRemote)
type MtcpServer struct {
*RawServer
*muxServerImpl
}
func newMtcpServer(base *baseTransporter) (*MtcpServer, error) {
raw, err := newRawServer(base)
if err != nil {
return err
return nil, err
}
s.l.Infof("HandleTCPConn from:%s to:%s", c.LocalAddr(), remote.Address)
relayConn := conn.NewRelayConn(s.relayLabel, c, mtcpc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
s.cmgr.AddConnection(relayConn)
defer s.cmgr.RemoveConnection(relayConn)
return relayConn.Transport(remote.Label)
s := &MtcpServer{
RawServer: raw,
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mtcp")),
}
return s, nil
}
type MTCPServer struct {
raw *RawClient
listenAddr string
listener net.Listener
l *zap.SugaredLogger
errChan chan error
connChan chan net.Conn
}
func NewMTCPServer(listenAddr string, raw *RawClient, l *zap.SugaredLogger) *MTCPServer {
return &MTCPServer{
l: l,
raw: raw,
listenAddr: listenAddr,
errChan: make(chan error, 1),
connChan: make(chan net.Conn, 1024),
}
}
func (s *MTCPServer) mux(conn net.Conn) {
defer conn.Close()
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Server(conn, cfg)
if err != nil {
s.l.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.listenAddr, err)
return
}
defer session.Close()
s.l.Debugf("session init %s %s", conn.RemoteAddr(), s.listenAddr)
defer s.l.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.listenAddr)
for {
stream, err := session.AcceptStream()
if err != nil {
s.l.Errorf("accept stream err: %s", err)
break
}
select {
case s.connChan <- stream:
default:
stream.Close()
s.l.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
}
}
func (s *MTCPServer) Accept() (conn net.Conn, err error) {
select {
case conn = <-s.connChan:
case err = <-s.errChan:
}
return
}
func (s *MTCPServer) ListenAndServe() error {
lis, err := net.Listen("tcp", s.listenAddr)
if err != nil {
return err
}
s.listener = lis
func (s *MtcpServer) ListenAndServe() error {
go func() {
for {
c, err := lis.Accept()
c, err := s.lis.Accept()
if err != nil {
s.errChan <- err
continue
@@ -152,14 +95,13 @@ func (s *MTCPServer) ListenAndServe() error {
return e
}
go func(c net.Conn) {
remote := s.raw.GetRemote()
if err := s.raw.HandleTCPConn(c, remote); err != nil {
s.l.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err)
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}(conn)
}
}
func (s *MTCPServer) Close() error {
return s.listener.Close()
func (s *MtcpServer) Close() error {
return s.lis.Close()
}

View File

@@ -8,27 +8,35 @@ import (
"github.com/gobwas/ws"
"github.com/labstack/echo/v4"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/web"
"github.com/Ehco1996/ehco/pkg/lb"
)
var (
_ RelayClient = &WsClient{}
_ RelayServer = &WsServer{}
)
type WsClient struct {
*RawClient
*baseTransporter
dialer *ws.Dialer
}
func newWsClient(raw *RawClient) *WsClient {
return &WsClient{RawClient: raw}
func newWsClient(base *baseTransporter) (*WsClient, error) {
s := &WsClient{
baseTransporter: base,
dialer: &ws.Dialer{Timeout: constant.DialTimeOut},
}
return s, nil
}
func (s *WsClient) dialRemote(remote *lb.Node) (net.Conn, error) {
func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
d := ws.Dialer{Timeout: constant.DialTimeOut}
wsc, _, _, err := d.Dial(context.TODO(), remote.Address+"/handshake/")
wsc, _, _, err := s.dialer.Dial(context.TODO(), remote.Address+"/handshake/")
if err != nil {
return nil, err
}
@@ -38,57 +46,51 @@ func (s *WsClient) dialRemote(remote *lb.Node) (net.Conn, error) {
return wsc, nil
}
func (s *WsClient) HandleTCPConn(c net.Conn, remote *lb.Node) error {
clonedRemote := remote.Clone()
wsc, err := s.dialRemote(clonedRemote)
if err != nil {
return err
}
s.l.Infof("HandleTCPConn from %s to %s", c.LocalAddr(), remote.Address)
relayConn := conn.NewRelayConn(
s.relayLabel, c, wsc,
conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
s.cmgr.AddConnection(relayConn)
defer s.cmgr.RemoveConnection(relayConn)
return relayConn.Transport(remote.Label)
}
type WsServer struct {
*baseTransporter
type WSServer struct {
raw *RawClient
e *echo.Echo
httpServer *http.Server
l *zap.SugaredLogger
relayer RelayClient
}
func NewWSServer(listenAddr string, raw *RawClient, l *zap.SugaredLogger) *WSServer {
s := &WSServer{
l: l,
raw: raw,
httpServer: &http.Server{Addr: listenAddr, ReadHeaderTimeout: 30 * time.Second},
func newWsServer(base *baseTransporter) (*WsServer, error) {
localTCPAddr, err := base.GetTCPListenAddr()
if err != nil {
return nil, err
}
s := &WsServer{
baseTransporter: base,
httpServer: &http.Server{
Addr: localTCPAddr.String(), ReadHeaderTimeout: 30 * time.Second,
},
}
e := web.NewEchoServer()
e.GET("/", echo.WrapHandler(web.MakeIndexF()))
e.GET("/handshake/", echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
s.e = e
s.httpServer.Handler = e
return s
relayer, err := NewRelayClient(base.cfg.TransportType, base)
if err != nil {
return nil, err
}
s.relayer = relayer
return s, nil
}
func (s *WSServer) ListenAndServe() error {
func (s *WsServer) ListenAndServe() error {
return s.e.StartServer(s.httpServer)
}
func (s *WSServer) Close() error {
func (s *WsServer) Close() error {
return s.e.Close()
}
func (s *WSServer) HandleRequest(w http.ResponseWriter, req *http.Request) {
func (s *WsServer) HandleRequest(w http.ResponseWriter, req *http.Request) {
wsc, _, _, err := ws.UpgradeHTTP(req, w)
if err != nil {
return
}
remote := s.raw.GetRemote()
if err := s.raw.HandleTCPConn(wsc, remote); err != nil {
s.l.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", wsc.RemoteAddr(), remote.Address, err)
if err := s.RelayTCPConn(wsc, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}

View File

@@ -1,64 +1,38 @@
// nolint: errcheck
package transporter
import (
"context"
"net"
"time"
"github.com/gobwas/ws"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/metrics"
mytls "github.com/Ehco1996/ehco/internal/tls"
"github.com/Ehco1996/ehco/pkg/lb"
)
type WSSClient struct {
WsClient
var (
_ RelayClient = &WssClient{}
_ RelayServer = &WssServer{}
)
type WssClient struct {
*WsClient
}
func newWSSClient(raw *RawClient) *WSSClient {
return &WSSClient{*newWsClient(raw)}
}
func (s *WSSClient) dialRemote(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
d := ws.Dialer{TLSConfig: mytls.DefaultTLSConfig}
wssc, _, _, err := d.Dial(context.TODO(), remote.Address+"/handshake/")
func newWssClient(base *baseTransporter) (*WssClient, error) {
wc, err := newWsClient(base)
if err != nil {
println("wss called", err.Error())
return nil, err
}
latency := time.Since(t1)
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
remote.HandShakeDuration = latency
return wssc, nil
// insert tls config
wc.dialer.TLSConfig = mytls.DefaultTLSConfig
return &WssClient{WsClient: wc}, nil
}
func (s *WSSClient) HandleTCPConn(c net.Conn, remote *lb.Node) error {
clonedRemote := remote.Clone()
wssc, err := s.dialRemote(clonedRemote)
type WssServer struct {
*WsServer
}
func newWssServer(base *baseTransporter) (*WssServer, error) {
wsServer, err := newWsServer(base)
if err != nil {
return err
return nil, err
}
s.l.Infof("HandleTCPConn from %s to %s", c.RemoteAddr(), remote.Address)
relayConn := conn.NewRelayConn(s.relayLabel, c, wssc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
s.cmgr.AddConnection(relayConn)
defer s.cmgr.RemoveConnection(relayConn)
return relayConn.Transport(remote.Label)
}
type WSSServer struct{ WSServer }
func NewWSSServer(listenAddr string, raw *RawClient, l *zap.SugaredLogger) *WSSServer {
wsServer := NewWSServer(listenAddr, raw, l)
return &WSSServer{WSServer: *wsServer}
}
func (s *WSSServer) ListenAndServe() error {
s.httpServer.TLSConfig = mytls.DefaultTLSConfig
return s.WSServer.ListenAndServe()
// insert tls config
wsServer.httpServer.TLSConfig = mytls.DefaultTLSConfig
return &WssServer{WsServer: wsServer}, nil
}

View File

@@ -1,9 +1,7 @@
// nolint: errcheck
package transporter
import (
"context"
"crypto/tls"
"net"
"net/http"
"time"
@@ -11,31 +9,34 @@ import (
"github.com/gobwas/ws"
"github.com/labstack/echo/v4"
"github.com/xtaci/smux"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/metrics"
mytls "github.com/Ehco1996/ehco/internal/tls"
"github.com/Ehco1996/ehco/internal/web"
"github.com/Ehco1996/ehco/pkg/lb"
)
type MWSSClient struct {
*RawClient
dialer *ws.Dialer
mtp *smuxTransporter
var (
_ RelayClient = &MwssClient{}
_ RelayServer = &MwssServer{}
_ muxServer = &MwssServer{}
)
type MwssClient struct {
*WssClient
muxTP *smuxTransporter
}
func newMWSSClient(raw *RawClient) *MWSSClient {
dialer := &ws.Dialer{TLSConfig: mytls.DefaultTLSConfig, Timeout: constant.DialTimeOut}
c := &MWSSClient{dialer: dialer, RawClient: raw}
mtp := NewSmuxTransporter(raw.l.Named("mwss"), c.initNewSession)
c.mtp = mtp
return c
func newMwssClient(base *baseTransporter) (*MwssClient, error) {
wc, err := newWssClient(base)
if err != nil {
return nil, err
}
c := &MwssClient{WssClient: wc}
c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession)
return c, nil
}
func (c *MWSSClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
func (c *MwssClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
rc, _, _, err := c.dialer.Dial(ctx, addr)
if err != nil {
return nil, err
@@ -51,68 +52,39 @@ func (c *MWSSClient) initNewSession(ctx context.Context, addr string) (*smux.Ses
return session, nil
}
func (s *MWSSClient) dialRemote(remote *lb.Node) (net.Conn, error) {
func (s *MwssClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
mwssc, err := s.mtp.Dial(context.TODO(), remote.Address+"/handshake/")
mwssc, err := s.muxTP.Dial(context.TODO(), remote.Address+"/handshake/")
if err != nil {
return nil, err
}
latency := time.Since(t1)
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
remote.HandShakeDuration = latency
return mwssc, nil
}
func (s *MWSSClient) HandleTCPConn(c net.Conn, remote *lb.Node) error {
clonedRemote := remote.Clone()
mwsc, err := s.dialRemote(clonedRemote)
type MwssServer struct {
*WssServer
*muxServerImpl
}
func newMwssServer(base *baseTransporter) (*MwssServer, error) {
wssServer, err := newWssServer(base)
if err != nil {
return err
return nil, err
}
s.l.Infof("HandleTCPConn from:%s to:%s", c.LocalAddr(), remote.Address)
relayConn := conn.NewRelayConn(s.relayLabel, c, mwsc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
s.cmgr.AddConnection(relayConn)
defer s.cmgr.RemoveConnection(relayConn)
return relayConn.Transport(remote.Label)
s := &MwssServer{
WssServer: wssServer,
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")),
}
s.e.GET("/handshake/", echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
return s, nil
}
type MWSSServer struct {
raw *RawClient
httpServer *http.Server
l *zap.SugaredLogger
connChan chan net.Conn
errChan chan error
}
func NewMWSSServer(listenAddr string, raw *RawClient, l *zap.SugaredLogger) *MWSSServer {
s := &MWSSServer{
raw: raw,
l: l,
errChan: make(chan error, 1),
connChan: make(chan net.Conn, 1024),
}
e := web.NewEchoServer()
e.GET("/", echo.WrapHandler(web.MakeIndexF()))
e.GET("/handshake/", echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
s.httpServer = &http.Server{
Addr: listenAddr,
Handler: e,
TLSConfig: mytls.DefaultTLSConfig,
ReadHeaderTimeout: 30 * time.Second,
}
return s
}
func (s *MWSSServer) ListenAndServe() error {
lis, err := net.Listen("tcp", s.httpServer.Addr)
if err != nil {
return err
}
func (s *MwssServer) ListenAndServe() error {
go func() {
s.errChan <- s.httpServer.Serve(tls.NewListener(lis, s.httpServer.TLSConfig))
s.errChan <- s.e.StartServer(s.httpServer)
}()
for {
@@ -121,15 +93,14 @@ func (s *MWSSServer) ListenAndServe() error {
return e
}
go func(c net.Conn) {
remote := s.raw.GetRemote()
if err := s.raw.HandleTCPConn(c, remote); err != nil {
s.l.Errorf("HandleTCPConn meet error from:%s to:%s err:%s", c.RemoteAddr(), remote.Address, err)
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}(conn)
}
}
func (s *MWSSServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
func (s *MwssServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
conn, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
s.l.Error(err)
@@ -138,44 +109,6 @@ func (s *MWSSServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
s.mux(conn)
}
func (s *MWSSServer) mux(conn net.Conn) {
defer conn.Close()
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Server(conn, cfg)
if err != nil {
s.l.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.httpServer.Addr, err)
return
}
defer session.Close()
s.l.Debugf("session init %s %s", conn.RemoteAddr(), s.httpServer.Addr)
defer s.l.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.httpServer.Addr)
for {
stream, err := session.AcceptStream()
if err != nil {
s.l.Errorf("accept stream err: %s", err)
break
}
select {
case s.connChan <- stream:
default:
stream.Close()
s.l.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
}
}
func (s *MWSSServer) Accept() (conn net.Conn, err error) {
select {
case conn = <-s.connChan:
case err = <-s.errChan:
}
return
}
func (s *MWSSServer) Close() error {
return s.httpServer.Close()
func (s *MwssServer) Close() error {
return s.e.Close()
}