Update On Fri Aug 16 20:33:05 CEST 2024

This commit is contained in:
github-action[bot]
2024-08-16 20:33:05 +02:00
parent 03f9f755df
commit 7390f900fb
199 changed files with 1989 additions and 14787 deletions

View File

@@ -17,7 +17,4 @@ ehco 现在提供 SaaS软件即服务版本这是一个全托管的解
- tcp/udp relay
- tunnel relay (ws/wss/mwss/mtcp)
- proxy server (内嵌了完整版本的 xray)
- 监控报警 (Prometheus/Grafana)
- WebAPI (http://web_host:web_port)
- [更多功能请探索文档](https://docs.ehco-relay.cc/)

View File

@@ -1,36 +0,0 @@
{
"log_level": "debug",
"relay_configs": [
{
"label": "relay-to-http-success",
"listen": "127.0.0.1:1234",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": ["google.com:80"],
"blocked_protocols": ["tls"]
},
{
"label": "relay-to-http-fail",
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": ["google.com:80"],
"blocked_protocols": ["http"]
},
{
"label": "relay-to-tls-success",
"listen": "127.0.0.1:1236",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": ["google.com:443"]
},
{
"label": "relay-to-tls-fail",
"listen": "127.0.0.1:1237",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": ["google.com:443"],
"blocked_protocols": ["tls"]
}
]
}

View File

@@ -1,15 +0,0 @@
{
"web_port": 9000,
"relay_configs": [
{
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "ws",
"tcp_remotes": ["ws://0.0.0.0:8787"],
"ws_config": {
"path": "pwd",
"remote_addr": "127.0.0.1:5201"
}
}
]
}

View File

@@ -9,71 +9,41 @@
"listen_type": "raw",
"transport_type": "raw",
"label": "relay1",
"tcp_remotes": ["0.0.0.0:5201"],
"udp_remotes": ["0.0.0.0:5201"]
"tcp_remotes": [
"0.0.0.0:5201"
]
},
{
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "ws",
"tcp_remotes": ["ws://0.0.0.0:2443"],
"udp_remotes": ["0.0.0.0:5201"]
"tcp_remotes": [
"ws://0.0.0.0:2443"
]
},
{
"listen": "127.0.0.1:1236",
"listen_type": "raw",
"transport_type": "wss",
"tcp_remotes": ["wss://0.0.0.0:3443"],
"udp_remotes": ["0.0.0.0:5201"]
},
{
"listen": "127.0.0.1:1237",
"listen_type": "raw",
"transport_type": "mwss",
"tcp_remotes": ["wss://0.0.0.0:4443"],
"udp_remotes": ["0.0.0.0:5201"]
},
{
"listen": "127.0.0.1:1238",
"listen_type": "raw",
"transport_type": "mtcp",
"tcp_remotes": ["0.0.0.0:4444"],
"udp_remotes": ["0.0.0.0:5201"]
"tcp_remotes": [
"wss://0.0.0.0:3443"
]
},
{
"listen": "127.0.0.1:2443",
"listen_type": "ws",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"udp_remotes": []
"tcp_remotes": [
"0.0.0.0:5201"
]
},
{
"listen": "127.0.0.1:3443",
"listen_type": "wss",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"udp_remotes": []
},
{
"listen": "127.0.0.1:4443",
"listen_type": "mwss",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"udp_remotes": []
},
{
"listen": "127.0.0.1:4444",
"listen_type": "mtcp",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"udp_remotes": []
},
{
"label": "ping_test",
"listen": "127.0.0.1:8888",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": ["8.8.8.8:5201", "google.com:5201"]
"tcp_remotes": [
"0.0.0.0:5201"
]
}
]
}
}

View File

@@ -1,23 +0,0 @@
{
"log_level": "info",
"relay_configs": [
{
"label": "client",
"listen": "127.0.0.1:1234",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": [
"0.0.0.0:1235"
]
},
{
"label": "server",
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": [
"0.0.0.0:5201"
]
}
]
}

View File

@@ -1,24 +0,0 @@
{
"relay_configs": [
{
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "mws",
"tcp_remotes": ["ws://0.0.0.0:2443"],
"ws_config": {
"path": "pwd",
"remote_addr": "127.0.0.1:5201"
}
},
{
"listen": "127.0.0.1:2443",
"listen_type": "mws",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"ws_config": {
"path": "pwd",
"remote_addr": "127.0.0.1:5201"
}
}
]
}

View File

@@ -1,14 +0,0 @@
{
"relay_configs": [
{
"listen": "127.0.0.1:1234",
"listen_type": "raw",
"transport_type": "raw",
"label": "iperf3",
"tcp_remotes": [
"0.0.0.0:5201"
],
"max_read_rate_kbps": 10000
}
]
}

View File

@@ -1,20 +0,0 @@
{
"web_port": 9000,
"log_level": "debug",
"reload_interval": 10,
"relay_configs": [
{
"listen": "127.0.0.1:1234",
"listen_type": "raw",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"udp_remotes": ["0.0.0.0:5201"]
}
],
"sub_configs": [
{
"name": "sub1",
"url": "xxx"
}
]
}

View File

@@ -1,16 +0,0 @@
{
"web_port": 9000,
"web_token": "",
"web_auth_user": "user",
"web_auth_pass": "pass",
"log_level": "info",
"relay_configs": [
{
"listen": "127.0.0.1:1234",
"listen_type": "raw",
"transport_type": "raw",
"label": "iperf3",
"tcp_remotes": ["0.0.0.0:5201"]
}
]
}

View File

@@ -1,29 +0,0 @@
{
"web_port": 9000,
"log_level": "info",
"enable_ping": true,
"relay_sync_interval": 6,
"relay_configs": [
{
"listen": "127.0.0.1:1234",
"listen_type": "raw",
"transport_type": "raw",
"label": "raw",
"tcp_remotes": ["192.168.31.30:5201"]
},
{
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "ws",
"label": "ws",
"tcp_remotes": ["ws://192.168.31.30:2443"]
},
{
"listen": "127.0.0.1:1236",
"listen_type": "raw",
"transport_type": "wss",
"label": "wss",
"tcp_remotes": ["wss://192.168.31.31:2443"]
}
]
}

View File

@@ -1,24 +0,0 @@
{
"relay_configs": [
{
"listen": "127.0.0.1:1235",
"listen_type": "raw",
"transport_type": "ws",
"tcp_remotes": ["ws://0.0.0.0:2443"],
"ws_config": {
"path": "pwd",
"remote_addr": "127.0.0.1:5201"
}
},
{
"listen": "127.0.0.1:2443",
"listen_type": "ws",
"transport_type": "raw",
"tcp_remotes": ["0.0.0.0:5201"],
"ws_config": {
"path": "pwd",
"remote_addr": "127.0.0.1:5201"
}
}
]
}

View File

@@ -20,10 +20,10 @@ require (
github.com/sagernet/sing-box v1.9.3
github.com/stretchr/testify v1.9.0
github.com/urfave/cli/v2 v2.27.2
github.com/xtaci/smux v1.5.24
github.com/xtls/xray-core v1.8.16
go.uber.org/atomic v1.11.0
go.uber.org/zap v1.27.0
golang.org/x/sync v0.7.0
golang.org/x/time v0.5.0
google.golang.org/grpc v1.65.0
gopkg.in/yaml.v3 v3.0.1
@@ -115,7 +115,6 @@ require (
golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/net v0.26.0 // indirect
golang.org/x/sync v0.7.0 // indirect
golang.org/x/sys v0.21.0 // indirect
golang.org/x/text v0.16.0 // indirect
golang.org/x/tools v0.22.0 // indirect

View File

@@ -317,8 +317,6 @@ github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw=
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQutDLsQv2Deui4iYQ6DWTxR14g6m8Wv88+Xqk=
github.com/xtaci/smux v1.5.24 h1:77emW9dtnOxxOQ5ltR+8BbsX1kzcOxQ5gB+aaV9hXOY=
github.com/xtaci/smux v1.5.24/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
github.com/xtls/reality v0.0.0-20240429224917-ecc4401070cc h1:0Nj8T1n7F7+v4vRVroaJIvY6R0vNABLfPH+lzPHRJvI=
github.com/xtls/reality v0.0.0-20240429224917-ecc4401070cc/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
github.com/xtls/xray-core v1.8.16 h1:PhbpdREAIvDS7xmxR6Sdpkx0h5ugmf6wIoWECWtJ0kE=

View File

@@ -9,7 +9,6 @@ import (
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay"
"github.com/Ehco1996/ehco/internal/relay/conf"
"github.com/Ehco1996/ehco/internal/tls"
"github.com/Ehco1996/ehco/internal/web"
"github.com/Ehco1996/ehco/pkg/buffer"
"github.com/Ehco1996/ehco/pkg/log"
@@ -42,24 +41,11 @@ func loadConfig() (cfg *config.Config, err error) {
if TCPRemoteAddr != "" {
cfg.RelayConfigs[0].TCPRemotes = []string{TCPRemoteAddr}
}
if UDPRemoteAddr != "" {
cfg.RelayConfigs[0].UDPRemotes = []string{UDPRemoteAddr}
}
if err := cfg.Adjust(); err != nil {
return nil, err
}
}
// init tls when need
for _, cfg := range cfg.RelayConfigs {
if cfg.ListenType == constant.RelayTypeWSS || cfg.ListenType == constant.RelayTypeMWSS ||
cfg.TransportType == constant.RelayTypeWSS || cfg.TransportType == constant.RelayTypeMWSS {
if err := tls.InitTlsCfg(); err != nil {
return nil, err
}
break
}
}
return cfg, nil
}

View File

@@ -10,7 +10,6 @@ var (
LocalAddr string
ListenType constant.RelayType
TCPRemoteAddr string
UDPRemoteAddr string
TransportType constant.RelayType
ConfigPath string
WebPort int
@@ -39,16 +38,10 @@ var RootFlags = []cli.Flag{
},
&cli.StringFlag{
Name: "r,remote",
Usage: "TCP 转发地址,例如 0.0.0.0:5201通过 ws 隧道转发时应为 ws://0.0.0.0:2443",
Usage: "转发地址,例如 0.0.0.0:5201通过 ws 隧道转发时应为 ws://0.0.0.0:2443",
EnvVars: []string{"EHCO_REMOTE_ADDR"},
Destination: &TCPRemoteAddr,
},
&cli.StringFlag{
Name: "ur,udp_remote",
Usage: "UDP 转发地址,例如 0.0.0.0:1234",
EnvVars: []string{"EHCO_UDP_REMOTE_ADDR"},
Destination: &UDPRemoteAddr,
},
&cli.StringFlag{
Name: "tt,transport_type",
Value: "raw",

View File

@@ -16,7 +16,8 @@ const (
ConnectionTypeClosed = "closed"
)
// connection manager interface
// connection manager interface/
// TODO support closed connection
type Cmgr interface {
ListConnections(connType string, page, pageSize int) []conn.RelayConn

View File

@@ -7,9 +7,10 @@ import (
"strings"
"time"
myhttp "github.com/Ehco1996/ehco/pkg/http"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/relay/conf"
"github.com/Ehco1996/ehco/internal/tls"
myhttp "github.com/Ehco1996/ehco/pkg/http"
"github.com/Ehco1996/ehco/pkg/sub"
xConf "github.com/xtls/xray-core/infra/conf"
"go.uber.org/zap"
@@ -121,6 +122,15 @@ func (c *Config) Adjust() error {
}
labelMap[r.Label] = struct{}{}
}
// init tls when need
for _, r := range c.RelayConfigs {
if r.ListenType == constant.RelayTypeWSS {
if err := tls.InitTlsCfg(); err != nil {
return err
}
break
}
}
return nil
}

View File

@@ -3,18 +3,160 @@ package conn
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/pkg/buffer"
"github.com/Ehco1996/ehco/pkg/bytes"
"go.uber.org/zap"
)
var idleTimeout = 30 * time.Second
// RelayConn is the interface that represents a relay connection.
// it contains two connections: clientConn and remoteConn
// clientConn is the connection from the client to the relay server
// remoteConn is the connection from the relay server to the remote server
// and the main function is to transport data between the two connections
type RelayConn interface {
// Transport transports data between the client and the remote connection.
Transport() error
GetRelayLabel() string
GetStats() *Stats
Close() error
}
type RelayConnOption func(*relayConnImpl)
func NewRelayConn(clientConn, remoteConn net.Conn, opts ...RelayConnOption) RelayConn {
rci := &relayConnImpl{
clientConn: clientConn,
remoteConn: remoteConn,
Stats: &Stats{},
}
for _, opt := range opts {
opt(rci)
}
if rci.l == nil {
rci.l = zap.S().Named(rci.RelayLabel)
}
return rci
}
type relayConnImpl struct {
clientConn net.Conn
remoteConn net.Conn
Closed bool `json:"closed"`
Stats *Stats `json:"stats"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time,omitempty"`
// options set those fields
l *zap.SugaredLogger
remote *lb.Node
HandshakeDuration time.Duration
RelayLabel string `json:"relay_label"`
ConnType string `json:"conn_type"`
}
func WithRelayLabel(relayLabel string) RelayConnOption {
return func(rci *relayConnImpl) {
rci.RelayLabel = relayLabel
}
}
func WithHandshakeDuration(duration time.Duration) RelayConnOption {
return func(rci *relayConnImpl) {
rci.HandshakeDuration = duration
}
}
func WithConnType(connType string) RelayConnOption {
return func(rci *relayConnImpl) {
rci.ConnType = connType
}
}
func WithRemote(remote *lb.Node) RelayConnOption {
return func(rci *relayConnImpl) {
rci.remote = remote
}
}
func WithLogger(l *zap.SugaredLogger) RelayConnOption {
return func(rci *relayConnImpl) {
rci.l = l
}
}
func (rc *relayConnImpl) Transport() error {
defer rc.Close() // nolint: errcheck
cl := rc.l.Named(shortHashSHA256(rc.GetFlow()))
cl.Debugf("transport start, stats: %s", rc.Stats.String())
c1 := newInnerConn(rc.clientConn, rc)
c2 := newInnerConn(rc.remoteConn, rc)
rc.StartTime = time.Now().Local()
err := copyConn(c1, c2)
if err != nil {
cl.Errorf("transport error: %s", err.Error())
}
cl.Debugf("transport end, stats: %s", rc.Stats.String())
rc.EndTime = time.Now().Local()
return err
}
func (rc *relayConnImpl) Close() error {
err1 := rc.clientConn.Close()
err2 := rc.remoteConn.Close()
rc.Closed = true
return combineErrorsAndMuteEOF(err1, err2)
}
// functions that for web ui
func (rc *relayConnImpl) GetTime() string {
if rc.EndTime.IsZero() {
return fmt.Sprintf("%s - N/A", rc.StartTime.Format(time.Stamp))
}
return fmt.Sprintf("%s - %s", rc.StartTime.Format(time.Stamp), rc.EndTime.Format(time.Stamp))
}
func (rc *relayConnImpl) GetFlow() string {
return fmt.Sprintf("%s <-> %s", rc.clientConn.RemoteAddr(), rc.remoteConn.RemoteAddr())
}
func (rc *relayConnImpl) GetRelayLabel() string {
return rc.RelayLabel
}
func (rc *relayConnImpl) GetStats() *Stats {
return rc.Stats
}
func (rc *relayConnImpl) GetConnType() string {
return rc.ConnType
}
func combineErrorsAndMuteEOF(err1, err2 error) error {
if err1 == io.EOF {
err1 = nil
}
if err2 == io.EOF {
return nil
}
if err1 != nil && err2 != nil {
return errors.Join(err1, err2)
}
if err1 != nil {
return err1
}
return err2
}
type Stats struct {
Up int64
@@ -35,52 +177,69 @@ func (s *Stats) String() string {
)
}
// note that innerConn is a wrapper around net.Conn to allow io.Copy to be used
type innerConn struct {
net.Conn
lastActive time.Time
remoteLabel string
stats *Stats
rc *relayConnImpl
}
func (c *innerConn) setDeadline(isRead bool) {
// set the read deadline to avoid hanging read for non-TCP connections
// because tcp connections have closeWrite/closeRead so no need to set read deadline
if _, ok := c.Conn.(*net.TCPConn); !ok {
deadline := time.Now().Add(idleTimeout)
if isRead {
_ = c.Conn.SetReadDeadline(deadline)
func newInnerConn(conn net.Conn, rc *relayConnImpl) *innerConn {
return &innerConn{Conn: conn, rc: rc, lastActive: time.Now()}
}
func (c *innerConn) recordStats(n int, isRead bool) {
if c.rc == nil {
return
}
if isRead {
metrics.NetWorkTransmitBytes.WithLabelValues(
c.rc.remote.Label, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_READ,
).Add(float64(n))
c.rc.Stats.Record(0, int64(n))
} else {
metrics.NetWorkTransmitBytes.WithLabelValues(
c.rc.remote.Label, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_WRITE,
).Add(float64(n))
c.rc.Stats.Record(int64(n), 0)
}
}
func (c *innerConn) Read(p []byte) (n int, err error) {
for {
deadline := time.Now().Add(constant.ReadTimeOut)
if err := c.Conn.SetReadDeadline(deadline); err != nil {
return 0, err
}
n, err = c.Conn.Read(p)
if err == nil {
c.recordStats(n, true)
c.lastActive = time.Now()
return
} else {
_ = c.Conn.SetWriteDeadline(deadline)
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
if time.Since(c.lastActive) > constant.IdleTimeOut {
c.rc.l.Debugf("read idle,close remote: %s", c.rc.remote.Label)
return 0, io.EOF
}
continue
}
return n, err
}
}
}
func (c *innerConn) recordStats(n int, isRead bool) {
if isRead {
metrics.NetWorkTransmitBytes.WithLabelValues(
c.remoteLabel, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_READ,
).Add(float64(n))
c.stats.Record(0, int64(n))
} else {
metrics.NetWorkTransmitBytes.WithLabelValues(
c.remoteLabel, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_WRITE,
).Add(float64(n))
c.stats.Record(int64(n), 0)
}
}
// 修改Read和Write方法以使用recordStats
func (c *innerConn) Read(p []byte) (n int, err error) {
c.setDeadline(true)
n, err = c.Conn.Read(p)
c.recordStats(n, true) // true for read operation
return
}
func (c *innerConn) Write(p []byte) (n int, err error) {
c.setDeadline(false)
if time.Since(c.lastActive) > constant.IdleTimeOut {
c.rc.l.Debugf("write idle,close remote: %s", c.rc.remote.Label)
return 0, io.EOF
}
n, err = c.Conn.Write(p)
c.recordStats(n, false) // false for write operation
if err != nil {
c.lastActive = time.Now()
}
return
}
@@ -109,10 +268,6 @@ func shortHashSHA256(input string) string {
return hex.EncodeToString(hash)[:7]
}
func connectionName(conn net.Conn) string {
return fmt.Sprintf("l:<%s> r:<%s>", conn.LocalAddr(), conn.RemoteAddr())
}
func copyConn(conn1, conn2 *innerConn) error {
buf := buffer.BufferPool.Get()
defer buffer.BufferPool.Put(buf)
@@ -125,134 +280,13 @@ func copyConn(conn1, conn2 *innerConn) error {
}()
// reverse copy conn2 to conn1,read from conn2 and write to conn1
_, err := io.Copy(conn1, conn2)
buf2 := buffer.BufferPool.Get()
defer buffer.BufferPool.Put(buf2)
_, err := io.CopyBuffer(conn1, conn2, buf2)
_ = conn1.CloseWrite()
err2 := <-errCH
_ = conn1.CloseRead()
_ = conn2.CloseRead()
// handle errors, need to combine errors from both directions
if err != nil && err2 != nil {
err = fmt.Errorf("transport errors in both directions: %v, %v", err, err2)
}
if err != nil {
return err
}
return err2
}
type RelayConnOption func(*relayConnImpl)
type RelayConn interface {
// Transport transports data between the client and the remote server.
// The remoteLabel is the label of the remote server.
Transport(remoteLabel string) error
// GetRelayLabel returns the label of the Relay instance.
GetRelayLabel() string
GetStats() *Stats
Close() error
}
func NewRelayConn(relayName string, clientConn, remoteConn net.Conn, opts ...RelayConnOption) RelayConn {
rci := &relayConnImpl{
RelayLabel: relayName,
clientConn: clientConn,
remoteConn: remoteConn,
}
for _, opt := range opts {
opt(rci)
}
s := &Stats{Up: 0, Down: 0, HandShakeLatency: rci.HandshakeDuration}
rci.Stats = s
return rci
}
type relayConnImpl struct {
RelayLabel string `json:"relay_label"`
Closed bool `json:"closed"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time,omitempty"`
Stats *Stats `json:"stats"`
HandshakeDuration time.Duration
clientConn net.Conn
remoteConn net.Conn
}
func WithHandshakeDuration(duration time.Duration) RelayConnOption {
return func(rci *relayConnImpl) {
rci.HandshakeDuration = duration
}
}
func (rc *relayConnImpl) Transport(remoteLabel string) error {
defer rc.Close() // nolint: errcheck
name := rc.Name()
shortName := fmt.Sprintf("%s-%s", rc.RelayLabel, shortHashSHA256(name))
cl := zap.L().Named(shortName)
cl.Debug("transport start", zap.String("full name", name), zap.String("stats", rc.Stats.String()))
c1 := &innerConn{
stats: rc.Stats,
remoteLabel: remoteLabel,
Conn: rc.clientConn,
}
c2 := &innerConn{
stats: rc.Stats,
remoteLabel: remoteLabel,
Conn: rc.remoteConn,
}
rc.StartTime = time.Now().Local()
err := copyConn(c1, c2)
if err != nil {
cl.Error("transport error", zap.Error(err))
}
cl.Debug("transport end", zap.String("stats", rc.Stats.String()))
rc.EndTime = time.Now().Local()
return err
}
func (rc *relayConnImpl) GetTime() string {
if rc.EndTime.IsZero() {
return fmt.Sprintf("%s - N/A", rc.StartTime.Format(time.Stamp))
}
return fmt.Sprintf("%s - %s", rc.StartTime.Format(time.Stamp), rc.EndTime.Format(time.Stamp))
}
func (rc *relayConnImpl) Name() string {
return fmt.Sprintf("c1:[%s] c2:[%s]", connectionName(rc.clientConn), connectionName(rc.remoteConn))
}
func (rc *relayConnImpl) Flow() string {
return fmt.Sprintf("%s <-> %s", rc.clientConn.RemoteAddr(), rc.remoteConn.RemoteAddr())
}
func (rc *relayConnImpl) GetRelayLabel() string {
return rc.RelayLabel
}
func (rc *relayConnImpl) GetStats() *Stats {
return rc.Stats
}
func (rc *relayConnImpl) Close() error {
err1 := rc.clientConn.Close()
err2 := rc.remoteConn.Close()
rc.Closed = true
return combineErrors(err1, err2)
}
func combineErrors(err1, err2 error) error {
if err1 != nil && err2 != nil {
return fmt.Errorf("combineErrors: %v, %v", err1, err2)
}
if err1 != nil {
return err1
}
return err2
return combineErrorsAndMuteEOF(err, err2)
}

View File

@@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/stretchr/testify/assert"
)
@@ -18,11 +19,9 @@ func TestInnerConn_ReadWrite(t *testing.T) {
serverConn.SetDeadline(time.Now().Add(1 * time.Second))
defer clientConn.Close()
defer serverConn.Close()
innerC := &innerConn{Conn: clientConn, stats: &Stats{}, remoteLabel: "test"}
rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}}
innerC := newInnerConn(clientConn, &rc)
errChan := make(chan error, 1)
go func() {
_, err := innerC.Write(testData)
errChan <- err
@@ -39,7 +38,7 @@ func TestInnerConn_ReadWrite(t *testing.T) {
if err := <-errChan; err != nil {
t.Fatalf("write err: %v", err)
}
assert.Equal(t, int64(len(testData)), innerC.stats.Up)
assert.Equal(t, int64(len(testData)), rc.Stats.Up)
errChan = make(chan error, 1)
clientConn.SetDeadline(time.Now().Add(1 * time.Second))
@@ -64,7 +63,7 @@ func TestInnerConn_ReadWrite(t *testing.T) {
if err := <-errChan; err != nil {
t.Fatalf("write error: %v", err)
}
assert.Equal(t, int64(len(testData)), innerC.stats.Down)
assert.Equal(t, int64(len(testData)), rc.Stats.Down)
}
func TestCopyTCPConn(t *testing.T) {
@@ -96,8 +95,9 @@ func TestCopyTCPConn(t *testing.T) {
assert.NoError(t, err)
defer remoteConn.Close()
c1 := &innerConn{Conn: clientConn, remoteLabel: "client", stats: &Stats{}}
c2 := &innerConn{Conn: remoteConn, remoteLabel: "server", stats: &Stats{}}
rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}}
c1 := newInnerConn(clientConn, &rc)
c2 := newInnerConn(remoteConn, &rc)
done := make(chan struct{})
go func() {
@@ -155,8 +155,9 @@ func TestCopyUDPConn(t *testing.T) {
assert.NoError(t, err)
defer remoteConn.Close()
c1 := &innerConn{Conn: clientConn, remoteLabel: "client", stats: &Stats{}}
c2 := &innerConn{Conn: remoteConn, remoteLabel: "server", stats: &Stats{}}
rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}}
c1 := newInnerConn(clientConn, &rc)
c2 := newInnerConn(remoteConn, &rc)
done := make(chan struct{})
go func() {

View File

@@ -0,0 +1,187 @@
//nolint:errcheck
package conn
import (
"context"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/pkg/buffer"
)
var _ net.Conn = &uc{}
type uc struct {
conn *net.UDPConn
addr *net.UDPAddr
msgCh chan []byte
lastActivity atomic.Value
listener *UDPListener
}
func (c *uc) Read(b []byte) (int, error) {
select {
case msg := <-c.msgCh:
n := copy(b, msg)
c.lastActivity.Store(time.Now())
return n, nil
default:
if time.Since(c.lastActivity.Load().(time.Time)) > constant.IdleTimeOut {
return 0, io.EOF
}
return 0, nil
}
}
func (c *uc) Write(b []byte) (int, error) {
n, err := c.conn.WriteToUDP(b, c.addr)
c.lastActivity.Store(time.Now())
return n, err
}
func (c *uc) Close() error {
c.listener.connsMu.Lock()
delete(c.listener.conns, c.addr.String())
c.listener.connsMu.Unlock()
close(c.msgCh)
return nil
}
func (c *uc) LocalAddr() net.Addr {
return c.conn.LocalAddr()
}
func (c *uc) RemoteAddr() net.Addr {
return c.addr
}
func (c *uc) SetDeadline(t time.Time) error {
return nil
}
func (c *uc) SetReadDeadline(t time.Time) error {
return nil
}
func (c *uc) SetWriteDeadline(t time.Time) error {
return nil
}
type UDPListener struct {
listenAddr *net.UDPAddr
listenConn *net.UDPConn
conns map[string]*uc
connsMu sync.RWMutex
connCh chan *uc
msgCh chan []byte
errCh chan error
ctx context.Context
cancel context.CancelFunc
closed atomic.Bool
}
func NewUDPListener(ctx context.Context, addr string) (*UDPListener, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
ctx, cancel := context.WithCancel(ctx)
l := &UDPListener{
listenConn: conn,
listenAddr: udpAddr,
conns: make(map[string]*uc),
connCh: make(chan *uc),
msgCh: make(chan []byte),
errCh: make(chan error),
ctx: ctx,
cancel: cancel,
}
go l.listen()
return l, nil
}
func (l *UDPListener) listen() {
defer l.listenConn.Close()
for {
if l.closed.Load() {
return
}
buf := buffer.UDPBufferPool.Get()
n, addr, err := l.listenConn.ReadFromUDP(buf)
if err != nil {
if !l.closed.Load() {
select {
case l.errCh <- err:
default:
}
}
buffer.UDPBufferPool.Put(buf)
continue
}
l.connsMu.RLock()
udpConn, exists := l.conns[addr.String()]
l.connsMu.RUnlock()
if !exists {
l.connsMu.Lock()
udpConn = &uc{
conn: l.listenConn,
addr: addr,
listener: l,
msgCh: make(chan []byte, 10),
lastActivity: atomic.Value{},
}
udpConn.lastActivity.Store(time.Now())
l.conns[addr.String()] = udpConn
l.connCh <- udpConn
l.connsMu.Unlock()
}
select {
case udpConn.msgCh <- buf[:n]:
default:
buffer.UDPBufferPool.Put(buf)
}
}
}
func (l *UDPListener) Accept() (*uc, error) {
select {
case conn := <-l.connCh:
return conn, nil
case err := <-l.errCh:
return nil, err
case <-l.ctx.Done():
return nil, l.ctx.Err()
}
}
func (l *UDPListener) Close() error {
if !l.closed.CompareAndSwap(false, true) {
return nil
}
l.cancel()
l.closed.Store(true)
return l.listenConn.Close()
}

View File

@@ -12,6 +12,7 @@ import (
"go.uber.org/zap"
)
// wsConn represents a WebSocket connection to relay(io.Copy)
type wsConn struct {
conn net.Conn
isServer bool
@@ -29,7 +30,7 @@ func (c *wsConn) Read(b []byte) (n int, err error) {
}
if header.Length > int64(cap(c.buf)) {
zap.S().Warnf("ws payload size:%d is larger than buffer size:%d", header.Length, cap(c.buf))
c.buf = make([]byte, header.Length)
return 0, fmt.Errorf("buffer size:%d too small to transport ws payload size:%d", len(b), header.Length)
}
payload := c.buf[:header.Length]
_, err = io.ReadFull(c.conn, payload)

View File

@@ -6,7 +6,9 @@ type RelayType string
var (
// allow change in test
IdleTimeOut = 10 * time.Second
// TODO Set to Relay Config
ReadTimeOut = 5 * time.Second
IdleTimeOut = 30 * time.Second
Version = "1.1.5-dev"
GitBranch string
@@ -20,24 +22,17 @@ const (
SniffTimeOut = 300 * time.Millisecond
SmuxGCDuration = 30 * time.Second
SmuxMaxAliveDuration = 10 * time.Minute
SmuxMaxStreamCnt = 5
// todo add udp buffer size
// todo,support config in relay config
BUFFER_POOL_SIZE = 1024 // support 512 connections
BUFFER_SIZE = 20 * 1024 // 20KB the maximum packet size of shadowsocks is about 16 KiB
BUFFER_SIZE = 40 * 1024 // 40KB ,the maximum packet size of shadowsocks is about 16 KiB so this is enough
UDPBufSize = 1500 // use default max mtu 1500
)
// relay type
const (
// tcp relay
RelayTypeRaw RelayType = "raw"
RelayTypeMTCP RelayType = "mtcp"
// direct relay
RelayTypeRaw RelayType = "raw"
// ws relay
RelayTypeWS RelayType = "ws"
RelayTypeMWS RelayType = "mws"
RelayTypeWSS RelayType = "wss"
RelayTypeMWSS RelayType = "mwss"
RelayTypeWS RelayType = "ws"
RelayTypeWSS RelayType = "wss"
)

View File

@@ -23,24 +23,51 @@ type WSConfig struct {
RemoteAddr string `json:"remote_addr,omitempty"`
}
func (w *WSConfig) Clone() *WSConfig {
return &WSConfig{
Path: w.Path,
RemoteAddr: w.RemoteAddr,
}
}
type Options struct {
WSConfig *WSConfig `json:"ws_config,omitempty"`
EnableUDP bool `json:"enable_udp,omitempty"`
EnableMultipathTCP bool `json:"enable_multipath_tcp,omitempty"`
MaxConnection int `json:"max_connection,omitempty"`
BlockedProtocols []string `json:"blocked_protocols,omitempty"`
MaxReadRateKbps int64 `json:"max_read_rate_kbps,omitempty"`
}
func (o *Options) Clone() *Options {
opt := &Options{
EnableUDP: o.EnableUDP,
EnableMultipathTCP: o.EnableMultipathTCP,
MaxConnection: o.MaxConnection,
MaxReadRateKbps: o.MaxReadRateKbps,
BlockedProtocols: make([]string, len(o.BlockedProtocols)),
}
copy(opt.BlockedProtocols, o.BlockedProtocols)
if o.WSConfig != nil {
opt.WSConfig = o.WSConfig.Clone()
}
return opt
}
type Config struct {
Label string `json:"label,omitempty"`
Listen string `json:"listen"`
ListenType constant.RelayType `json:"listen_type"`
TransportType constant.RelayType `json:"transport_type"`
TCPRemotes []string `json:"tcp_remotes"`
UDPRemotes []string `json:"udp_remotes"`
TCPRemotes []string `json:"tcp_remotes"` // TODO rename to remotes
MaxConnection int `json:"max_connection,omitempty"`
BlockedProtocols []string `json:"blocked_protocols,omitempty"`
MaxReadRateKbps int64 `json:"max_read_rate_kbps,omitempty"`
WSConfig *WSConfig `json:"ws_config,omitempty"`
Options *Options `json:"options,omitempty"`
}
func (r *Config) GetWSHandShakePath() string {
if r.WSConfig != nil && r.WSConfig.Path != "" {
return r.WSConfig.Path
if r.Options != nil && r.Options.WSConfig != nil && r.Options.WSConfig.Path != "" {
return r.Options.WSConfig.Path
}
return WS_HANDSHAKE_PATH
}
@@ -50,8 +77,8 @@ func (r *Config) GetWSRemoteAddr(baseAddr string) (string, error) {
if err != nil {
return "", err
}
if r.WSConfig != nil && r.WSConfig.RemoteAddr != "" {
addr += fmt.Sprintf("?%s=%s", WS_QUERY_REMOTE_ADDR, r.WSConfig.RemoteAddr)
if r.Options != nil && r.Options.WSConfig != nil && r.Options.WSConfig.RemoteAddr != "" {
addr += fmt.Sprintf("?%s=%s", WS_QUERY_REMOTE_ADDR, r.Options.WSConfig.RemoteAddr)
}
return addr, nil
}
@@ -79,17 +106,7 @@ func (r *Config) Validate() error {
}
}
for _, addr := range r.UDPRemotes {
if addr == "" {
return fmt.Errorf("invalid udp remote addr:%s", addr)
}
}
if len(r.TCPRemotes) == 0 && len(r.UDPRemotes) == 0 {
return errors.New("both tcp and udp remotes are empty")
}
for _, protocol := range r.BlockedProtocols {
for _, protocol := range r.Options.BlockedProtocols {
if protocol != ProtocolHTTP && protocol != ProtocolTLS {
return fmt.Errorf("invalid blocked protocol:%s", protocol)
}
@@ -103,11 +120,10 @@ func (r *Config) Clone() *Config {
ListenType: r.ListenType,
TransportType: r.TransportType,
Label: r.Label,
Options: r.Options.Clone(),
}
new.TCPRemotes = make([]string, len(r.TCPRemotes))
copy(new.TCPRemotes, r.TCPRemotes)
new.UDPRemotes = make([]string, len(r.UDPRemotes))
copy(new.UDPRemotes, r.UDPRemotes)
return new
}
@@ -121,26 +137,17 @@ func (r *Config) Different(new *Config) bool {
if len(r.TCPRemotes) != len(new.TCPRemotes) {
return true
}
for i, addr := range r.TCPRemotes {
if addr != new.TCPRemotes[i] {
return true
}
}
if len(r.UDPRemotes) != len(new.UDPRemotes) {
return true
}
for i, addr := range r.UDPRemotes {
if addr != new.UDPRemotes[i] {
return true
}
}
return false
}
// todo make this shorter and more readable
func (r *Config) DefaultLabel() string {
defaultLabel := fmt.Sprintf("<At=%s TCP-To=%s TP=%s>",
defaultLabel := fmt.Sprintf("<At=%s To=%s TP=%s>",
r.Listen, r.TCPRemotes, r.TransportType)
return defaultLabel
}
@@ -150,6 +157,12 @@ func (r *Config) Adjust() error {
r.Label = r.DefaultLabel()
zap.S().Debugf("label is empty, set default label:%s", r.Label)
}
if r.Options == nil {
r.Options = &Options{
WSConfig: &WSConfig{},
EnableMultipathTCP: true, // default enable multipath tcp
}
}
return nil
}
@@ -171,20 +184,14 @@ func (r *Config) GetLoggerName() string {
func (r *Config) validateType() error {
if r.ListenType != constant.RelayTypeRaw &&
r.ListenType != constant.RelayTypeWS &&
r.ListenType != constant.RelayTypeMWS &&
r.ListenType != constant.RelayTypeWSS &&
r.ListenType != constant.RelayTypeMTCP &&
r.ListenType != constant.RelayTypeMWSS {
r.ListenType != constant.RelayTypeWSS {
return fmt.Errorf("invalid listen type:%s", r.ListenType)
}
if r.TransportType != constant.RelayTypeRaw &&
r.TransportType != constant.RelayTypeWS &&
r.TransportType != constant.RelayTypeMWS &&
r.TransportType != constant.RelayTypeWSS &&
r.TransportType != constant.RelayTypeMTCP &&
r.TransportType != constant.RelayTypeMWSS {
return fmt.Errorf("invalid transport type:%s", r.ListenType)
r.TransportType != constant.RelayTypeWSS {
return fmt.Errorf("invalid transport type:%s", r.TransportType)
}
return nil
}

View File

@@ -1,6 +1,8 @@
package relay
import (
"context"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/cmgr"
@@ -33,17 +35,17 @@ func NewRelay(cfg *conf.Config, cmgr cmgr.Cmgr) (*Relay, error) {
return r, nil
}
func (r *Relay) ListenAndServe() error {
func (r *Relay) ListenAndServe(ctx context.Context) error {
errCh := make(chan error)
go func() {
r.l.Infof("Start TCP Relay Server:%s", r.cfg.DefaultLabel())
errCh <- r.relayServer.ListenAndServe()
r.l.Infof("Start Relay Server(%s):%s", r.cfg.ListenType, r.cfg.DefaultLabel())
errCh <- r.relayServer.ListenAndServe(ctx)
}()
return <-errCh
}
func (r *Relay) Close() {
r.l.Infof("Close TCP Relay Server:%s", r.cfg.DefaultLabel())
r.l.Infof("Close Relay Server:%s", r.cfg.DefaultLabel())
if err := r.relayServer.Close(); err != nil {
r.l.Errorf(err.Error())
}

View File

@@ -46,10 +46,10 @@ func NewServer(cfg *config.Config) (*Server, error) {
return s, nil
}
func (s *Server) startOneRelay(r *Relay) {
func (s *Server) startOneRelay(ctx context.Context, r *Relay) {
s.relayM.Store(r.UniqueID(), r)
// mute closed network error for tcp server and mute http.ErrServerClosed for http server when config reload
if err := r.ListenAndServe(); err != nil &&
if err := r.ListenAndServe(ctx); err != nil &&
!errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
s.l.Errorf("start relay %s meet error: %s", r.UniqueID(), err)
s.errCH <- err
@@ -68,7 +68,7 @@ func (s *Server) Start(ctx context.Context) error {
if err != nil {
return err
}
go s.startOneRelay(r)
go s.startOneRelay(ctx, r)
}
if s.cfg.PATH != "" && (s.cfg.ReloadInterval > 0 || len(s.cfg.SubConfigs) > 0) {

View File

@@ -1,6 +1,8 @@
package relay
import (
"context"
"github.com/Ehco1996/ehco/internal/glue"
"github.com/Ehco1996/ehco/internal/relay/conf"
"go.uber.org/zap"
@@ -48,7 +50,7 @@ func (s *Server) Reload(force bool) error {
s.l.Error("new relay meet error", zap.Error(err))
continue
}
go s.startOneRelay(r)
go s.startOneRelay(context.TODO(), r)
} else {
// when label not change, check if config changed
oldCfg, ok := oldRelayCfgM[newCfg.Label]
@@ -66,7 +68,7 @@ func (s *Server) Reload(force bool) error {
s.l.Error("new relay meet error", zap.Error(err))
continue
}
go s.startOneRelay(r)
go s.startOneRelay(context.TODO(), r)
}
}
}

View File

@@ -18,97 +18,160 @@ import (
"go.uber.org/zap"
)
type baseTransporter struct {
cfg *conf.Config
l *zap.SugaredLogger
var _ RelayServer = &BaseRelayServer{}
cmgr cmgr.Cmgr
tCPRemotes lb.RoundRobin
relayer RelayClient
type BaseRelayServer struct {
cmgr cmgr.Cmgr
cfg *conf.Config
l *zap.SugaredLogger
remotes lb.RoundRobin
relayer RelayClient
}
func NewBaseTransporter(cfg *conf.Config, cmgr cmgr.Cmgr) (*baseTransporter, error) {
func newBaseRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (*BaseRelayServer, error) {
relayer, err := newRelayClient(cfg)
if err != nil {
return nil, err
}
return &baseTransporter{
cfg: cfg,
cmgr: cmgr,
tCPRemotes: cfg.ToTCPRemotes(),
l: zap.S().Named(cfg.GetLoggerName()),
relayer: relayer,
return &BaseRelayServer{
relayer: relayer,
cfg: cfg,
cmgr: cmgr,
remotes: cfg.ToTCPRemotes(),
l: zap.S().Named(cfg.GetLoggerName()),
}, nil
}
func (b *baseTransporter) GetTCPListenAddr() (*net.TCPAddr, error) {
return net.ResolveTCPAddr("tcp", b.cfg.Listen)
}
func (b *baseTransporter) GetRemote() *lb.Node {
return b.tCPRemotes.Next()
}
func (b *baseTransporter) RelayTCPConn(c net.Conn, handshakeF TCPHandShakeF) error {
remote := b.GetRemote()
func (b *BaseRelayServer) RelayTCPConn(ctx context.Context, c net.Conn) error {
remote := b.remotes.Next().Clone()
metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Inc()
defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Dec()
// check limit
if b.cfg.MaxConnection > 0 && b.cmgr.CountConnection(cmgr.ConnectionTypeActive) >= b.cfg.MaxConnection {
c.Close()
return fmt.Errorf("relay:%s active connection count exceed limit %d", b.cfg.Label, b.cfg.MaxConnection)
if err := b.checkConnectionLimit(); err != nil {
return err
}
// sniff protocol
if len(b.cfg.BlockedProtocols) > 0 {
buffer := buf.NewPacket()
ctx := context.TODO()
sniffMetadata, err := sniff.PeekStream(
ctx, c, buffer, constant.SniffTimeOut,
sniff.TLSClientHello, sniff.HTTPHost)
if err != nil {
// this mean no protocol sniffed
b.l.Debug("sniff error: %s", err)
}
if sniffMetadata != nil {
b.l.Infof("sniffed protocol: %s", sniffMetadata.Protocol)
for _, p := range b.cfg.BlockedProtocols {
if sniffMetadata.Protocol == p {
c.Close()
return fmt.Errorf("relay:%s want to relay blocked protocol:%s", b.cfg.Label, sniffMetadata.Protocol)
}
}
}
if !buffer.IsEmpty() {
c = bufio.NewCachedConn(c, buffer)
} else {
buffer.Release()
}
}
// rate limit
if b.cfg.MaxReadRateKbps > 0 {
c = conn.NewRateLimitedConn(c, b.cfg.MaxReadRateKbps)
}
clonedRemote := remote.Clone()
rc, err := handshakeF(clonedRemote)
var err error
c, err = b.sniffAndBlockProtocol(c)
if err != nil {
return err
}
c = b.applyRateLimit(c)
rc, err := b.relayer.HandShake(ctx, remote, true)
if err != nil {
return fmt.Errorf("handshake error: %w", err)
}
defer rc.Close()
b.l.Infof("RelayTCPConn from %s to %s", c.LocalAddr(), remote.Address)
relayConn := conn.NewRelayConn(
b.cfg.Label, c, rc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
b.cmgr.AddConnection(relayConn)
defer b.cmgr.RemoveConnection(relayConn)
return relayConn.Transport(remote.Label)
return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_TCP)
}
func (b *baseTransporter) HealthCheck(ctx context.Context) (int64, error) {
remote := b.GetRemote().Clone()
err := b.relayer.HealthCheck(ctx, remote)
func (b *BaseRelayServer) RelayUDPConn(ctx context.Context, c net.Conn) error {
remote := b.remotes.Next().Clone()
metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_UDP).Inc()
defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_UDP).Dec()
rc, err := b.relayer.HandShake(ctx, remote, false)
if err != nil {
return fmt.Errorf("handshake error: %w", err)
}
defer rc.Close()
b.l.Infof("RelayUDPConn from %s to %s", c.LocalAddr(), remote.Address)
return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_UDP)
}
func (b *BaseRelayServer) checkConnectionLimit() error {
if b.cfg.Options.MaxConnection > 0 && b.cmgr.CountConnection(cmgr.ConnectionTypeActive) >= b.cfg.Options.MaxConnection {
return fmt.Errorf("relay:%s active connection count exceed limit %d", b.cfg.Label, b.cfg.Options.MaxConnection)
}
return nil
}
func (b *BaseRelayServer) sniffAndBlockProtocol(c net.Conn) (net.Conn, error) {
if len(b.cfg.Options.BlockedProtocols) == 0 {
return c, nil
}
buffer := buf.NewPacket()
ctx, cancel := context.WithTimeout(context.Background(), constant.SniffTimeOut)
defer cancel()
sniffMetadata, err := sniff.PeekStream(ctx, c, buffer, constant.SniffTimeOut, sniff.TLSClientHello, sniff.HTTPHost)
if err != nil {
b.l.Debugf("sniff error: %s", err)
return c, nil
}
if sniffMetadata != nil {
b.l.Infof("sniffed protocol: %s", sniffMetadata.Protocol)
for _, p := range b.cfg.Options.BlockedProtocols {
if sniffMetadata.Protocol == p {
return c, fmt.Errorf("relay:%s blocked protocol:%s", b.cfg.Label, sniffMetadata.Protocol)
}
}
}
if !buffer.IsEmpty() {
return bufio.NewCachedConn(c, buffer), nil
} else {
buffer.Release()
}
return c, nil
}
func (b *BaseRelayServer) applyRateLimit(c net.Conn) net.Conn {
if b.cfg.Options.MaxReadRateKbps > 0 {
return conn.NewRateLimitedConn(c, b.cfg.Options.MaxReadRateKbps)
}
return c
}
func (b *BaseRelayServer) handleRelayConn(c, rc net.Conn, remote *lb.Node, connType string) error {
opts := []conn.RelayConnOption{
conn.WithLogger(b.l),
conn.WithRemote(remote),
conn.WithConnType(connType),
conn.WithRelayLabel(b.cfg.Label),
conn.WithHandshakeDuration(remote.HandShakeDuration),
}
relayConn := conn.NewRelayConn(c, rc, opts...)
b.cmgr.AddConnection(relayConn)
defer b.cmgr.RemoveConnection(relayConn)
return relayConn.Transport()
}
func (b *BaseRelayServer) HealthCheck(ctx context.Context) (int64, error) {
remote := b.remotes.Next().Clone()
// us tcp handshake to check health
_, err := b.relayer.HandShake(ctx, remote, true)
return int64(remote.HandShakeDuration.Milliseconds()), err
}
func (b *BaseRelayServer) Close() error {
return fmt.Errorf("not implemented")
}
func (b *BaseRelayServer) ListenAndServe(ctx context.Context) error {
return fmt.Errorf("not implemented")
}
func NewNetDialer(cfg *conf.Config) *net.Dialer {
dialer := &net.Dialer{Timeout: constant.DialTimeOut}
dialer.SetMultipathTCP(cfg.Options.EnableMultipathTCP)
return dialer
}
func NewTCPListener(ctx context.Context, cfg *conf.Config) (net.Listener, error) {
addr, err := net.ResolveTCPAddr("tcp", cfg.Listen)
if err != nil {
return nil, err
}
lcfg := net.ListenConfig{}
lcfg.SetMultipathTCP(cfg.Options.EnableMultipathTCP)
return lcfg.Listen(ctx, "tcp", addr.String())
}

View File

@@ -11,56 +11,45 @@ import (
"github.com/Ehco1996/ehco/internal/relay/conf"
)
type TCPHandShakeF func(remote *lb.Node) (net.Conn, error)
// TODO opt this interface
type RelayClient interface {
HealthCheck(ctx context.Context, remote *lb.Node) error
TCPHandShake(remote *lb.Node) (net.Conn, error)
HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error)
}
func newRelayClient(cfg *conf.Config) (RelayClient, error) {
switch cfg.TransportType {
case constant.RelayTypeRaw:
return newRawClient(cfg)
case constant.RelayTypeMTCP:
return newMtcpClient(cfg)
case constant.RelayTypeWS:
return newWsClient(cfg)
case constant.RelayTypeMWS:
return newMwsClient(cfg)
case constant.RelayTypeWSS:
return newWssClient(cfg)
case constant.RelayTypeMWSS:
return newMwssClient(cfg)
default:
return nil, fmt.Errorf("unsupported transport type: %s", cfg.TransportType)
}
}
type RelayServer interface {
ListenAndServe() error
ListenAndServe(ctx context.Context) error
Close() error
RelayTCPConn(ctx context.Context, c net.Conn) error
RelayUDPConn(ctx context.Context, c net.Conn) error
HealthCheck(ctx context.Context) (int64, error) // latency in ms
}
func NewRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (RelayServer, error) {
base, err := NewBaseTransporter(cfg, cmgr)
base, err := newBaseRelayServer(cfg, cmgr)
if err != nil {
return nil, err
}
switch cfg.ListenType {
case constant.RelayTypeRaw:
return newRawServer(base)
case constant.RelayTypeMTCP:
return newMtcpServer(base)
case constant.RelayTypeWS:
return newWsServer(base)
case constant.RelayTypeMWS:
return newMwsServer(base)
case constant.RelayTypeWSS:
return newWssServer(base)
case constant.RelayTypeMWSS:
return newMwssServer(base)
default:
panic("unsupported transport type" + cfg.ListenType)
}

View File

@@ -1,201 +0,0 @@
// nolint: errcheck
package transporter
import (
"context"
"net"
"sync"
"time"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/xtaci/smux"
"go.uber.org/zap"
)
type smuxTransporter struct {
sessionMutex sync.Mutex
gcTicker *time.Ticker
l *zap.SugaredLogger
// remote addr -> SessionWithMetrics
sessionM map[string][]*SessionWithMetrics
initSessionF func(ctx context.Context, addr string) (*smux.Session, error)
}
type SessionWithMetrics struct {
session *smux.Session
createdTime time.Time
streamList []*smux.Stream
}
func (sm *SessionWithMetrics) CanNotServeNewStream() bool {
return sm.session.IsClosed() ||
sm.session.NumStreams() >= constant.SmuxMaxStreamCnt ||
time.Since(sm.createdTime) > constant.SmuxMaxAliveDuration
}
func streamDead(s *smux.Stream) bool {
select {
case _, ok := <-s.GetDieCh():
return !ok // 如果接收到值且通道未关闭,则 Stream 未死
default:
return true // 如果通道已经关闭,则 Stream 死了
}
}
func (sm *SessionWithMetrics) canCloseSession(remoteAddr string, l *zap.SugaredLogger) bool {
for _, s := range sm.streamList {
if !streamDead(s) {
return false
}
l.Debugf("session: %s stream: %d is not dead", remoteAddr, s.ID())
}
return true
}
func NewSmuxTransporter(
l *zap.SugaredLogger,
initSessionF func(ctx context.Context, addr string) (*smux.Session, error),
) *smuxTransporter {
tr := &smuxTransporter{
l: l,
initSessionF: initSessionF,
sessionM: make(map[string][]*SessionWithMetrics),
gcTicker: time.NewTicker(constant.SmuxGCDuration),
}
// start gc thread for close idle sessions
go tr.gc()
return tr
}
func (tr *smuxTransporter) gc() {
for range tr.gcTicker.C {
tr.sessionMutex.Lock()
for addr, sl := range tr.sessionM {
tr.l.Debugf("start doing gc for remote addr: %s total session count %d", addr, len(sl))
for idx := range sl {
sm := sl[idx]
if sm.CanNotServeNewStream() && sm.canCloseSession(addr, tr.l) {
tr.l.Debugf("close idle session:%s stream cnt %d",
sm.session.LocalAddr().String(), sm.session.NumStreams())
sm.session.Close()
}
}
newList := []*SessionWithMetrics{}
for _, s := range sl {
if !s.session.IsClosed() {
newList = append(newList, s)
}
}
tr.sessionM[addr] = newList
tr.l.Debugf("finish gc for remote addr: %s total session count %d", addr, len(sl))
}
tr.sessionMutex.Unlock()
}
}
func (tr *smuxTransporter) Dial(ctx context.Context, addr string) (conn net.Conn, err error) {
tr.sessionMutex.Lock()
defer tr.sessionMutex.Unlock()
var session *smux.Session
var curSM *SessionWithMetrics
sessionList := tr.sessionM[addr]
for _, sm := range sessionList {
if sm.CanNotServeNewStream() {
continue
} else {
tr.l.Debugf("use session: %s total stream count: %d remote addr: %s",
sm.session.LocalAddr().String(), sm.session.NumStreams(), addr)
session = sm.session
curSM = sm
break
}
}
// create new one
if session == nil {
session, err = tr.initSessionF(ctx, addr)
if err != nil {
return nil, err
}
sm := &SessionWithMetrics{session: session, createdTime: time.Now(), streamList: []*smux.Stream{}}
sessionList = append(sessionList, sm)
tr.sessionM[addr] = sessionList
curSM = sm
}
stream, err := session.OpenStream()
if err != nil {
tr.l.Errorf("open stream meet error:%s", err)
session.Close()
return nil, err
}
curSM.streamList = append(curSM.streamList, stream)
return stream, nil
}
type muxServer interface {
ListenAndServe() error
Accept() (net.Conn, error)
Close() error
mux(net.Conn)
}
func newMuxServer(listenAddr string, l *zap.SugaredLogger) *muxServerImpl {
return &muxServerImpl{
errChan: make(chan error, 1),
connChan: make(chan net.Conn, 1024),
listenAddr: listenAddr,
l: l,
}
}
type muxServerImpl struct {
errChan chan error
connChan chan net.Conn
listenAddr string
l *zap.SugaredLogger
}
func (s *muxServerImpl) Accept() (net.Conn, error) {
select {
case conn := <-s.connChan:
return conn, nil
case err := <-s.errChan:
return nil, err
}
}
func (s *muxServerImpl) mux(conn net.Conn) {
defer conn.Close()
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Server(conn, cfg)
if err != nil {
s.l.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.listenAddr, err)
return
}
defer session.Close() // nolint: errcheck
s.l.Debugf("session init %s %s", conn.RemoteAddr(), s.listenAddr)
defer s.l.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.listenAddr)
for {
stream, err := session.AcceptStream()
if err != nil {
s.l.Errorf("accept stream err: %s", err)
break
}
select {
case s.connChan <- stream:
default:
stream.Close() // nolint: errcheck
s.l.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
}
}
}

View File

@@ -3,10 +3,11 @@ package transporter
import (
"context"
"errors"
"net"
"time"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay/conf"
@@ -26,17 +27,22 @@ type RawClient struct {
func newRawClient(cfg *conf.Config) (*RawClient, error) {
r := &RawClient{
l: zap.S().Named("raw"),
cfg: cfg,
dialer: &net.Dialer{Timeout: constant.DialTimeOut},
dialer: NewNetDialer(cfg),
l: zap.S().Named(string(cfg.TransportType)),
}
r.dialer.SetMultipathTCP(true)
return r, nil
}
func (raw *RawClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
func (raw *RawClient) HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error) {
t1 := time.Now()
rc, err := raw.dialer.Dial("tcp", remote.Address)
var rc net.Conn
var err error
if isTCP {
rc, err = raw.dialer.DialContext(ctx, "tcp", remote.Address)
} else {
rc, err = raw.dialer.DialContext(ctx, "udp", remote.Address)
}
if err != nil {
return nil, err
}
@@ -46,55 +52,72 @@ func (raw *RawClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
return rc, nil
}
func (raw *RawClient) HealthCheck(ctx context.Context, remote *lb.Node) error {
l := zap.S().Named("health-check")
l.Infof("start send req to %s", remote.Address)
c, err := raw.TCPHandShake(remote)
if err != nil {
l.Errorf("send req to %s meet error:%s", remote.Address, err)
return err
}
c.Close()
return nil
}
type RawServer struct {
*baseTransporter
lis net.Listener
*BaseRelayServer
tcpLis net.Listener
udpLis *conn.UDPListener
}
func newRawServer(base *baseTransporter) (*RawServer, error) {
addr, err := base.GetTCPListenAddr()
if err != nil {
return nil, err
}
cfg := net.ListenConfig{}
cfg.SetMultipathTCP(true)
lis, err := cfg.Listen(context.TODO(), "tcp", addr.String())
if err != nil {
return nil, err
}
return &RawServer{
lis: lis,
baseTransporter: base,
}, nil
func newRawServer(bs *BaseRelayServer) (*RawServer, error) {
rs := &RawServer{BaseRelayServer: bs}
return rs, nil
}
func (s *RawServer) Close() error {
return s.lis.Close()
err := s.tcpLis.Close()
if s.udpLis != nil {
err2 := s.udpLis.Close()
err = errors.Join(err, err2)
}
return err
}
func (s *RawServer) ListenAndServe() error {
func (s *RawServer) ListenAndServe(ctx context.Context) error {
ts, err := NewTCPListener(ctx, s.cfg)
if err != nil {
return err
}
s.tcpLis = ts
if s.cfg.Options != nil && s.cfg.Options.EnableUDP {
udpLis, err := conn.NewUDPListener(ctx, s.cfg.Listen)
if err != nil {
return err
}
s.udpLis = udpLis
}
if s.udpLis != nil {
go s.listenUDP(ctx)
}
for {
c, err := s.lis.Accept()
c, err := s.tcpLis.Accept()
if err != nil {
return err
}
go func(c net.Conn) {
defer c.Close()
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
if err := s.RelayTCPConn(ctx, c); err != nil {
s.l.Errorf("RelayTCPConn meet error: %s", err.Error())
}
}(c)
}
}
func (s *RawServer) listenUDP(ctx context.Context) error {
s.l.Infof("Start UDP server at %s", s.cfg.Listen)
for {
c, err := s.udpLis.Accept()
if err != nil {
s.l.Errorf("UDP accept error: %v", err)
return err
}
go func() {
if err := s.RelayUDPConn(ctx, c); err != nil {
s.l.Errorf("RelayUDPConn meet error: %s", err.Error())
}
}()
}
}

View File

@@ -1,109 +0,0 @@
package transporter
import (
"context"
"net"
"time"
"github.com/xtaci/smux"
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay/conf"
)
var (
_ RelayClient = &MtcpClient{}
_ RelayServer = &MtcpServer{}
)
type MtcpClient struct {
*RawClient
muxTP *smuxTransporter
}
func newMtcpClient(cfg *conf.Config) (*MtcpClient, error) {
raw, err := newRawClient(cfg)
if err != nil {
return nil, err
}
c := &MtcpClient{RawClient: raw}
c.muxTP = NewSmuxTransporter(zap.S().Named("mtcp"), c.initNewSession)
return c, nil
}
func (c *MtcpClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
rc, err := c.dialer.Dial("tcp", addr)
if err != nil {
return nil, err
}
// stream multiplex
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Client(rc, cfg)
if err != nil {
return nil, err
}
c.l.Infof("init new session to: %s", rc.RemoteAddr())
return session, nil
}
func (s *MtcpClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
mtcpc, err := s.muxTP.Dial(context.TODO(), remote.Address)
if err != nil {
return nil, err
}
latency := time.Since(t1)
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
remote.HandShakeDuration = latency
return mtcpc, nil
}
type MtcpServer struct {
*RawServer
*muxServerImpl
}
func newMtcpServer(base *baseTransporter) (*MtcpServer, error) {
raw, err := newRawServer(base)
if err != nil {
return nil, err
}
s := &MtcpServer{
RawServer: raw,
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mtcp")),
}
return s, nil
}
func (s *MtcpServer) ListenAndServe() error {
go func() {
for {
c, err := s.lis.Accept()
if err != nil {
s.errChan <- err
continue
}
go s.mux(c)
}
}()
for {
conn, e := s.Accept()
if e != nil {
return e
}
go func(c net.Conn) {
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}(conn)
}
}
func (s *MtcpServer) Close() error {
return s.lis.Close()
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"net"
"net/http"
"net/url"
"time"
"github.com/gobwas/ws"
@@ -11,7 +12,6 @@ import (
"go.uber.org/zap"
"github.com/Ehco1996/ehco/internal/conn"
"github.com/Ehco1996/ehco/internal/constant"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay/conf"
@@ -24,27 +24,48 @@ var (
)
type WsClient struct {
dialer *ws.Dialer
cfg *conf.Config
l *zap.SugaredLogger
dialer *ws.Dialer
cfg *conf.Config
netDialer *net.Dialer
l *zap.SugaredLogger
}
func newWsClient(cfg *conf.Config) (*WsClient, error) {
s := &WsClient{
cfg: cfg,
l: zap.S().Named(string(cfg.TransportType)),
dialer: &ws.Dialer{Timeout: constant.DialTimeOut},
cfg: cfg,
netDialer: NewNetDialer(cfg),
l: zap.S().Named(string(cfg.TransportType)),
dialer: &ws.DefaultDialer, // todo config buffer size
}
s.dialer.NetDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.netDialer.Dial(network, addr)
}
return s, nil
}
func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
func (s *WsClient) addUDPQueryParam(addr string) string {
u, err := url.Parse(addr)
if err != nil {
s.l.Errorf("Failed to parse URL: %v", err)
return addr
}
q := u.Query()
q.Set("type", "udp")
u.RawQuery = q.Encode()
return u.String()
}
func (s *WsClient) HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error) {
t1 := time.Now()
addr, err := s.cfg.GetWSRemoteAddr(remote.Address)
if err != nil {
return nil, err
}
wsc, _, _, err := s.dialer.Dial(context.TODO(), addr)
if !isTCP {
addr = s.addUDPQueryParam(addr)
}
wsc, _, _, err := s.dialer.Dial(ctx, addr)
if err != nil {
return nil, err
}
@@ -55,61 +76,50 @@ func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
return c, nil
}
func (s *WsClient) HealthCheck(ctx context.Context, remote *lb.Node) error {
l := zap.S().Named("health-check")
l.Infof("start send req to %s", remote.Address)
c, err := s.TCPHandShake(remote)
if err != nil {
l.Errorf("send req to %s meet error:%s", remote.Address, err)
return err
}
c.Close()
return nil
}
type WsServer struct {
*baseTransporter
e *echo.Echo
*BaseRelayServer
httpServer *http.Server
}
func newWsServer(base *baseTransporter) (*WsServer, error) {
localTCPAddr, err := base.GetTCPListenAddr()
if err != nil {
return nil, err
}
s := &WsServer{
baseTransporter: base,
httpServer: &http.Server{
Addr: localTCPAddr.String(), ReadHeaderTimeout: 30 * time.Second,
},
}
func newWsServer(bs *BaseRelayServer) (*WsServer, error) {
s := &WsServer{BaseRelayServer: bs}
e := web.NewEchoServer()
e.Use(web.NginxLogMiddleware(zap.S().Named("ws-server")))
e.GET("/", echo.WrapHandler(web.MakeIndexF()))
e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
s.e = e
e.GET(bs.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.handleRequest)))
s.httpServer = &http.Server{Handler: e}
return s, nil
}
func (s *WsServer) ListenAndServe() error {
return s.e.StartServer(s.httpServer)
}
func (s *WsServer) Close() error {
return s.e.Close()
}
func (s *WsServer) HandleRequest(w http.ResponseWriter, req *http.Request) {
func (s *WsServer) handleRequest(w http.ResponseWriter, req *http.Request) {
// todo use bufio.ReadWriter
wsc, _, _, err := ws.UpgradeHTTP(req, w)
if err != nil {
return
}
if err := s.RelayTCPConn(conn.NewWSConn(wsc, true), s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
if req.URL.Query().Get("type") == "udp" {
if !s.cfg.Options.EnableUDP {
s.l.Error("udp not support but request with udp type")
wsc.Close()
return
}
err = s.RelayUDPConn(req.Context(), conn.NewWSConn(wsc, true))
} else {
err = s.RelayTCPConn(req.Context(), conn.NewWSConn(wsc, true))
}
if err != nil {
s.l.Errorf("handleRequest meet error:%s", err)
}
}
func (s *WsServer) ListenAndServe(ctx context.Context) error {
listener, err := NewTCPListener(ctx, s.cfg)
if err != nil {
return err
}
return s.httpServer.Serve(listener)
}
func (s *WsServer) Close() error {
return s.httpServer.Close()
}

View File

@@ -1,122 +0,0 @@
// NOTE CAN NOT use real ws frame to transport smux frame
// err: accept stream err: buffer size:8 too small to transport ws payload size:45
// so this transport just use ws protocol to handshake and then use smux protocol to transport
package transporter
import (
"context"
"net"
"net/http"
"time"
"github.com/gobwas/ws"
"github.com/labstack/echo/v4"
"github.com/xtaci/smux"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay/conf"
)
var (
_ RelayClient = &MwsClient{}
_ RelayServer = &MwsServer{}
_ muxServer = &MwsServer{}
)
type MwsClient struct {
*WssClient
muxTP *smuxTransporter
}
func newMwsClient(cfg *conf.Config) (*MwsClient, error) {
wc, err := newWssClient(cfg)
if err != nil {
return nil, err
}
c := &MwsClient{WssClient: wc}
c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession)
return c, nil
}
func (c *MwsClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
rc, _, _, err := c.dialer.Dial(ctx, addr)
if err != nil {
return nil, err
}
// stream multiplex
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Client(rc, cfg)
if err != nil {
return nil, err
}
c.l.Infof("init new session to: %s", rc.RemoteAddr())
return session, nil
}
func (s *MwsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
addr, err := s.cfg.GetWSRemoteAddr(remote.Address)
if err != nil {
return nil, err
}
mwssc, err := s.muxTP.Dial(context.TODO(), addr)
if err != nil {
return nil, err
}
latency := time.Since(t1)
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
remote.HandShakeDuration = latency
return mwssc, nil
}
type MwsServer struct {
*WsServer
*muxServerImpl
}
func newMwsServer(base *baseTransporter) (*MwsServer, error) {
wsServer, err := newWsServer(base)
if err != nil {
return nil, err
}
s := &MwsServer{
WsServer: wsServer,
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")),
}
s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
return s, nil
}
func (s *MwsServer) ListenAndServe() error {
go func() {
s.errChan <- s.e.StartServer(s.httpServer)
}()
for {
conn, e := s.Accept()
if e != nil {
return e
}
go func(c net.Conn) {
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}(conn)
}
}
func (s *MwsServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
c, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
s.l.Error(err)
return
}
s.mux(c)
}
func (s *MwsServer) Close() error {
return s.e.Close()
}

View File

@@ -1,6 +1,9 @@
package transporter
import (
"context"
"crypto/tls"
"github.com/Ehco1996/ehco/internal/relay/conf"
mytls "github.com/Ehco1996/ehco/internal/tls"
)
@@ -28,12 +31,19 @@ type WssServer struct {
*WsServer
}
func newWssServer(base *baseTransporter) (*WssServer, error) {
wsServer, err := newWsServer(base)
func newWssServer(bs *BaseRelayServer) (*WssServer, error) {
wsServer, err := newWsServer(bs)
if err != nil {
return nil, err
}
// insert tls config
wsServer.httpServer.TLSConfig = mytls.DefaultTLSConfig
return &WssServer{WsServer: wsServer}, nil
}
func (s *WssServer) ListenAndServe(ctx context.Context) error {
listener, err := NewTCPListener(ctx, s.cfg)
if err != nil {
return err
}
tlsListener := tls.NewListener(listener, mytls.DefaultTLSConfig)
return s.httpServer.Serve(tlsListener)
}

View File

@@ -1,122 +0,0 @@
// NOTE CAN NOT use real ws frame to transport smux frame
// err: accept stream err: buffer size:8 too small to transport ws payload size:45
// so this transport just use ws protocol to handshake and then use smux protocol to transport
package transporter
import (
"context"
"net"
"net/http"
"time"
"github.com/gobwas/ws"
"github.com/labstack/echo/v4"
"github.com/xtaci/smux"
"github.com/Ehco1996/ehco/internal/lb"
"github.com/Ehco1996/ehco/internal/metrics"
"github.com/Ehco1996/ehco/internal/relay/conf"
)
var (
_ RelayClient = &MwssClient{}
_ RelayServer = &MwssServer{}
_ muxServer = &MwssServer{}
)
type MwssClient struct {
*WssClient
muxTP *smuxTransporter
}
func newMwssClient(cfg *conf.Config) (*MwssClient, error) {
wc, err := newWssClient(cfg)
if err != nil {
return nil, err
}
c := &MwssClient{WssClient: wc}
c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession)
return c, nil
}
func (c *MwssClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
rc, _, _, err := c.dialer.Dial(ctx, addr)
if err != nil {
return nil, err
}
// stream multiplex
cfg := smux.DefaultConfig()
cfg.KeepAliveDisabled = true
session, err := smux.Client(rc, cfg)
if err != nil {
return nil, err
}
c.l.Infof("init new session to: %s", rc.RemoteAddr())
return session, nil
}
func (s *MwssClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
t1 := time.Now()
addr, err := s.cfg.GetWSRemoteAddr(remote.Address)
if err != nil {
return nil, err
}
mwssc, err := s.muxTP.Dial(context.TODO(), addr)
if err != nil {
return nil, err
}
latency := time.Since(t1)
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
remote.HandShakeDuration = latency
return mwssc, nil
}
type MwssServer struct {
*WssServer
*muxServerImpl
}
func newMwssServer(base *baseTransporter) (*MwssServer, error) {
wssServer, err := newWssServer(base)
if err != nil {
return nil, err
}
s := &MwssServer{
WssServer: wssServer,
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")),
}
s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
return s, nil
}
func (s *MwssServer) ListenAndServe() error {
go func() {
s.errChan <- s.e.StartServer(s.httpServer)
}()
for {
conn, e := s.Accept()
if e != nil {
return e
}
go func(c net.Conn) {
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
s.l.Errorf("RelayTCPConn error: %s", err.Error())
}
}(conn)
}
}
func (s *MwssServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
c, _, _, err := ws.UpgradeHTTP(r, w)
if err != nil {
s.l.Error(err)
return
}
s.mux(c)
}
func (s *MwssServer) Close() error {
return s.e.Close()
}

View File

@@ -28,6 +28,7 @@
<thead>
<tr>
<th>Relay Label</th>
<th>Type</th>
<th>Flow</th>
<th>Stats</th>
<th>Time</th>
@@ -37,7 +38,8 @@
{{range .ConnectionList}}
<tr>
<td>{{.RelayLabel}}</td>
<td>{{.Flow}}</td>
<td>{{.ConnType}}</td>
<td>{{.GetFlow}}</td>
<td>{{.Stats}}</td>
<td>{{.GetTime}}</td>
</tr>

View File

@@ -6,11 +6,13 @@ import (
// 全局pool
var (
BufferPool *BytePool
BufferPool *BytePool
UDPBufferPool *BytePool
)
func init() {
BufferPool = NewBytePool(constant.BUFFER_POOL_SIZE, constant.BUFFER_SIZE)
UDPBufferPool = NewBytePool(constant.BUFFER_POOL_SIZE, constant.UDPBufSize)
}
// BytePool implements a leaky pool of []byte in the form of a bounded channel

View File

@@ -35,7 +35,7 @@ func (b *readerImpl) parsePingInfo(metricMap map[string]*dto.MetricFamily, nm *N
metric, ok := metricMap["ehco_ping_response_duration_seconds"]
if !ok {
// this metric is optional when enable_ping = false
zap.S().Warn("ping metric not found")
zap.S().Debug("ping metric not found")
return nil
}
for _, m := range metric.Metric {

View File

@@ -179,7 +179,9 @@ func (c *ClashSub) ToRelayConfigs(listenHost string) ([]*relay_cfg.Config, error
}
rc.TCPRemotes = append(rc.TCPRemotes, remote)
if proxy.UDP {
rc.UDPRemotes = append(rc.UDPRemotes, remote)
rc.Options = &relay_cfg.Options{
EnableUDP: true,
}
}
}
relayConfigs = append(relayConfigs, rc)

View File

@@ -139,7 +139,9 @@ func (p *Proxies) ToRelayConfig(listenHost string, listenPort string, newName st
TCPRemotes: []string{remoteAddr},
}
if p.UDP {
r.UDPRemotes = []string{remoteAddr}
r.Options = &relay_cfg.Options{
EnableUDP: true,
}
}
if err := r.Validate(); err != nil {
return nil, err

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"context"
"fmt"
"net"
"os"
"testing"
"time"
@@ -28,8 +27,7 @@ const (
ECHO_PORT = 9002
ECHO_SERVER = "0.0.0.0:9002"
RAW_LISTEN = "0.0.0.0:1234"
RAW_LISTEN_WITH_MAX_CONNECTION = "0.0.0.0:2234"
RAW_LISTEN = "0.0.0.0:1234"
WS_LISTEN = "0.0.0.0:1235"
WS_REMOTE = "ws://0.0.0.0:2000"
@@ -38,22 +36,15 @@ const (
WSS_LISTEN = "0.0.0.0:1236"
WSS_REMOTE = "wss://0.0.0.0:2001"
WSS_SERVER = "0.0.0.0:2001"
MWSS_LISTEN = "0.0.0.0:1237"
MWSS_REMOTE = "wss://0.0.0.0:2002"
MWSS_SERVER = "0.0.0.0:2002"
MTCP_LISTEN = "0.0.0.0:1238"
MTCP_REMOTE = "0.0.0.0:2003"
MTCP_SERVER = "0.0.0.0:2003"
MWS_LISTEN = "0.0.0.0:1239"
MWS_REMOTE = "ws://0.0.0.0:2004"
MSS_SERVER = "0.0.0.0:2004"
)
func TestMain(m *testing.M) {
// Setup
// change the idle timeout to 1 second to make connection close faster in test
constant.IdleTimeOut = time.Second
constant.ReadTimeOut = time.Second
_ = log.InitGlobalLogger("debug")
_ = tls.InitTlsCfg()
@@ -78,38 +69,35 @@ func TestMain(m *testing.M) {
func startRelayServers() []*relay.Relay {
cfg := config.Config{
PATH: "",
RelayConfigs: []*conf.Config{
// raw cfg
// raw
{
Listen: RAW_LISTEN,
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{ECHO_SERVER},
UDPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
Options: &conf.Options{
EnableUDP: true,
},
},
// raw cfg with max connection
{
Listen: RAW_LISTEN_WITH_MAX_CONNECTION,
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{ECHO_SERVER},
UDPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
MaxConnection: 1,
},
// ws
{
Listen: WS_LISTEN,
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{WS_REMOTE},
TransportType: constant.RelayTypeWS,
Options: &conf.Options{
EnableUDP: true,
},
},
{
Listen: WS_SERVER,
ListenType: constant.RelayTypeWS,
TCPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
Options: &conf.Options{
EnableUDP: true,
},
},
// wss
@@ -118,65 +106,30 @@ func startRelayServers() []*relay.Relay {
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{WSS_REMOTE},
TransportType: constant.RelayTypeWSS,
Options: &conf.Options{
EnableUDP: true,
},
},
{
Listen: WSS_SERVER,
ListenType: constant.RelayTypeWSS,
TCPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
},
// mwss
{
Listen: MWSS_LISTEN,
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{MWSS_REMOTE},
TransportType: constant.RelayTypeMWSS,
},
{
Listen: MWSS_SERVER,
ListenType: constant.RelayTypeMWSS,
TCPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
},
// mtcp
{
Listen: MTCP_LISTEN,
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{MTCP_REMOTE},
TransportType: constant.RelayTypeMTCP,
},
{
Listen: MTCP_SERVER,
ListenType: constant.RelayTypeMTCP,
TCPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
},
// mws
{
Listen: MWS_LISTEN,
ListenType: constant.RelayTypeRaw,
TCPRemotes: []string{MWS_REMOTE},
TransportType: constant.RelayTypeMWS,
},
{
Listen: MSS_SERVER,
ListenType: constant.RelayTypeMWS,
TCPRemotes: []string{ECHO_SERVER},
TransportType: constant.RelayTypeRaw,
Options: &conf.Options{
EnableUDP: true,
},
},
},
}
var servers []*relay.Relay
for _, c := range cfg.RelayConfigs {
c.Adjust()
r, err := relay.NewRelay(c, cmgr.NewCmgr(cmgr.DummyConfig))
if err != nil {
zap.S().Fatal(err)
}
go r.ListenAndServe()
go r.ListenAndServe(context.TODO())
servers = append(servers, r)
}
@@ -194,16 +147,14 @@ func TestRelay(t *testing.T) {
{"Raw", RAW_LISTEN, "raw"},
{"WS", WS_LISTEN, "ws"},
{"WSS", WSS_LISTEN, "wss"},
{"MWSS", MWSS_LISTEN, "mwss"},
{"MTCP", MTCP_LISTEN, "mtcp"},
{"MWS", MWS_LISTEN, "mws"},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testRelayCommon(t, tc.address, tc.protocol, false)
testTCPRelay(t, tc.address, tc.protocol, false)
testUDPRelay(t, tc.address, false)
})
}
}
@@ -214,21 +165,22 @@ func TestRelayConcurrent(t *testing.T) {
address string
concurrency int
}{
{"MWSS", MWSS_LISTEN, 10},
{"MTCP", MTCP_LISTEN, 10},
{"MWS", MWS_LISTEN, 10},
{"Raw", RAW_LISTEN, 10},
{"WS", WS_LISTEN, 10},
{"WSS", WSS_LISTEN, 10},
}
for _, tc := range testCases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testRelayCommon(t, tc.address, tc.name, true, tc.concurrency)
testTCPRelay(t, tc.address, tc.name, true, tc.concurrency)
testUDPRelay(t, tc.address, true, tc.concurrency)
})
}
}
func testRelayCommon(t *testing.T, address, protocol string, concurrent bool, concurrency ...int) {
func testTCPRelay(t *testing.T, address, protocol string, concurrent bool, concurrency ...int) {
t.Helper()
msg := []byte("hello")
@@ -264,40 +216,42 @@ func testRelayCommon(t *testing.T, address, protocol string, concurrent bool, co
t.Logf("Test TCP over %s done!", protocol)
}
func TestRelayWithMaxConnectionCount(t *testing.T) {
msg := []byte("hello")
func testUDPRelay(t *testing.T, address string, concurrent bool, concurrency ...int) {
t.Helper()
msg := []byte("hello udp")
// First connection will be accepted
go func() {
err := echo.EchoTcpMsgLong(msg, time.Second, RAW_LISTEN_WITH_MAX_CONNECTION)
require.NoError(t, err, "First connection should be accepted")
}()
runTest := func() error {
res := echo.SendUdpMsg(msg, address)
if !bytes.Equal(msg, res) {
return fmt.Errorf("response mismatch: got %s, want %s", res, msg)
}
return nil
}
// Wait for first connection
time.Sleep(time.Second)
// Second connection should be rejected
err := echo.EchoTcpMsgLong(msg, time.Second, RAW_LISTEN_WITH_MAX_CONNECTION)
require.Error(t, err, "Second connection should be rejected")
if concurrent {
n := 10
if len(concurrency) > 0 {
n = concurrency[0]
}
g, ctx := errgroup.WithContext(context.Background())
for i := 0; i < n; i++ {
g.Go(func() error {
select {
case <-ctx.Done():
return ctx.Err()
default:
return runTest()
}
})
}
require.NoError(t, g.Wait(), "Concurrent test failed")
} else {
require.NoError(t, runTest(), "Single test failed")
}
t.Logf("Test UDP over %s done!", address)
}
func TestRelayWithDeadline(t *testing.T) {
logger, _ := zap.NewDevelopment()
msg := []byte("hello")
conn, err := net.Dial("tcp", RAW_LISTEN)
if err != nil {
logger.Sugar().Fatal(err)
}
defer conn.Close()
if _, err := conn.Write(msg); err != nil {
logger.Sugar().Fatal(err)
}
buf := make([]byte, len(msg))
constant.IdleTimeOut = time.Second // change for test
time.Sleep(constant.IdleTimeOut)
_, err = conn.Read(buf)
if err != nil {
logger.Sugar().Fatal("need error here")
}
func TestRelayIdleTimeout(t *testing.T) {
err := echo.EchoTcpMsgLong([]byte("hello"), time.Second, RAW_LISTEN)
require.Error(t, err, "Connection should be rejected")
}