mirror of
https://github.com/bolucat/Archive.git
synced 2025-12-24 13:28:37 +08:00
Update On Fri Aug 16 20:33:05 CEST 2024
This commit is contained in:
@@ -17,7 +17,4 @@ ehco 现在提供 SaaS(软件即服务)版本!这是一个全托管的解
|
||||
|
||||
- tcp/udp relay
|
||||
- tunnel relay (ws/wss/mwss/mtcp)
|
||||
- proxy server (内嵌了完整版本的 xray)
|
||||
- 监控报警 (Prometheus/Grafana)
|
||||
- WebAPI (http://web_host:web_port)
|
||||
- [更多功能请探索文档](https://docs.ehco-relay.cc/)
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
{
|
||||
"log_level": "debug",
|
||||
"relay_configs": [
|
||||
{
|
||||
"label": "relay-to-http-success",
|
||||
"listen": "127.0.0.1:1234",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["google.com:80"],
|
||||
"blocked_protocols": ["tls"]
|
||||
},
|
||||
{
|
||||
"label": "relay-to-http-fail",
|
||||
"listen": "127.0.0.1:1235",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["google.com:80"],
|
||||
"blocked_protocols": ["http"]
|
||||
},
|
||||
{
|
||||
"label": "relay-to-tls-success",
|
||||
"listen": "127.0.0.1:1236",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["google.com:443"]
|
||||
},
|
||||
{
|
||||
"label": "relay-to-tls-fail",
|
||||
"listen": "127.0.0.1:1237",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["google.com:443"],
|
||||
"blocked_protocols": ["tls"]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"web_port": 9000,
|
||||
"relay_configs": [
|
||||
{
|
||||
"listen": "127.0.0.1:1235",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "ws",
|
||||
"tcp_remotes": ["ws://0.0.0.0:8787"],
|
||||
"ws_config": {
|
||||
"path": "pwd",
|
||||
"remote_addr": "127.0.0.1:5201"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -9,71 +9,41 @@
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"label": "relay1",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"udp_remotes": ["0.0.0.0:5201"]
|
||||
"tcp_remotes": [
|
||||
"0.0.0.0:5201"
|
||||
]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:1235",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "ws",
|
||||
"tcp_remotes": ["ws://0.0.0.0:2443"],
|
||||
"udp_remotes": ["0.0.0.0:5201"]
|
||||
"tcp_remotes": [
|
||||
"ws://0.0.0.0:2443"
|
||||
]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:1236",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "wss",
|
||||
"tcp_remotes": ["wss://0.0.0.0:3443"],
|
||||
"udp_remotes": ["0.0.0.0:5201"]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:1237",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "mwss",
|
||||
"tcp_remotes": ["wss://0.0.0.0:4443"],
|
||||
"udp_remotes": ["0.0.0.0:5201"]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:1238",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "mtcp",
|
||||
"tcp_remotes": ["0.0.0.0:4444"],
|
||||
"udp_remotes": ["0.0.0.0:5201"]
|
||||
"tcp_remotes": [
|
||||
"wss://0.0.0.0:3443"
|
||||
]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:2443",
|
||||
"listen_type": "ws",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"udp_remotes": []
|
||||
"tcp_remotes": [
|
||||
"0.0.0.0:5201"
|
||||
]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:3443",
|
||||
"listen_type": "wss",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"udp_remotes": []
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:4443",
|
||||
"listen_type": "mwss",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"udp_remotes": []
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:4444",
|
||||
"listen_type": "mtcp",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"udp_remotes": []
|
||||
},
|
||||
{
|
||||
"label": "ping_test",
|
||||
"listen": "127.0.0.1:8888",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["8.8.8.8:5201", "google.com:5201"]
|
||||
"tcp_remotes": [
|
||||
"0.0.0.0:5201"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -1,23 +0,0 @@
|
||||
{
|
||||
"log_level": "info",
|
||||
"relay_configs": [
|
||||
{
|
||||
"label": "client",
|
||||
"listen": "127.0.0.1:1234",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": [
|
||||
"0.0.0.0:1235"
|
||||
]
|
||||
},
|
||||
{
|
||||
"label": "server",
|
||||
"listen": "127.0.0.1:1235",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": [
|
||||
"0.0.0.0:5201"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
{
|
||||
"relay_configs": [
|
||||
{
|
||||
"listen": "127.0.0.1:1235",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "mws",
|
||||
"tcp_remotes": ["ws://0.0.0.0:2443"],
|
||||
"ws_config": {
|
||||
"path": "pwd",
|
||||
"remote_addr": "127.0.0.1:5201"
|
||||
}
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:2443",
|
||||
"listen_type": "mws",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"ws_config": {
|
||||
"path": "pwd",
|
||||
"remote_addr": "127.0.0.1:5201"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
{
|
||||
"relay_configs": [
|
||||
{
|
||||
"listen": "127.0.0.1:1234",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"label": "iperf3",
|
||||
"tcp_remotes": [
|
||||
"0.0.0.0:5201"
|
||||
],
|
||||
"max_read_rate_kbps": 10000
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
{
|
||||
"web_port": 9000,
|
||||
"log_level": "debug",
|
||||
"reload_interval": 10,
|
||||
"relay_configs": [
|
||||
{
|
||||
"listen": "127.0.0.1:1234",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"udp_remotes": ["0.0.0.0:5201"]
|
||||
}
|
||||
],
|
||||
"sub_configs": [
|
||||
{
|
||||
"name": "sub1",
|
||||
"url": "xxx"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"web_port": 9000,
|
||||
"web_token": "",
|
||||
"web_auth_user": "user",
|
||||
"web_auth_pass": "pass",
|
||||
"log_level": "info",
|
||||
"relay_configs": [
|
||||
{
|
||||
"listen": "127.0.0.1:1234",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"label": "iperf3",
|
||||
"tcp_remotes": ["0.0.0.0:5201"]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,29 +0,0 @@
|
||||
{
|
||||
"web_port": 9000,
|
||||
"log_level": "info",
|
||||
"enable_ping": true,
|
||||
"relay_sync_interval": 6,
|
||||
"relay_configs": [
|
||||
{
|
||||
"listen": "127.0.0.1:1234",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "raw",
|
||||
"label": "raw",
|
||||
"tcp_remotes": ["192.168.31.30:5201"]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:1235",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "ws",
|
||||
"label": "ws",
|
||||
"tcp_remotes": ["ws://192.168.31.30:2443"]
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:1236",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "wss",
|
||||
"label": "wss",
|
||||
"tcp_remotes": ["wss://192.168.31.31:2443"]
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
{
|
||||
"relay_configs": [
|
||||
{
|
||||
"listen": "127.0.0.1:1235",
|
||||
"listen_type": "raw",
|
||||
"transport_type": "ws",
|
||||
"tcp_remotes": ["ws://0.0.0.0:2443"],
|
||||
"ws_config": {
|
||||
"path": "pwd",
|
||||
"remote_addr": "127.0.0.1:5201"
|
||||
}
|
||||
},
|
||||
{
|
||||
"listen": "127.0.0.1:2443",
|
||||
"listen_type": "ws",
|
||||
"transport_type": "raw",
|
||||
"tcp_remotes": ["0.0.0.0:5201"],
|
||||
"ws_config": {
|
||||
"path": "pwd",
|
||||
"remote_addr": "127.0.0.1:5201"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -20,10 +20,10 @@ require (
|
||||
github.com/sagernet/sing-box v1.9.3
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/urfave/cli/v2 v2.27.2
|
||||
github.com/xtaci/smux v1.5.24
|
||||
github.com/xtls/xray-core v1.8.16
|
||||
go.uber.org/atomic v1.11.0
|
||||
go.uber.org/zap v1.27.0
|
||||
golang.org/x/sync v0.7.0
|
||||
golang.org/x/time v0.5.0
|
||||
google.golang.org/grpc v1.65.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -115,7 +115,6 @@ require (
|
||||
golang.org/x/exp v0.0.0-20240531132922-fd00a4e0eefc // indirect
|
||||
golang.org/x/mod v0.18.0 // indirect
|
||||
golang.org/x/net v0.26.0 // indirect
|
||||
golang.org/x/sync v0.7.0 // indirect
|
||||
golang.org/x/sys v0.21.0 // indirect
|
||||
golang.org/x/text v0.16.0 // indirect
|
||||
golang.org/x/tools v0.22.0 // indirect
|
||||
|
||||
@@ -317,8 +317,6 @@ github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8
|
||||
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
|
||||
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw=
|
||||
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQutDLsQv2Deui4iYQ6DWTxR14g6m8Wv88+Xqk=
|
||||
github.com/xtaci/smux v1.5.24 h1:77emW9dtnOxxOQ5ltR+8BbsX1kzcOxQ5gB+aaV9hXOY=
|
||||
github.com/xtaci/smux v1.5.24/go.mod h1:OMlQbT5vcgl2gb49mFkYo6SMf+zP3rcjcwQz7ZU7IGY=
|
||||
github.com/xtls/reality v0.0.0-20240429224917-ecc4401070cc h1:0Nj8T1n7F7+v4vRVroaJIvY6R0vNABLfPH+lzPHRJvI=
|
||||
github.com/xtls/reality v0.0.0-20240429224917-ecc4401070cc/go.mod h1:dm4y/1QwzjGaK17ofi0Vs6NpKAHegZky8qk6J2JJZAE=
|
||||
github.com/xtls/xray-core v1.8.16 h1:PhbpdREAIvDS7xmxR6Sdpkx0h5ugmf6wIoWECWtJ0kE=
|
||||
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"github.com/Ehco1996/ehco/internal/metrics"
|
||||
"github.com/Ehco1996/ehco/internal/relay"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
"github.com/Ehco1996/ehco/internal/tls"
|
||||
"github.com/Ehco1996/ehco/internal/web"
|
||||
"github.com/Ehco1996/ehco/pkg/buffer"
|
||||
"github.com/Ehco1996/ehco/pkg/log"
|
||||
@@ -42,24 +41,11 @@ func loadConfig() (cfg *config.Config, err error) {
|
||||
if TCPRemoteAddr != "" {
|
||||
cfg.RelayConfigs[0].TCPRemotes = []string{TCPRemoteAddr}
|
||||
}
|
||||
if UDPRemoteAddr != "" {
|
||||
cfg.RelayConfigs[0].UDPRemotes = []string{UDPRemoteAddr}
|
||||
}
|
||||
if err := cfg.Adjust(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// init tls when need
|
||||
for _, cfg := range cfg.RelayConfigs {
|
||||
if cfg.ListenType == constant.RelayTypeWSS || cfg.ListenType == constant.RelayTypeMWSS ||
|
||||
cfg.TransportType == constant.RelayTypeWSS || cfg.TransportType == constant.RelayTypeMWSS {
|
||||
if err := tls.InitTlsCfg(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ var (
|
||||
LocalAddr string
|
||||
ListenType constant.RelayType
|
||||
TCPRemoteAddr string
|
||||
UDPRemoteAddr string
|
||||
TransportType constant.RelayType
|
||||
ConfigPath string
|
||||
WebPort int
|
||||
@@ -39,16 +38,10 @@ var RootFlags = []cli.Flag{
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "r,remote",
|
||||
Usage: "TCP 转发地址,例如 0.0.0.0:5201,通过 ws 隧道转发时应为 ws://0.0.0.0:2443",
|
||||
Usage: "转发地址,例如 0.0.0.0:5201,通过 ws 隧道转发时应为 ws://0.0.0.0:2443",
|
||||
EnvVars: []string{"EHCO_REMOTE_ADDR"},
|
||||
Destination: &TCPRemoteAddr,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "ur,udp_remote",
|
||||
Usage: "UDP 转发地址,例如 0.0.0.0:1234",
|
||||
EnvVars: []string{"EHCO_UDP_REMOTE_ADDR"},
|
||||
Destination: &UDPRemoteAddr,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "tt,transport_type",
|
||||
Value: "raw",
|
||||
|
||||
@@ -16,7 +16,8 @@ const (
|
||||
ConnectionTypeClosed = "closed"
|
||||
)
|
||||
|
||||
// connection manager interface
|
||||
// connection manager interface/
|
||||
// TODO support closed connection
|
||||
type Cmgr interface {
|
||||
ListConnections(connType string, page, pageSize int) []conn.RelayConn
|
||||
|
||||
|
||||
@@ -7,9 +7,10 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
myhttp "github.com/Ehco1996/ehco/pkg/http"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/constant"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
"github.com/Ehco1996/ehco/internal/tls"
|
||||
myhttp "github.com/Ehco1996/ehco/pkg/http"
|
||||
"github.com/Ehco1996/ehco/pkg/sub"
|
||||
xConf "github.com/xtls/xray-core/infra/conf"
|
||||
"go.uber.org/zap"
|
||||
@@ -121,6 +122,15 @@ func (c *Config) Adjust() error {
|
||||
}
|
||||
labelMap[r.Label] = struct{}{}
|
||||
}
|
||||
// init tls when need
|
||||
for _, r := range c.RelayConfigs {
|
||||
if r.ListenType == constant.RelayTypeWSS {
|
||||
if err := tls.InitTlsCfg(); err != nil {
|
||||
return err
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,18 +3,160 @@ package conn
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/constant"
|
||||
"github.com/Ehco1996/ehco/internal/lb"
|
||||
"github.com/Ehco1996/ehco/internal/metrics"
|
||||
"github.com/Ehco1996/ehco/pkg/buffer"
|
||||
"github.com/Ehco1996/ehco/pkg/bytes"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var idleTimeout = 30 * time.Second
|
||||
// RelayConn is the interface that represents a relay connection.
|
||||
// it contains two connections: clientConn and remoteConn
|
||||
// clientConn is the connection from the client to the relay server
|
||||
// remoteConn is the connection from the relay server to the remote server
|
||||
// and the main function is to transport data between the two connections
|
||||
type RelayConn interface {
|
||||
// Transport transports data between the client and the remote connection.
|
||||
Transport() error
|
||||
GetRelayLabel() string
|
||||
GetStats() *Stats
|
||||
Close() error
|
||||
}
|
||||
|
||||
type RelayConnOption func(*relayConnImpl)
|
||||
|
||||
func NewRelayConn(clientConn, remoteConn net.Conn, opts ...RelayConnOption) RelayConn {
|
||||
rci := &relayConnImpl{
|
||||
clientConn: clientConn,
|
||||
remoteConn: remoteConn,
|
||||
Stats: &Stats{},
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(rci)
|
||||
}
|
||||
if rci.l == nil {
|
||||
rci.l = zap.S().Named(rci.RelayLabel)
|
||||
}
|
||||
return rci
|
||||
}
|
||||
|
||||
type relayConnImpl struct {
|
||||
clientConn net.Conn
|
||||
remoteConn net.Conn
|
||||
|
||||
Closed bool `json:"closed"`
|
||||
|
||||
Stats *Stats `json:"stats"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time,omitempty"`
|
||||
|
||||
// options set those fields
|
||||
l *zap.SugaredLogger
|
||||
remote *lb.Node
|
||||
HandshakeDuration time.Duration
|
||||
RelayLabel string `json:"relay_label"`
|
||||
ConnType string `json:"conn_type"`
|
||||
}
|
||||
|
||||
func WithRelayLabel(relayLabel string) RelayConnOption {
|
||||
return func(rci *relayConnImpl) {
|
||||
rci.RelayLabel = relayLabel
|
||||
}
|
||||
}
|
||||
|
||||
func WithHandshakeDuration(duration time.Duration) RelayConnOption {
|
||||
return func(rci *relayConnImpl) {
|
||||
rci.HandshakeDuration = duration
|
||||
}
|
||||
}
|
||||
|
||||
func WithConnType(connType string) RelayConnOption {
|
||||
return func(rci *relayConnImpl) {
|
||||
rci.ConnType = connType
|
||||
}
|
||||
}
|
||||
|
||||
func WithRemote(remote *lb.Node) RelayConnOption {
|
||||
return func(rci *relayConnImpl) {
|
||||
rci.remote = remote
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(l *zap.SugaredLogger) RelayConnOption {
|
||||
return func(rci *relayConnImpl) {
|
||||
rci.l = l
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) Transport() error {
|
||||
defer rc.Close() // nolint: errcheck
|
||||
cl := rc.l.Named(shortHashSHA256(rc.GetFlow()))
|
||||
cl.Debugf("transport start, stats: %s", rc.Stats.String())
|
||||
c1 := newInnerConn(rc.clientConn, rc)
|
||||
c2 := newInnerConn(rc.remoteConn, rc)
|
||||
rc.StartTime = time.Now().Local()
|
||||
err := copyConn(c1, c2)
|
||||
if err != nil {
|
||||
cl.Errorf("transport error: %s", err.Error())
|
||||
}
|
||||
cl.Debugf("transport end, stats: %s", rc.Stats.String())
|
||||
rc.EndTime = time.Now().Local()
|
||||
return err
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) Close() error {
|
||||
err1 := rc.clientConn.Close()
|
||||
err2 := rc.remoteConn.Close()
|
||||
rc.Closed = true
|
||||
return combineErrorsAndMuteEOF(err1, err2)
|
||||
}
|
||||
|
||||
// functions that for web ui
|
||||
func (rc *relayConnImpl) GetTime() string {
|
||||
if rc.EndTime.IsZero() {
|
||||
return fmt.Sprintf("%s - N/A", rc.StartTime.Format(time.Stamp))
|
||||
}
|
||||
return fmt.Sprintf("%s - %s", rc.StartTime.Format(time.Stamp), rc.EndTime.Format(time.Stamp))
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) GetFlow() string {
|
||||
return fmt.Sprintf("%s <-> %s", rc.clientConn.RemoteAddr(), rc.remoteConn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) GetRelayLabel() string {
|
||||
return rc.RelayLabel
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) GetStats() *Stats {
|
||||
return rc.Stats
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) GetConnType() string {
|
||||
return rc.ConnType
|
||||
}
|
||||
|
||||
func combineErrorsAndMuteEOF(err1, err2 error) error {
|
||||
if err1 == io.EOF {
|
||||
err1 = nil
|
||||
}
|
||||
if err2 == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err1 != nil && err2 != nil {
|
||||
return errors.Join(err1, err2)
|
||||
}
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
type Stats struct {
|
||||
Up int64
|
||||
@@ -35,52 +177,69 @@ func (s *Stats) String() string {
|
||||
)
|
||||
}
|
||||
|
||||
// note that innerConn is a wrapper around net.Conn to allow io.Copy to be used
|
||||
type innerConn struct {
|
||||
net.Conn
|
||||
lastActive time.Time
|
||||
|
||||
remoteLabel string
|
||||
stats *Stats
|
||||
rc *relayConnImpl
|
||||
}
|
||||
|
||||
func (c *innerConn) setDeadline(isRead bool) {
|
||||
// set the read deadline to avoid hanging read for non-TCP connections
|
||||
// because tcp connections have closeWrite/closeRead so no need to set read deadline
|
||||
if _, ok := c.Conn.(*net.TCPConn); !ok {
|
||||
deadline := time.Now().Add(idleTimeout)
|
||||
if isRead {
|
||||
_ = c.Conn.SetReadDeadline(deadline)
|
||||
func newInnerConn(conn net.Conn, rc *relayConnImpl) *innerConn {
|
||||
return &innerConn{Conn: conn, rc: rc, lastActive: time.Now()}
|
||||
}
|
||||
|
||||
func (c *innerConn) recordStats(n int, isRead bool) {
|
||||
if c.rc == nil {
|
||||
return
|
||||
}
|
||||
if isRead {
|
||||
metrics.NetWorkTransmitBytes.WithLabelValues(
|
||||
c.rc.remote.Label, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_READ,
|
||||
).Add(float64(n))
|
||||
c.rc.Stats.Record(0, int64(n))
|
||||
} else {
|
||||
metrics.NetWorkTransmitBytes.WithLabelValues(
|
||||
c.rc.remote.Label, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_WRITE,
|
||||
).Add(float64(n))
|
||||
c.rc.Stats.Record(int64(n), 0)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *innerConn) Read(p []byte) (n int, err error) {
|
||||
for {
|
||||
deadline := time.Now().Add(constant.ReadTimeOut)
|
||||
if err := c.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n, err = c.Conn.Read(p)
|
||||
if err == nil {
|
||||
c.recordStats(n, true)
|
||||
c.lastActive = time.Now()
|
||||
return
|
||||
} else {
|
||||
_ = c.Conn.SetWriteDeadline(deadline)
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
if time.Since(c.lastActive) > constant.IdleTimeOut {
|
||||
c.rc.l.Debugf("read idle,close remote: %s", c.rc.remote.Label)
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *innerConn) recordStats(n int, isRead bool) {
|
||||
if isRead {
|
||||
metrics.NetWorkTransmitBytes.WithLabelValues(
|
||||
c.remoteLabel, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_READ,
|
||||
).Add(float64(n))
|
||||
c.stats.Record(0, int64(n))
|
||||
} else {
|
||||
metrics.NetWorkTransmitBytes.WithLabelValues(
|
||||
c.remoteLabel, metrics.METRIC_CONN_TYPE_TCP, metrics.METRIC_CONN_FLOW_WRITE,
|
||||
).Add(float64(n))
|
||||
c.stats.Record(int64(n), 0)
|
||||
}
|
||||
}
|
||||
|
||||
// 修改Read和Write方法以使用recordStats
|
||||
func (c *innerConn) Read(p []byte) (n int, err error) {
|
||||
c.setDeadline(true)
|
||||
n, err = c.Conn.Read(p)
|
||||
c.recordStats(n, true) // true for read operation
|
||||
return
|
||||
}
|
||||
|
||||
func (c *innerConn) Write(p []byte) (n int, err error) {
|
||||
c.setDeadline(false)
|
||||
if time.Since(c.lastActive) > constant.IdleTimeOut {
|
||||
c.rc.l.Debugf("write idle,close remote: %s", c.rc.remote.Label)
|
||||
return 0, io.EOF
|
||||
}
|
||||
n, err = c.Conn.Write(p)
|
||||
c.recordStats(n, false) // false for write operation
|
||||
if err != nil {
|
||||
c.lastActive = time.Now()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -109,10 +268,6 @@ func shortHashSHA256(input string) string {
|
||||
return hex.EncodeToString(hash)[:7]
|
||||
}
|
||||
|
||||
func connectionName(conn net.Conn) string {
|
||||
return fmt.Sprintf("l:<%s> r:<%s>", conn.LocalAddr(), conn.RemoteAddr())
|
||||
}
|
||||
|
||||
func copyConn(conn1, conn2 *innerConn) error {
|
||||
buf := buffer.BufferPool.Get()
|
||||
defer buffer.BufferPool.Put(buf)
|
||||
@@ -125,134 +280,13 @@ func copyConn(conn1, conn2 *innerConn) error {
|
||||
}()
|
||||
|
||||
// reverse copy conn2 to conn1,read from conn2 and write to conn1
|
||||
_, err := io.Copy(conn1, conn2)
|
||||
buf2 := buffer.BufferPool.Get()
|
||||
defer buffer.BufferPool.Put(buf2)
|
||||
_, err := io.CopyBuffer(conn1, conn2, buf2)
|
||||
_ = conn1.CloseWrite()
|
||||
|
||||
err2 := <-errCH
|
||||
_ = conn1.CloseRead()
|
||||
_ = conn2.CloseRead()
|
||||
|
||||
// handle errors, need to combine errors from both directions
|
||||
if err != nil && err2 != nil {
|
||||
err = fmt.Errorf("transport errors in both directions: %v, %v", err, err2)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return err2
|
||||
}
|
||||
|
||||
type RelayConnOption func(*relayConnImpl)
|
||||
|
||||
type RelayConn interface {
|
||||
// Transport transports data between the client and the remote server.
|
||||
// The remoteLabel is the label of the remote server.
|
||||
Transport(remoteLabel string) error
|
||||
|
||||
// GetRelayLabel returns the label of the Relay instance.
|
||||
GetRelayLabel() string
|
||||
|
||||
GetStats() *Stats
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
func NewRelayConn(relayName string, clientConn, remoteConn net.Conn, opts ...RelayConnOption) RelayConn {
|
||||
rci := &relayConnImpl{
|
||||
RelayLabel: relayName,
|
||||
clientConn: clientConn,
|
||||
remoteConn: remoteConn,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(rci)
|
||||
}
|
||||
s := &Stats{Up: 0, Down: 0, HandShakeLatency: rci.HandshakeDuration}
|
||||
rci.Stats = s
|
||||
return rci
|
||||
}
|
||||
|
||||
type relayConnImpl struct {
|
||||
RelayLabel string `json:"relay_label"`
|
||||
Closed bool `json:"closed"`
|
||||
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time,omitempty"`
|
||||
|
||||
Stats *Stats `json:"stats"`
|
||||
HandshakeDuration time.Duration
|
||||
|
||||
clientConn net.Conn
|
||||
remoteConn net.Conn
|
||||
}
|
||||
|
||||
func WithHandshakeDuration(duration time.Duration) RelayConnOption {
|
||||
return func(rci *relayConnImpl) {
|
||||
rci.HandshakeDuration = duration
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) Transport(remoteLabel string) error {
|
||||
defer rc.Close() // nolint: errcheck
|
||||
name := rc.Name()
|
||||
shortName := fmt.Sprintf("%s-%s", rc.RelayLabel, shortHashSHA256(name))
|
||||
cl := zap.L().Named(shortName)
|
||||
cl.Debug("transport start", zap.String("full name", name), zap.String("stats", rc.Stats.String()))
|
||||
c1 := &innerConn{
|
||||
stats: rc.Stats,
|
||||
remoteLabel: remoteLabel,
|
||||
Conn: rc.clientConn,
|
||||
}
|
||||
c2 := &innerConn{
|
||||
stats: rc.Stats,
|
||||
remoteLabel: remoteLabel,
|
||||
Conn: rc.remoteConn,
|
||||
}
|
||||
rc.StartTime = time.Now().Local()
|
||||
err := copyConn(c1, c2)
|
||||
if err != nil {
|
||||
cl.Error("transport error", zap.Error(err))
|
||||
}
|
||||
cl.Debug("transport end", zap.String("stats", rc.Stats.String()))
|
||||
rc.EndTime = time.Now().Local()
|
||||
return err
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) GetTime() string {
|
||||
if rc.EndTime.IsZero() {
|
||||
return fmt.Sprintf("%s - N/A", rc.StartTime.Format(time.Stamp))
|
||||
}
|
||||
return fmt.Sprintf("%s - %s", rc.StartTime.Format(time.Stamp), rc.EndTime.Format(time.Stamp))
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) Name() string {
|
||||
return fmt.Sprintf("c1:[%s] c2:[%s]", connectionName(rc.clientConn), connectionName(rc.remoteConn))
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) Flow() string {
|
||||
return fmt.Sprintf("%s <-> %s", rc.clientConn.RemoteAddr(), rc.remoteConn.RemoteAddr())
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) GetRelayLabel() string {
|
||||
return rc.RelayLabel
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) GetStats() *Stats {
|
||||
return rc.Stats
|
||||
}
|
||||
|
||||
func (rc *relayConnImpl) Close() error {
|
||||
err1 := rc.clientConn.Close()
|
||||
err2 := rc.remoteConn.Close()
|
||||
rc.Closed = true
|
||||
return combineErrors(err1, err2)
|
||||
}
|
||||
|
||||
func combineErrors(err1, err2 error) error {
|
||||
if err1 != nil && err2 != nil {
|
||||
return fmt.Errorf("combineErrors: %v, %v", err1, err2)
|
||||
}
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
return combineErrorsAndMuteEOF(err, err2)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/lb"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
@@ -18,11 +19,9 @@ func TestInnerConn_ReadWrite(t *testing.T) {
|
||||
serverConn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
defer clientConn.Close()
|
||||
defer serverConn.Close()
|
||||
|
||||
innerC := &innerConn{Conn: clientConn, stats: &Stats{}, remoteLabel: "test"}
|
||||
|
||||
rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}}
|
||||
innerC := newInnerConn(clientConn, &rc)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
_, err := innerC.Write(testData)
|
||||
errChan <- err
|
||||
@@ -39,7 +38,7 @@ func TestInnerConn_ReadWrite(t *testing.T) {
|
||||
if err := <-errChan; err != nil {
|
||||
t.Fatalf("write err: %v", err)
|
||||
}
|
||||
assert.Equal(t, int64(len(testData)), innerC.stats.Up)
|
||||
assert.Equal(t, int64(len(testData)), rc.Stats.Up)
|
||||
|
||||
errChan = make(chan error, 1)
|
||||
clientConn.SetDeadline(time.Now().Add(1 * time.Second))
|
||||
@@ -64,7 +63,7 @@ func TestInnerConn_ReadWrite(t *testing.T) {
|
||||
if err := <-errChan; err != nil {
|
||||
t.Fatalf("write error: %v", err)
|
||||
}
|
||||
assert.Equal(t, int64(len(testData)), innerC.stats.Down)
|
||||
assert.Equal(t, int64(len(testData)), rc.Stats.Down)
|
||||
}
|
||||
|
||||
func TestCopyTCPConn(t *testing.T) {
|
||||
@@ -96,8 +95,9 @@ func TestCopyTCPConn(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
defer remoteConn.Close()
|
||||
|
||||
c1 := &innerConn{Conn: clientConn, remoteLabel: "client", stats: &Stats{}}
|
||||
c2 := &innerConn{Conn: remoteConn, remoteLabel: "server", stats: &Stats{}}
|
||||
rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}}
|
||||
c1 := newInnerConn(clientConn, &rc)
|
||||
c2 := newInnerConn(remoteConn, &rc)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
@@ -155,8 +155,9 @@ func TestCopyUDPConn(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
defer remoteConn.Close()
|
||||
|
||||
c1 := &innerConn{Conn: clientConn, remoteLabel: "client", stats: &Stats{}}
|
||||
c2 := &innerConn{Conn: remoteConn, remoteLabel: "server", stats: &Stats{}}
|
||||
rc := relayConnImpl{Stats: &Stats{}, remote: &lb.Node{Label: "client"}}
|
||||
c1 := newInnerConn(clientConn, &rc)
|
||||
c2 := newInnerConn(remoteConn, &rc)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
|
||||
187
echo/internal/conn/udp_listener.go
Normal file
187
echo/internal/conn/udp_listener.go
Normal file
@@ -0,0 +1,187 @@
|
||||
//nolint:errcheck
|
||||
package conn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/constant"
|
||||
"github.com/Ehco1996/ehco/pkg/buffer"
|
||||
)
|
||||
|
||||
var _ net.Conn = &uc{}
|
||||
|
||||
type uc struct {
|
||||
conn *net.UDPConn
|
||||
addr *net.UDPAddr
|
||||
|
||||
msgCh chan []byte
|
||||
|
||||
lastActivity atomic.Value
|
||||
|
||||
listener *UDPListener
|
||||
}
|
||||
|
||||
func (c *uc) Read(b []byte) (int, error) {
|
||||
select {
|
||||
case msg := <-c.msgCh:
|
||||
n := copy(b, msg)
|
||||
c.lastActivity.Store(time.Now())
|
||||
return n, nil
|
||||
default:
|
||||
if time.Since(c.lastActivity.Load().(time.Time)) > constant.IdleTimeOut {
|
||||
return 0, io.EOF
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *uc) Write(b []byte) (int, error) {
|
||||
n, err := c.conn.WriteToUDP(b, c.addr)
|
||||
c.lastActivity.Store(time.Now())
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *uc) Close() error {
|
||||
c.listener.connsMu.Lock()
|
||||
delete(c.listener.conns, c.addr.String())
|
||||
c.listener.connsMu.Unlock()
|
||||
close(c.msgCh)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *uc) LocalAddr() net.Addr {
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (c *uc) RemoteAddr() net.Addr {
|
||||
return c.addr
|
||||
}
|
||||
|
||||
func (c *uc) SetDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *uc) SetReadDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *uc) SetWriteDeadline(t time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type UDPListener struct {
|
||||
listenAddr *net.UDPAddr
|
||||
listenConn *net.UDPConn
|
||||
|
||||
conns map[string]*uc
|
||||
connsMu sync.RWMutex
|
||||
connCh chan *uc
|
||||
msgCh chan []byte
|
||||
errCh chan error
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
func NewUDPListener(ctx context.Context, addr string) (*UDPListener, error) {
|
||||
udpAddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
l := &UDPListener{
|
||||
listenConn: conn,
|
||||
listenAddr: udpAddr,
|
||||
|
||||
conns: make(map[string]*uc),
|
||||
connCh: make(chan *uc),
|
||||
msgCh: make(chan []byte),
|
||||
errCh: make(chan error),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
go l.listen()
|
||||
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (l *UDPListener) listen() {
|
||||
defer l.listenConn.Close()
|
||||
for {
|
||||
if l.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
buf := buffer.UDPBufferPool.Get()
|
||||
n, addr, err := l.listenConn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
if !l.closed.Load() {
|
||||
select {
|
||||
case l.errCh <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
buffer.UDPBufferPool.Put(buf)
|
||||
continue
|
||||
}
|
||||
|
||||
l.connsMu.RLock()
|
||||
udpConn, exists := l.conns[addr.String()]
|
||||
l.connsMu.RUnlock()
|
||||
if !exists {
|
||||
l.connsMu.Lock()
|
||||
udpConn = &uc{
|
||||
conn: l.listenConn,
|
||||
addr: addr,
|
||||
listener: l,
|
||||
msgCh: make(chan []byte, 10),
|
||||
lastActivity: atomic.Value{},
|
||||
}
|
||||
udpConn.lastActivity.Store(time.Now())
|
||||
l.conns[addr.String()] = udpConn
|
||||
l.connCh <- udpConn
|
||||
l.connsMu.Unlock()
|
||||
}
|
||||
|
||||
select {
|
||||
case udpConn.msgCh <- buf[:n]:
|
||||
default:
|
||||
buffer.UDPBufferPool.Put(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UDPListener) Accept() (*uc, error) {
|
||||
select {
|
||||
case conn := <-l.connCh:
|
||||
return conn, nil
|
||||
case err := <-l.errCh:
|
||||
return nil, err
|
||||
case <-l.ctx.Done():
|
||||
return nil, l.ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (l *UDPListener) Close() error {
|
||||
if !l.closed.CompareAndSwap(false, true) {
|
||||
return nil
|
||||
}
|
||||
l.cancel()
|
||||
l.closed.Store(true)
|
||||
return l.listenConn.Close()
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// wsConn represents a WebSocket connection to relay(io.Copy)
|
||||
type wsConn struct {
|
||||
conn net.Conn
|
||||
isServer bool
|
||||
@@ -29,7 +30,7 @@ func (c *wsConn) Read(b []byte) (n int, err error) {
|
||||
}
|
||||
if header.Length > int64(cap(c.buf)) {
|
||||
zap.S().Warnf("ws payload size:%d is larger than buffer size:%d", header.Length, cap(c.buf))
|
||||
c.buf = make([]byte, header.Length)
|
||||
return 0, fmt.Errorf("buffer size:%d too small to transport ws payload size:%d", len(b), header.Length)
|
||||
}
|
||||
payload := c.buf[:header.Length]
|
||||
_, err = io.ReadFull(c.conn, payload)
|
||||
|
||||
@@ -6,7 +6,9 @@ type RelayType string
|
||||
|
||||
var (
|
||||
// allow change in test
|
||||
IdleTimeOut = 10 * time.Second
|
||||
// TODO Set to Relay Config
|
||||
ReadTimeOut = 5 * time.Second
|
||||
IdleTimeOut = 30 * time.Second
|
||||
|
||||
Version = "1.1.5-dev"
|
||||
GitBranch string
|
||||
@@ -20,24 +22,17 @@ const (
|
||||
|
||||
SniffTimeOut = 300 * time.Millisecond
|
||||
|
||||
SmuxGCDuration = 30 * time.Second
|
||||
SmuxMaxAliveDuration = 10 * time.Minute
|
||||
SmuxMaxStreamCnt = 5
|
||||
|
||||
// todo add udp buffer size
|
||||
// todo,support config in relay config
|
||||
BUFFER_POOL_SIZE = 1024 // support 512 connections
|
||||
BUFFER_SIZE = 20 * 1024 // 20KB the maximum packet size of shadowsocks is about 16 KiB
|
||||
BUFFER_SIZE = 40 * 1024 // 40KB ,the maximum packet size of shadowsocks is about 16 KiB so this is enough
|
||||
UDPBufSize = 1500 // use default max mtu 1500
|
||||
)
|
||||
|
||||
// relay type
|
||||
const (
|
||||
// tcp relay
|
||||
RelayTypeRaw RelayType = "raw"
|
||||
RelayTypeMTCP RelayType = "mtcp"
|
||||
|
||||
// direct relay
|
||||
RelayTypeRaw RelayType = "raw"
|
||||
// ws relay
|
||||
RelayTypeWS RelayType = "ws"
|
||||
RelayTypeMWS RelayType = "mws"
|
||||
RelayTypeWSS RelayType = "wss"
|
||||
RelayTypeMWSS RelayType = "mwss"
|
||||
RelayTypeWS RelayType = "ws"
|
||||
RelayTypeWSS RelayType = "wss"
|
||||
)
|
||||
|
||||
@@ -23,24 +23,51 @@ type WSConfig struct {
|
||||
RemoteAddr string `json:"remote_addr,omitempty"`
|
||||
}
|
||||
|
||||
func (w *WSConfig) Clone() *WSConfig {
|
||||
return &WSConfig{
|
||||
Path: w.Path,
|
||||
RemoteAddr: w.RemoteAddr,
|
||||
}
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
WSConfig *WSConfig `json:"ws_config,omitempty"`
|
||||
EnableUDP bool `json:"enable_udp,omitempty"`
|
||||
EnableMultipathTCP bool `json:"enable_multipath_tcp,omitempty"`
|
||||
|
||||
MaxConnection int `json:"max_connection,omitempty"`
|
||||
BlockedProtocols []string `json:"blocked_protocols,omitempty"`
|
||||
MaxReadRateKbps int64 `json:"max_read_rate_kbps,omitempty"`
|
||||
}
|
||||
|
||||
func (o *Options) Clone() *Options {
|
||||
opt := &Options{
|
||||
EnableUDP: o.EnableUDP,
|
||||
EnableMultipathTCP: o.EnableMultipathTCP,
|
||||
MaxConnection: o.MaxConnection,
|
||||
MaxReadRateKbps: o.MaxReadRateKbps,
|
||||
BlockedProtocols: make([]string, len(o.BlockedProtocols)),
|
||||
}
|
||||
copy(opt.BlockedProtocols, o.BlockedProtocols)
|
||||
if o.WSConfig != nil {
|
||||
opt.WSConfig = o.WSConfig.Clone()
|
||||
}
|
||||
return opt
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Label string `json:"label,omitempty"`
|
||||
Listen string `json:"listen"`
|
||||
ListenType constant.RelayType `json:"listen_type"`
|
||||
TransportType constant.RelayType `json:"transport_type"`
|
||||
TCPRemotes []string `json:"tcp_remotes"`
|
||||
UDPRemotes []string `json:"udp_remotes"`
|
||||
TCPRemotes []string `json:"tcp_remotes"` // TODO rename to remotes
|
||||
|
||||
MaxConnection int `json:"max_connection,omitempty"`
|
||||
BlockedProtocols []string `json:"blocked_protocols,omitempty"`
|
||||
MaxReadRateKbps int64 `json:"max_read_rate_kbps,omitempty"`
|
||||
|
||||
WSConfig *WSConfig `json:"ws_config,omitempty"`
|
||||
Options *Options `json:"options,omitempty"`
|
||||
}
|
||||
|
||||
func (r *Config) GetWSHandShakePath() string {
|
||||
if r.WSConfig != nil && r.WSConfig.Path != "" {
|
||||
return r.WSConfig.Path
|
||||
if r.Options != nil && r.Options.WSConfig != nil && r.Options.WSConfig.Path != "" {
|
||||
return r.Options.WSConfig.Path
|
||||
}
|
||||
return WS_HANDSHAKE_PATH
|
||||
}
|
||||
@@ -50,8 +77,8 @@ func (r *Config) GetWSRemoteAddr(baseAddr string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if r.WSConfig != nil && r.WSConfig.RemoteAddr != "" {
|
||||
addr += fmt.Sprintf("?%s=%s", WS_QUERY_REMOTE_ADDR, r.WSConfig.RemoteAddr)
|
||||
if r.Options != nil && r.Options.WSConfig != nil && r.Options.WSConfig.RemoteAddr != "" {
|
||||
addr += fmt.Sprintf("?%s=%s", WS_QUERY_REMOTE_ADDR, r.Options.WSConfig.RemoteAddr)
|
||||
}
|
||||
return addr, nil
|
||||
}
|
||||
@@ -79,17 +106,7 @@ func (r *Config) Validate() error {
|
||||
}
|
||||
}
|
||||
|
||||
for _, addr := range r.UDPRemotes {
|
||||
if addr == "" {
|
||||
return fmt.Errorf("invalid udp remote addr:%s", addr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.TCPRemotes) == 0 && len(r.UDPRemotes) == 0 {
|
||||
return errors.New("both tcp and udp remotes are empty")
|
||||
}
|
||||
|
||||
for _, protocol := range r.BlockedProtocols {
|
||||
for _, protocol := range r.Options.BlockedProtocols {
|
||||
if protocol != ProtocolHTTP && protocol != ProtocolTLS {
|
||||
return fmt.Errorf("invalid blocked protocol:%s", protocol)
|
||||
}
|
||||
@@ -103,11 +120,10 @@ func (r *Config) Clone() *Config {
|
||||
ListenType: r.ListenType,
|
||||
TransportType: r.TransportType,
|
||||
Label: r.Label,
|
||||
Options: r.Options.Clone(),
|
||||
}
|
||||
new.TCPRemotes = make([]string, len(r.TCPRemotes))
|
||||
copy(new.TCPRemotes, r.TCPRemotes)
|
||||
new.UDPRemotes = make([]string, len(r.UDPRemotes))
|
||||
copy(new.UDPRemotes, r.UDPRemotes)
|
||||
return new
|
||||
}
|
||||
|
||||
@@ -121,26 +137,17 @@ func (r *Config) Different(new *Config) bool {
|
||||
if len(r.TCPRemotes) != len(new.TCPRemotes) {
|
||||
return true
|
||||
}
|
||||
|
||||
for i, addr := range r.TCPRemotes {
|
||||
if addr != new.TCPRemotes[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if len(r.UDPRemotes) != len(new.UDPRemotes) {
|
||||
return true
|
||||
}
|
||||
for i, addr := range r.UDPRemotes {
|
||||
if addr != new.UDPRemotes[i] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// todo make this shorter and more readable
|
||||
func (r *Config) DefaultLabel() string {
|
||||
defaultLabel := fmt.Sprintf("<At=%s TCP-To=%s TP=%s>",
|
||||
defaultLabel := fmt.Sprintf("<At=%s To=%s TP=%s>",
|
||||
r.Listen, r.TCPRemotes, r.TransportType)
|
||||
return defaultLabel
|
||||
}
|
||||
@@ -150,6 +157,12 @@ func (r *Config) Adjust() error {
|
||||
r.Label = r.DefaultLabel()
|
||||
zap.S().Debugf("label is empty, set default label:%s", r.Label)
|
||||
}
|
||||
if r.Options == nil {
|
||||
r.Options = &Options{
|
||||
WSConfig: &WSConfig{},
|
||||
EnableMultipathTCP: true, // default enable multipath tcp
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -171,20 +184,14 @@ func (r *Config) GetLoggerName() string {
|
||||
func (r *Config) validateType() error {
|
||||
if r.ListenType != constant.RelayTypeRaw &&
|
||||
r.ListenType != constant.RelayTypeWS &&
|
||||
r.ListenType != constant.RelayTypeMWS &&
|
||||
r.ListenType != constant.RelayTypeWSS &&
|
||||
r.ListenType != constant.RelayTypeMTCP &&
|
||||
r.ListenType != constant.RelayTypeMWSS {
|
||||
r.ListenType != constant.RelayTypeWSS {
|
||||
return fmt.Errorf("invalid listen type:%s", r.ListenType)
|
||||
}
|
||||
|
||||
if r.TransportType != constant.RelayTypeRaw &&
|
||||
r.TransportType != constant.RelayTypeWS &&
|
||||
r.TransportType != constant.RelayTypeMWS &&
|
||||
r.TransportType != constant.RelayTypeWSS &&
|
||||
r.TransportType != constant.RelayTypeMTCP &&
|
||||
r.TransportType != constant.RelayTypeMWSS {
|
||||
return fmt.Errorf("invalid transport type:%s", r.ListenType)
|
||||
r.TransportType != constant.RelayTypeWSS {
|
||||
return fmt.Errorf("invalid transport type:%s", r.TransportType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/cmgr"
|
||||
@@ -33,17 +35,17 @@ func NewRelay(cfg *conf.Config, cmgr cmgr.Cmgr) (*Relay, error) {
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (r *Relay) ListenAndServe() error {
|
||||
func (r *Relay) ListenAndServe(ctx context.Context) error {
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
r.l.Infof("Start TCP Relay Server:%s", r.cfg.DefaultLabel())
|
||||
errCh <- r.relayServer.ListenAndServe()
|
||||
r.l.Infof("Start Relay Server(%s):%s", r.cfg.ListenType, r.cfg.DefaultLabel())
|
||||
errCh <- r.relayServer.ListenAndServe(ctx)
|
||||
}()
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func (r *Relay) Close() {
|
||||
r.l.Infof("Close TCP Relay Server:%s", r.cfg.DefaultLabel())
|
||||
r.l.Infof("Close Relay Server:%s", r.cfg.DefaultLabel())
|
||||
if err := r.relayServer.Close(); err != nil {
|
||||
r.l.Errorf(err.Error())
|
||||
}
|
||||
|
||||
@@ -46,10 +46,10 @@ func NewServer(cfg *config.Config) (*Server, error) {
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Server) startOneRelay(r *Relay) {
|
||||
func (s *Server) startOneRelay(ctx context.Context, r *Relay) {
|
||||
s.relayM.Store(r.UniqueID(), r)
|
||||
// mute closed network error for tcp server and mute http.ErrServerClosed for http server when config reload
|
||||
if err := r.ListenAndServe(); err != nil &&
|
||||
if err := r.ListenAndServe(ctx); err != nil &&
|
||||
!errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||
s.l.Errorf("start relay %s meet error: %s", r.UniqueID(), err)
|
||||
s.errCH <- err
|
||||
@@ -68,7 +68,7 @@ func (s *Server) Start(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go s.startOneRelay(r)
|
||||
go s.startOneRelay(ctx, r)
|
||||
}
|
||||
|
||||
if s.cfg.PATH != "" && (s.cfg.ReloadInterval > 0 || len(s.cfg.SubConfigs) > 0) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/glue"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
"go.uber.org/zap"
|
||||
@@ -48,7 +50,7 @@ func (s *Server) Reload(force bool) error {
|
||||
s.l.Error("new relay meet error", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
go s.startOneRelay(r)
|
||||
go s.startOneRelay(context.TODO(), r)
|
||||
} else {
|
||||
// when label not change, check if config changed
|
||||
oldCfg, ok := oldRelayCfgM[newCfg.Label]
|
||||
@@ -66,7 +68,7 @@ func (s *Server) Reload(force bool) error {
|
||||
s.l.Error("new relay meet error", zap.Error(err))
|
||||
continue
|
||||
}
|
||||
go s.startOneRelay(r)
|
||||
go s.startOneRelay(context.TODO(), r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,97 +18,160 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type baseTransporter struct {
|
||||
cfg *conf.Config
|
||||
l *zap.SugaredLogger
|
||||
var _ RelayServer = &BaseRelayServer{}
|
||||
|
||||
cmgr cmgr.Cmgr
|
||||
tCPRemotes lb.RoundRobin
|
||||
relayer RelayClient
|
||||
type BaseRelayServer struct {
|
||||
cmgr cmgr.Cmgr
|
||||
cfg *conf.Config
|
||||
l *zap.SugaredLogger
|
||||
|
||||
remotes lb.RoundRobin
|
||||
relayer RelayClient
|
||||
}
|
||||
|
||||
func NewBaseTransporter(cfg *conf.Config, cmgr cmgr.Cmgr) (*baseTransporter, error) {
|
||||
func newBaseRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (*BaseRelayServer, error) {
|
||||
relayer, err := newRelayClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &baseTransporter{
|
||||
cfg: cfg,
|
||||
cmgr: cmgr,
|
||||
tCPRemotes: cfg.ToTCPRemotes(),
|
||||
l: zap.S().Named(cfg.GetLoggerName()),
|
||||
relayer: relayer,
|
||||
return &BaseRelayServer{
|
||||
relayer: relayer,
|
||||
cfg: cfg,
|
||||
cmgr: cmgr,
|
||||
remotes: cfg.ToTCPRemotes(),
|
||||
l: zap.S().Named(cfg.GetLoggerName()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *baseTransporter) GetTCPListenAddr() (*net.TCPAddr, error) {
|
||||
return net.ResolveTCPAddr("tcp", b.cfg.Listen)
|
||||
}
|
||||
|
||||
func (b *baseTransporter) GetRemote() *lb.Node {
|
||||
return b.tCPRemotes.Next()
|
||||
}
|
||||
|
||||
func (b *baseTransporter) RelayTCPConn(c net.Conn, handshakeF TCPHandShakeF) error {
|
||||
remote := b.GetRemote()
|
||||
func (b *BaseRelayServer) RelayTCPConn(ctx context.Context, c net.Conn) error {
|
||||
remote := b.remotes.Next().Clone()
|
||||
metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Inc()
|
||||
defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_TCP).Dec()
|
||||
|
||||
// check limit
|
||||
if b.cfg.MaxConnection > 0 && b.cmgr.CountConnection(cmgr.ConnectionTypeActive) >= b.cfg.MaxConnection {
|
||||
c.Close()
|
||||
return fmt.Errorf("relay:%s active connection count exceed limit %d", b.cfg.Label, b.cfg.MaxConnection)
|
||||
if err := b.checkConnectionLimit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sniff protocol
|
||||
if len(b.cfg.BlockedProtocols) > 0 {
|
||||
buffer := buf.NewPacket()
|
||||
ctx := context.TODO()
|
||||
sniffMetadata, err := sniff.PeekStream(
|
||||
ctx, c, buffer, constant.SniffTimeOut,
|
||||
sniff.TLSClientHello, sniff.HTTPHost)
|
||||
if err != nil {
|
||||
// this mean no protocol sniffed
|
||||
b.l.Debug("sniff error: %s", err)
|
||||
}
|
||||
if sniffMetadata != nil {
|
||||
b.l.Infof("sniffed protocol: %s", sniffMetadata.Protocol)
|
||||
for _, p := range b.cfg.BlockedProtocols {
|
||||
if sniffMetadata.Protocol == p {
|
||||
c.Close()
|
||||
return fmt.Errorf("relay:%s want to relay blocked protocol:%s", b.cfg.Label, sniffMetadata.Protocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !buffer.IsEmpty() {
|
||||
c = bufio.NewCachedConn(c, buffer)
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
}
|
||||
|
||||
// rate limit
|
||||
if b.cfg.MaxReadRateKbps > 0 {
|
||||
c = conn.NewRateLimitedConn(c, b.cfg.MaxReadRateKbps)
|
||||
}
|
||||
|
||||
clonedRemote := remote.Clone()
|
||||
rc, err := handshakeF(clonedRemote)
|
||||
var err error
|
||||
c, err = b.sniffAndBlockProtocol(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c = b.applyRateLimit(c)
|
||||
|
||||
rc, err := b.relayer.HandShake(ctx, remote, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("handshake error: %w", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
b.l.Infof("RelayTCPConn from %s to %s", c.LocalAddr(), remote.Address)
|
||||
relayConn := conn.NewRelayConn(
|
||||
b.cfg.Label, c, rc, conn.WithHandshakeDuration(clonedRemote.HandShakeDuration))
|
||||
b.cmgr.AddConnection(relayConn)
|
||||
defer b.cmgr.RemoveConnection(relayConn)
|
||||
return relayConn.Transport(remote.Label)
|
||||
return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_TCP)
|
||||
}
|
||||
|
||||
func (b *baseTransporter) HealthCheck(ctx context.Context) (int64, error) {
|
||||
remote := b.GetRemote().Clone()
|
||||
err := b.relayer.HealthCheck(ctx, remote)
|
||||
func (b *BaseRelayServer) RelayUDPConn(ctx context.Context, c net.Conn) error {
|
||||
remote := b.remotes.Next().Clone()
|
||||
metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_UDP).Inc()
|
||||
defer metrics.CurConnectionCount.WithLabelValues(remote.Label, metrics.METRIC_CONN_TYPE_UDP).Dec()
|
||||
|
||||
rc, err := b.relayer.HandShake(ctx, remote, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("handshake error: %w", err)
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
b.l.Infof("RelayUDPConn from %s to %s", c.LocalAddr(), remote.Address)
|
||||
return b.handleRelayConn(c, rc, remote, metrics.METRIC_CONN_TYPE_UDP)
|
||||
}
|
||||
|
||||
func (b *BaseRelayServer) checkConnectionLimit() error {
|
||||
if b.cfg.Options.MaxConnection > 0 && b.cmgr.CountConnection(cmgr.ConnectionTypeActive) >= b.cfg.Options.MaxConnection {
|
||||
return fmt.Errorf("relay:%s active connection count exceed limit %d", b.cfg.Label, b.cfg.Options.MaxConnection)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BaseRelayServer) sniffAndBlockProtocol(c net.Conn) (net.Conn, error) {
|
||||
if len(b.cfg.Options.BlockedProtocols) == 0 {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
buffer := buf.NewPacket()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), constant.SniffTimeOut)
|
||||
defer cancel()
|
||||
|
||||
sniffMetadata, err := sniff.PeekStream(ctx, c, buffer, constant.SniffTimeOut, sniff.TLSClientHello, sniff.HTTPHost)
|
||||
if err != nil {
|
||||
b.l.Debugf("sniff error: %s", err)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
if sniffMetadata != nil {
|
||||
b.l.Infof("sniffed protocol: %s", sniffMetadata.Protocol)
|
||||
for _, p := range b.cfg.Options.BlockedProtocols {
|
||||
if sniffMetadata.Protocol == p {
|
||||
return c, fmt.Errorf("relay:%s blocked protocol:%s", b.cfg.Label, sniffMetadata.Protocol)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !buffer.IsEmpty() {
|
||||
return bufio.NewCachedConn(c, buffer), nil
|
||||
} else {
|
||||
buffer.Release()
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (b *BaseRelayServer) applyRateLimit(c net.Conn) net.Conn {
|
||||
if b.cfg.Options.MaxReadRateKbps > 0 {
|
||||
return conn.NewRateLimitedConn(c, b.cfg.Options.MaxReadRateKbps)
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (b *BaseRelayServer) handleRelayConn(c, rc net.Conn, remote *lb.Node, connType string) error {
|
||||
opts := []conn.RelayConnOption{
|
||||
conn.WithLogger(b.l),
|
||||
conn.WithRemote(remote),
|
||||
conn.WithConnType(connType),
|
||||
conn.WithRelayLabel(b.cfg.Label),
|
||||
conn.WithHandshakeDuration(remote.HandShakeDuration),
|
||||
}
|
||||
relayConn := conn.NewRelayConn(c, rc, opts...)
|
||||
b.cmgr.AddConnection(relayConn)
|
||||
defer b.cmgr.RemoveConnection(relayConn)
|
||||
return relayConn.Transport()
|
||||
}
|
||||
|
||||
func (b *BaseRelayServer) HealthCheck(ctx context.Context) (int64, error) {
|
||||
remote := b.remotes.Next().Clone()
|
||||
// us tcp handshake to check health
|
||||
_, err := b.relayer.HandShake(ctx, remote, true)
|
||||
return int64(remote.HandShakeDuration.Milliseconds()), err
|
||||
}
|
||||
|
||||
func (b *BaseRelayServer) Close() error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (b *BaseRelayServer) ListenAndServe(ctx context.Context) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func NewNetDialer(cfg *conf.Config) *net.Dialer {
|
||||
dialer := &net.Dialer{Timeout: constant.DialTimeOut}
|
||||
dialer.SetMultipathTCP(cfg.Options.EnableMultipathTCP)
|
||||
return dialer
|
||||
}
|
||||
|
||||
func NewTCPListener(ctx context.Context, cfg *conf.Config) (net.Listener, error) {
|
||||
addr, err := net.ResolveTCPAddr("tcp", cfg.Listen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lcfg := net.ListenConfig{}
|
||||
lcfg.SetMultipathTCP(cfg.Options.EnableMultipathTCP)
|
||||
return lcfg.Listen(ctx, "tcp", addr.String())
|
||||
}
|
||||
|
||||
@@ -11,56 +11,45 @@ import (
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
)
|
||||
|
||||
type TCPHandShakeF func(remote *lb.Node) (net.Conn, error)
|
||||
|
||||
// TODO opt this interface
|
||||
type RelayClient interface {
|
||||
HealthCheck(ctx context.Context, remote *lb.Node) error
|
||||
TCPHandShake(remote *lb.Node) (net.Conn, error)
|
||||
HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error)
|
||||
}
|
||||
|
||||
func newRelayClient(cfg *conf.Config) (RelayClient, error) {
|
||||
switch cfg.TransportType {
|
||||
case constant.RelayTypeRaw:
|
||||
return newRawClient(cfg)
|
||||
case constant.RelayTypeMTCP:
|
||||
return newMtcpClient(cfg)
|
||||
case constant.RelayTypeWS:
|
||||
return newWsClient(cfg)
|
||||
case constant.RelayTypeMWS:
|
||||
return newMwsClient(cfg)
|
||||
case constant.RelayTypeWSS:
|
||||
return newWssClient(cfg)
|
||||
case constant.RelayTypeMWSS:
|
||||
return newMwssClient(cfg)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported transport type: %s", cfg.TransportType)
|
||||
}
|
||||
}
|
||||
|
||||
type RelayServer interface {
|
||||
ListenAndServe() error
|
||||
ListenAndServe(ctx context.Context) error
|
||||
Close() error
|
||||
|
||||
RelayTCPConn(ctx context.Context, c net.Conn) error
|
||||
RelayUDPConn(ctx context.Context, c net.Conn) error
|
||||
HealthCheck(ctx context.Context) (int64, error) // latency in ms
|
||||
}
|
||||
|
||||
func NewRelayServer(cfg *conf.Config, cmgr cmgr.Cmgr) (RelayServer, error) {
|
||||
base, err := NewBaseTransporter(cfg, cmgr)
|
||||
base, err := newBaseRelayServer(cfg, cmgr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch cfg.ListenType {
|
||||
case constant.RelayTypeRaw:
|
||||
return newRawServer(base)
|
||||
case constant.RelayTypeMTCP:
|
||||
return newMtcpServer(base)
|
||||
case constant.RelayTypeWS:
|
||||
return newWsServer(base)
|
||||
case constant.RelayTypeMWS:
|
||||
return newMwsServer(base)
|
||||
case constant.RelayTypeWSS:
|
||||
return newWssServer(base)
|
||||
case constant.RelayTypeMWSS:
|
||||
return newMwssServer(base)
|
||||
default:
|
||||
panic("unsupported transport type" + cfg.ListenType)
|
||||
}
|
||||
|
||||
@@ -1,201 +0,0 @@
|
||||
// nolint: errcheck
|
||||
package transporter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/constant"
|
||||
"github.com/xtaci/smux"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type smuxTransporter struct {
|
||||
sessionMutex sync.Mutex
|
||||
|
||||
gcTicker *time.Ticker
|
||||
l *zap.SugaredLogger
|
||||
|
||||
// remote addr -> SessionWithMetrics
|
||||
sessionM map[string][]*SessionWithMetrics
|
||||
|
||||
initSessionF func(ctx context.Context, addr string) (*smux.Session, error)
|
||||
}
|
||||
|
||||
type SessionWithMetrics struct {
|
||||
session *smux.Session
|
||||
|
||||
createdTime time.Time
|
||||
streamList []*smux.Stream
|
||||
}
|
||||
|
||||
func (sm *SessionWithMetrics) CanNotServeNewStream() bool {
|
||||
return sm.session.IsClosed() ||
|
||||
sm.session.NumStreams() >= constant.SmuxMaxStreamCnt ||
|
||||
time.Since(sm.createdTime) > constant.SmuxMaxAliveDuration
|
||||
}
|
||||
|
||||
func streamDead(s *smux.Stream) bool {
|
||||
select {
|
||||
case _, ok := <-s.GetDieCh():
|
||||
return !ok // 如果接收到值且通道未关闭,则 Stream 未死
|
||||
default:
|
||||
return true // 如果通道已经关闭,则 Stream 死了
|
||||
}
|
||||
}
|
||||
|
||||
func (sm *SessionWithMetrics) canCloseSession(remoteAddr string, l *zap.SugaredLogger) bool {
|
||||
for _, s := range sm.streamList {
|
||||
if !streamDead(s) {
|
||||
return false
|
||||
}
|
||||
l.Debugf("session: %s stream: %d is not dead", remoteAddr, s.ID())
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func NewSmuxTransporter(
|
||||
l *zap.SugaredLogger,
|
||||
initSessionF func(ctx context.Context, addr string) (*smux.Session, error),
|
||||
) *smuxTransporter {
|
||||
tr := &smuxTransporter{
|
||||
l: l,
|
||||
initSessionF: initSessionF,
|
||||
sessionM: make(map[string][]*SessionWithMetrics),
|
||||
gcTicker: time.NewTicker(constant.SmuxGCDuration),
|
||||
}
|
||||
// start gc thread for close idle sessions
|
||||
go tr.gc()
|
||||
return tr
|
||||
}
|
||||
|
||||
func (tr *smuxTransporter) gc() {
|
||||
for range tr.gcTicker.C {
|
||||
tr.sessionMutex.Lock()
|
||||
for addr, sl := range tr.sessionM {
|
||||
tr.l.Debugf("start doing gc for remote addr: %s total session count %d", addr, len(sl))
|
||||
for idx := range sl {
|
||||
sm := sl[idx]
|
||||
if sm.CanNotServeNewStream() && sm.canCloseSession(addr, tr.l) {
|
||||
tr.l.Debugf("close idle session:%s stream cnt %d",
|
||||
sm.session.LocalAddr().String(), sm.session.NumStreams())
|
||||
sm.session.Close()
|
||||
}
|
||||
}
|
||||
newList := []*SessionWithMetrics{}
|
||||
for _, s := range sl {
|
||||
if !s.session.IsClosed() {
|
||||
newList = append(newList, s)
|
||||
}
|
||||
}
|
||||
tr.sessionM[addr] = newList
|
||||
tr.l.Debugf("finish gc for remote addr: %s total session count %d", addr, len(sl))
|
||||
}
|
||||
tr.sessionMutex.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (tr *smuxTransporter) Dial(ctx context.Context, addr string) (conn net.Conn, err error) {
|
||||
tr.sessionMutex.Lock()
|
||||
defer tr.sessionMutex.Unlock()
|
||||
var session *smux.Session
|
||||
var curSM *SessionWithMetrics
|
||||
|
||||
sessionList := tr.sessionM[addr]
|
||||
for _, sm := range sessionList {
|
||||
if sm.CanNotServeNewStream() {
|
||||
continue
|
||||
} else {
|
||||
tr.l.Debugf("use session: %s total stream count: %d remote addr: %s",
|
||||
sm.session.LocalAddr().String(), sm.session.NumStreams(), addr)
|
||||
session = sm.session
|
||||
curSM = sm
|
||||
break
|
||||
}
|
||||
}
|
||||
// create new one
|
||||
if session == nil {
|
||||
session, err = tr.initSessionF(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sm := &SessionWithMetrics{session: session, createdTime: time.Now(), streamList: []*smux.Stream{}}
|
||||
sessionList = append(sessionList, sm)
|
||||
tr.sessionM[addr] = sessionList
|
||||
curSM = sm
|
||||
}
|
||||
|
||||
stream, err := session.OpenStream()
|
||||
if err != nil {
|
||||
tr.l.Errorf("open stream meet error:%s", err)
|
||||
session.Close()
|
||||
return nil, err
|
||||
}
|
||||
curSM.streamList = append(curSM.streamList, stream)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
type muxServer interface {
|
||||
ListenAndServe() error
|
||||
Accept() (net.Conn, error)
|
||||
Close() error
|
||||
mux(net.Conn)
|
||||
}
|
||||
|
||||
func newMuxServer(listenAddr string, l *zap.SugaredLogger) *muxServerImpl {
|
||||
return &muxServerImpl{
|
||||
errChan: make(chan error, 1),
|
||||
connChan: make(chan net.Conn, 1024),
|
||||
listenAddr: listenAddr,
|
||||
l: l,
|
||||
}
|
||||
}
|
||||
|
||||
type muxServerImpl struct {
|
||||
errChan chan error
|
||||
connChan chan net.Conn
|
||||
|
||||
listenAddr string
|
||||
l *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func (s *muxServerImpl) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-s.connChan:
|
||||
return conn, nil
|
||||
case err := <-s.errChan:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *muxServerImpl) mux(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.KeepAliveDisabled = true
|
||||
session, err := smux.Server(conn, cfg)
|
||||
if err != nil {
|
||||
s.l.Debugf("server err %s - %s : %s", conn.RemoteAddr(), s.listenAddr, err)
|
||||
return
|
||||
}
|
||||
defer session.Close() // nolint: errcheck
|
||||
|
||||
s.l.Debugf("session init %s %s", conn.RemoteAddr(), s.listenAddr)
|
||||
defer s.l.Debugf("session close %s >-< %s", conn.RemoteAddr(), s.listenAddr)
|
||||
|
||||
for {
|
||||
stream, err := session.AcceptStream()
|
||||
if err != nil {
|
||||
s.l.Errorf("accept stream err: %s", err)
|
||||
break
|
||||
}
|
||||
select {
|
||||
case s.connChan <- stream:
|
||||
default:
|
||||
stream.Close() // nolint: errcheck
|
||||
s.l.Infof("%s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -3,10 +3,11 @@ package transporter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/constant"
|
||||
"github.com/Ehco1996/ehco/internal/conn"
|
||||
"github.com/Ehco1996/ehco/internal/lb"
|
||||
"github.com/Ehco1996/ehco/internal/metrics"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
@@ -26,17 +27,22 @@ type RawClient struct {
|
||||
|
||||
func newRawClient(cfg *conf.Config) (*RawClient, error) {
|
||||
r := &RawClient{
|
||||
l: zap.S().Named("raw"),
|
||||
cfg: cfg,
|
||||
dialer: &net.Dialer{Timeout: constant.DialTimeOut},
|
||||
dialer: NewNetDialer(cfg),
|
||||
l: zap.S().Named(string(cfg.TransportType)),
|
||||
}
|
||||
r.dialer.SetMultipathTCP(true)
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func (raw *RawClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
|
||||
func (raw *RawClient) HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error) {
|
||||
t1 := time.Now()
|
||||
rc, err := raw.dialer.Dial("tcp", remote.Address)
|
||||
var rc net.Conn
|
||||
var err error
|
||||
if isTCP {
|
||||
rc, err = raw.dialer.DialContext(ctx, "tcp", remote.Address)
|
||||
} else {
|
||||
rc, err = raw.dialer.DialContext(ctx, "udp", remote.Address)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -46,55 +52,72 @@ func (raw *RawClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func (raw *RawClient) HealthCheck(ctx context.Context, remote *lb.Node) error {
|
||||
l := zap.S().Named("health-check")
|
||||
l.Infof("start send req to %s", remote.Address)
|
||||
c, err := raw.TCPHandShake(remote)
|
||||
if err != nil {
|
||||
l.Errorf("send req to %s meet error:%s", remote.Address, err)
|
||||
return err
|
||||
}
|
||||
c.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type RawServer struct {
|
||||
*baseTransporter
|
||||
lis net.Listener
|
||||
*BaseRelayServer
|
||||
|
||||
tcpLis net.Listener
|
||||
udpLis *conn.UDPListener
|
||||
}
|
||||
|
||||
func newRawServer(base *baseTransporter) (*RawServer, error) {
|
||||
addr, err := base.GetTCPListenAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg := net.ListenConfig{}
|
||||
cfg.SetMultipathTCP(true)
|
||||
lis, err := cfg.Listen(context.TODO(), "tcp", addr.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &RawServer{
|
||||
lis: lis,
|
||||
baseTransporter: base,
|
||||
}, nil
|
||||
func newRawServer(bs *BaseRelayServer) (*RawServer, error) {
|
||||
rs := &RawServer{BaseRelayServer: bs}
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
func (s *RawServer) Close() error {
|
||||
return s.lis.Close()
|
||||
err := s.tcpLis.Close()
|
||||
if s.udpLis != nil {
|
||||
err2 := s.udpLis.Close()
|
||||
err = errors.Join(err, err2)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *RawServer) ListenAndServe() error {
|
||||
func (s *RawServer) ListenAndServe(ctx context.Context) error {
|
||||
ts, err := NewTCPListener(ctx, s.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.tcpLis = ts
|
||||
|
||||
if s.cfg.Options != nil && s.cfg.Options.EnableUDP {
|
||||
udpLis, err := conn.NewUDPListener(ctx, s.cfg.Listen)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.udpLis = udpLis
|
||||
}
|
||||
|
||||
if s.udpLis != nil {
|
||||
go s.listenUDP(ctx)
|
||||
}
|
||||
for {
|
||||
c, err := s.lis.Accept()
|
||||
c, err := s.tcpLis.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer c.Close()
|
||||
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
|
||||
s.l.Errorf("RelayTCPConn error: %s", err.Error())
|
||||
if err := s.RelayTCPConn(ctx, c); err != nil {
|
||||
s.l.Errorf("RelayTCPConn meet error: %s", err.Error())
|
||||
}
|
||||
}(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RawServer) listenUDP(ctx context.Context) error {
|
||||
s.l.Infof("Start UDP server at %s", s.cfg.Listen)
|
||||
for {
|
||||
c, err := s.udpLis.Accept()
|
||||
if err != nil {
|
||||
s.l.Errorf("UDP accept error: %v", err)
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
if err := s.RelayUDPConn(ctx, c); err != nil {
|
||||
s.l.Errorf("RelayUDPConn meet error: %s", err.Error())
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,109 +0,0 @@
|
||||
package transporter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/xtaci/smux"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/lb"
|
||||
"github.com/Ehco1996/ehco/internal/metrics"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
)
|
||||
|
||||
var (
|
||||
_ RelayClient = &MtcpClient{}
|
||||
_ RelayServer = &MtcpServer{}
|
||||
)
|
||||
|
||||
type MtcpClient struct {
|
||||
*RawClient
|
||||
muxTP *smuxTransporter
|
||||
}
|
||||
|
||||
func newMtcpClient(cfg *conf.Config) (*MtcpClient, error) {
|
||||
raw, err := newRawClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &MtcpClient{RawClient: raw}
|
||||
c.muxTP = NewSmuxTransporter(zap.S().Named("mtcp"), c.initNewSession)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *MtcpClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
|
||||
rc, err := c.dialer.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// stream multiplex
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.KeepAliveDisabled = true
|
||||
session, err := smux.Client(rc, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.l.Infof("init new session to: %s", rc.RemoteAddr())
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *MtcpClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
|
||||
t1 := time.Now()
|
||||
mtcpc, err := s.muxTP.Dial(context.TODO(), remote.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
latency := time.Since(t1)
|
||||
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
|
||||
remote.HandShakeDuration = latency
|
||||
return mtcpc, nil
|
||||
}
|
||||
|
||||
type MtcpServer struct {
|
||||
*RawServer
|
||||
*muxServerImpl
|
||||
}
|
||||
|
||||
func newMtcpServer(base *baseTransporter) (*MtcpServer, error) {
|
||||
raw, err := newRawServer(base)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &MtcpServer{
|
||||
RawServer: raw,
|
||||
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mtcp")),
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *MtcpServer) ListenAndServe() error {
|
||||
go func() {
|
||||
for {
|
||||
c, err := s.lis.Accept()
|
||||
if err != nil {
|
||||
s.errChan <- err
|
||||
continue
|
||||
}
|
||||
go s.mux(c)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, e := s.Accept()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
|
||||
s.l.Errorf("RelayTCPConn error: %s", err.Error())
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MtcpServer) Close() error {
|
||||
return s.lis.Close()
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
@@ -11,7 +12,6 @@ import (
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/conn"
|
||||
"github.com/Ehco1996/ehco/internal/constant"
|
||||
"github.com/Ehco1996/ehco/internal/lb"
|
||||
"github.com/Ehco1996/ehco/internal/metrics"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
@@ -24,27 +24,48 @@ var (
|
||||
)
|
||||
|
||||
type WsClient struct {
|
||||
dialer *ws.Dialer
|
||||
cfg *conf.Config
|
||||
l *zap.SugaredLogger
|
||||
dialer *ws.Dialer
|
||||
cfg *conf.Config
|
||||
netDialer *net.Dialer
|
||||
l *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func newWsClient(cfg *conf.Config) (*WsClient, error) {
|
||||
s := &WsClient{
|
||||
cfg: cfg,
|
||||
l: zap.S().Named(string(cfg.TransportType)),
|
||||
dialer: &ws.Dialer{Timeout: constant.DialTimeOut},
|
||||
cfg: cfg,
|
||||
netDialer: NewNetDialer(cfg),
|
||||
l: zap.S().Named(string(cfg.TransportType)),
|
||||
dialer: &ws.DefaultDialer, // todo config buffer size
|
||||
}
|
||||
s.dialer.NetDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return s.netDialer.Dial(network, addr)
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
|
||||
func (s *WsClient) addUDPQueryParam(addr string) string {
|
||||
u, err := url.Parse(addr)
|
||||
if err != nil {
|
||||
s.l.Errorf("Failed to parse URL: %v", err)
|
||||
return addr
|
||||
}
|
||||
q := u.Query()
|
||||
q.Set("type", "udp")
|
||||
u.RawQuery = q.Encode()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
func (s *WsClient) HandShake(ctx context.Context, remote *lb.Node, isTCP bool) (net.Conn, error) {
|
||||
t1 := time.Now()
|
||||
addr, err := s.cfg.GetWSRemoteAddr(remote.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wsc, _, _, err := s.dialer.Dial(context.TODO(), addr)
|
||||
if !isTCP {
|
||||
addr = s.addUDPQueryParam(addr)
|
||||
}
|
||||
|
||||
wsc, _, _, err := s.dialer.Dial(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -55,61 +76,50 @@ func (s *WsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *WsClient) HealthCheck(ctx context.Context, remote *lb.Node) error {
|
||||
l := zap.S().Named("health-check")
|
||||
l.Infof("start send req to %s", remote.Address)
|
||||
c, err := s.TCPHandShake(remote)
|
||||
if err != nil {
|
||||
l.Errorf("send req to %s meet error:%s", remote.Address, err)
|
||||
return err
|
||||
}
|
||||
c.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
type WsServer struct {
|
||||
*baseTransporter
|
||||
|
||||
e *echo.Echo
|
||||
*BaseRelayServer
|
||||
httpServer *http.Server
|
||||
}
|
||||
|
||||
func newWsServer(base *baseTransporter) (*WsServer, error) {
|
||||
localTCPAddr, err := base.GetTCPListenAddr()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &WsServer{
|
||||
baseTransporter: base,
|
||||
httpServer: &http.Server{
|
||||
Addr: localTCPAddr.String(), ReadHeaderTimeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
func newWsServer(bs *BaseRelayServer) (*WsServer, error) {
|
||||
s := &WsServer{BaseRelayServer: bs}
|
||||
e := web.NewEchoServer()
|
||||
e.Use(web.NginxLogMiddleware(zap.S().Named("ws-server")))
|
||||
|
||||
e.GET("/", echo.WrapHandler(web.MakeIndexF()))
|
||||
e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
|
||||
|
||||
s.e = e
|
||||
e.GET(bs.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.handleRequest)))
|
||||
s.httpServer = &http.Server{Handler: e}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *WsServer) ListenAndServe() error {
|
||||
return s.e.StartServer(s.httpServer)
|
||||
}
|
||||
|
||||
func (s *WsServer) Close() error {
|
||||
return s.e.Close()
|
||||
}
|
||||
|
||||
func (s *WsServer) HandleRequest(w http.ResponseWriter, req *http.Request) {
|
||||
func (s *WsServer) handleRequest(w http.ResponseWriter, req *http.Request) {
|
||||
// todo use bufio.ReadWriter
|
||||
wsc, _, _, err := ws.UpgradeHTTP(req, w)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.RelayTCPConn(conn.NewWSConn(wsc, true), s.relayer.TCPHandShake); err != nil {
|
||||
s.l.Errorf("RelayTCPConn error: %s", err.Error())
|
||||
if req.URL.Query().Get("type") == "udp" {
|
||||
if !s.cfg.Options.EnableUDP {
|
||||
s.l.Error("udp not support but request with udp type")
|
||||
wsc.Close()
|
||||
return
|
||||
}
|
||||
err = s.RelayUDPConn(req.Context(), conn.NewWSConn(wsc, true))
|
||||
} else {
|
||||
err = s.RelayTCPConn(req.Context(), conn.NewWSConn(wsc, true))
|
||||
}
|
||||
if err != nil {
|
||||
s.l.Errorf("handleRequest meet error:%s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *WsServer) ListenAndServe(ctx context.Context) error {
|
||||
listener, err := NewTCPListener(ctx, s.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.httpServer.Serve(listener)
|
||||
}
|
||||
|
||||
func (s *WsServer) Close() error {
|
||||
return s.httpServer.Close()
|
||||
}
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
// NOTE CAN NOT use real ws frame to transport smux frame
|
||||
// err: accept stream err: buffer size:8 too small to transport ws payload size:45
|
||||
// so this transport just use ws protocol to handshake and then use smux protocol to transport
|
||||
package transporter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/xtaci/smux"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/lb"
|
||||
"github.com/Ehco1996/ehco/internal/metrics"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
)
|
||||
|
||||
var (
|
||||
_ RelayClient = &MwsClient{}
|
||||
_ RelayServer = &MwsServer{}
|
||||
_ muxServer = &MwsServer{}
|
||||
)
|
||||
|
||||
type MwsClient struct {
|
||||
*WssClient
|
||||
|
||||
muxTP *smuxTransporter
|
||||
}
|
||||
|
||||
func newMwsClient(cfg *conf.Config) (*MwsClient, error) {
|
||||
wc, err := newWssClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &MwsClient{WssClient: wc}
|
||||
c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *MwsClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
|
||||
rc, _, _, err := c.dialer.Dial(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// stream multiplex
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.KeepAliveDisabled = true
|
||||
session, err := smux.Client(rc, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.l.Infof("init new session to: %s", rc.RemoteAddr())
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *MwsClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
|
||||
t1 := time.Now()
|
||||
addr, err := s.cfg.GetWSRemoteAddr(remote.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mwssc, err := s.muxTP.Dial(context.TODO(), addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
latency := time.Since(t1)
|
||||
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
|
||||
remote.HandShakeDuration = latency
|
||||
return mwssc, nil
|
||||
}
|
||||
|
||||
type MwsServer struct {
|
||||
*WsServer
|
||||
*muxServerImpl
|
||||
}
|
||||
|
||||
func newMwsServer(base *baseTransporter) (*MwsServer, error) {
|
||||
wsServer, err := newWsServer(base)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &MwsServer{
|
||||
WsServer: wsServer,
|
||||
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")),
|
||||
}
|
||||
s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *MwsServer) ListenAndServe() error {
|
||||
go func() {
|
||||
s.errChan <- s.e.StartServer(s.httpServer)
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, e := s.Accept()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
|
||||
s.l.Errorf("RelayTCPConn error: %s", err.Error())
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MwsServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
c, _, _, err := ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
s.l.Error(err)
|
||||
return
|
||||
}
|
||||
s.mux(c)
|
||||
}
|
||||
|
||||
func (s *MwsServer) Close() error {
|
||||
return s.e.Close()
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package transporter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
mytls "github.com/Ehco1996/ehco/internal/tls"
|
||||
)
|
||||
@@ -28,12 +31,19 @@ type WssServer struct {
|
||||
*WsServer
|
||||
}
|
||||
|
||||
func newWssServer(base *baseTransporter) (*WssServer, error) {
|
||||
wsServer, err := newWsServer(base)
|
||||
func newWssServer(bs *BaseRelayServer) (*WssServer, error) {
|
||||
wsServer, err := newWsServer(bs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// insert tls config
|
||||
wsServer.httpServer.TLSConfig = mytls.DefaultTLSConfig
|
||||
return &WssServer{WsServer: wsServer}, nil
|
||||
}
|
||||
|
||||
func (s *WssServer) ListenAndServe(ctx context.Context) error {
|
||||
listener, err := NewTCPListener(ctx, s.cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tlsListener := tls.NewListener(listener, mytls.DefaultTLSConfig)
|
||||
return s.httpServer.Serve(tlsListener)
|
||||
}
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
// NOTE CAN NOT use real ws frame to transport smux frame
|
||||
// err: accept stream err: buffer size:8 too small to transport ws payload size:45
|
||||
// so this transport just use ws protocol to handshake and then use smux protocol to transport
|
||||
package transporter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gobwas/ws"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/xtaci/smux"
|
||||
|
||||
"github.com/Ehco1996/ehco/internal/lb"
|
||||
"github.com/Ehco1996/ehco/internal/metrics"
|
||||
"github.com/Ehco1996/ehco/internal/relay/conf"
|
||||
)
|
||||
|
||||
var (
|
||||
_ RelayClient = &MwssClient{}
|
||||
_ RelayServer = &MwssServer{}
|
||||
_ muxServer = &MwssServer{}
|
||||
)
|
||||
|
||||
type MwssClient struct {
|
||||
*WssClient
|
||||
|
||||
muxTP *smuxTransporter
|
||||
}
|
||||
|
||||
func newMwssClient(cfg *conf.Config) (*MwssClient, error) {
|
||||
wc, err := newWssClient(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &MwssClient{WssClient: wc}
|
||||
c.muxTP = NewSmuxTransporter(c.l.Named("mwss"), c.initNewSession)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *MwssClient) initNewSession(ctx context.Context, addr string) (*smux.Session, error) {
|
||||
rc, _, _, err := c.dialer.Dial(ctx, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// stream multiplex
|
||||
cfg := smux.DefaultConfig()
|
||||
cfg.KeepAliveDisabled = true
|
||||
session, err := smux.Client(rc, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.l.Infof("init new session to: %s", rc.RemoteAddr())
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *MwssClient) TCPHandShake(remote *lb.Node) (net.Conn, error) {
|
||||
t1 := time.Now()
|
||||
addr, err := s.cfg.GetWSRemoteAddr(remote.Address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mwssc, err := s.muxTP.Dial(context.TODO(), addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
latency := time.Since(t1)
|
||||
metrics.HandShakeDuration.WithLabelValues(remote.Label).Observe(float64(latency.Milliseconds()))
|
||||
remote.HandShakeDuration = latency
|
||||
return mwssc, nil
|
||||
}
|
||||
|
||||
type MwssServer struct {
|
||||
*WssServer
|
||||
*muxServerImpl
|
||||
}
|
||||
|
||||
func newMwssServer(base *baseTransporter) (*MwssServer, error) {
|
||||
wssServer, err := newWssServer(base)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &MwssServer{
|
||||
WssServer: wssServer,
|
||||
muxServerImpl: newMuxServer(base.cfg.Listen, base.l.Named("mwss")),
|
||||
}
|
||||
s.e.GET(base.cfg.GetWSHandShakePath(), echo.WrapHandler(http.HandlerFunc(s.HandleRequest)))
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *MwssServer) ListenAndServe() error {
|
||||
go func() {
|
||||
s.errChan <- s.e.StartServer(s.httpServer)
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, e := s.Accept()
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
if err := s.RelayTCPConn(c, s.relayer.TCPHandShake); err != nil {
|
||||
s.l.Errorf("RelayTCPConn error: %s", err.Error())
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MwssServer) HandleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
c, _, _, err := ws.UpgradeHTTP(r, w)
|
||||
if err != nil {
|
||||
s.l.Error(err)
|
||||
return
|
||||
}
|
||||
s.mux(c)
|
||||
}
|
||||
|
||||
func (s *MwssServer) Close() error {
|
||||
return s.e.Close()
|
||||
}
|
||||
@@ -28,6 +28,7 @@
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Relay Label</th>
|
||||
<th>Type</th>
|
||||
<th>Flow</th>
|
||||
<th>Stats</th>
|
||||
<th>Time</th>
|
||||
@@ -37,7 +38,8 @@
|
||||
{{range .ConnectionList}}
|
||||
<tr>
|
||||
<td>{{.RelayLabel}}</td>
|
||||
<td>{{.Flow}}</td>
|
||||
<td>{{.ConnType}}</td>
|
||||
<td>{{.GetFlow}}</td>
|
||||
<td>{{.Stats}}</td>
|
||||
<td>{{.GetTime}}</td>
|
||||
</tr>
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
|
||||
// 全局pool
|
||||
var (
|
||||
BufferPool *BytePool
|
||||
BufferPool *BytePool
|
||||
UDPBufferPool *BytePool
|
||||
)
|
||||
|
||||
func init() {
|
||||
BufferPool = NewBytePool(constant.BUFFER_POOL_SIZE, constant.BUFFER_SIZE)
|
||||
UDPBufferPool = NewBytePool(constant.BUFFER_POOL_SIZE, constant.UDPBufSize)
|
||||
}
|
||||
|
||||
// BytePool implements a leaky pool of []byte in the form of a bounded channel
|
||||
|
||||
@@ -35,7 +35,7 @@ func (b *readerImpl) parsePingInfo(metricMap map[string]*dto.MetricFamily, nm *N
|
||||
metric, ok := metricMap["ehco_ping_response_duration_seconds"]
|
||||
if !ok {
|
||||
// this metric is optional when enable_ping = false
|
||||
zap.S().Warn("ping metric not found")
|
||||
zap.S().Debug("ping metric not found")
|
||||
return nil
|
||||
}
|
||||
for _, m := range metric.Metric {
|
||||
|
||||
@@ -179,7 +179,9 @@ func (c *ClashSub) ToRelayConfigs(listenHost string) ([]*relay_cfg.Config, error
|
||||
}
|
||||
rc.TCPRemotes = append(rc.TCPRemotes, remote)
|
||||
if proxy.UDP {
|
||||
rc.UDPRemotes = append(rc.UDPRemotes, remote)
|
||||
rc.Options = &relay_cfg.Options{
|
||||
EnableUDP: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
relayConfigs = append(relayConfigs, rc)
|
||||
|
||||
@@ -139,7 +139,9 @@ func (p *Proxies) ToRelayConfig(listenHost string, listenPort string, newName st
|
||||
TCPRemotes: []string{remoteAddr},
|
||||
}
|
||||
if p.UDP {
|
||||
r.UDPRemotes = []string{remoteAddr}
|
||||
r.Options = &relay_cfg.Options{
|
||||
EnableUDP: true,
|
||||
}
|
||||
}
|
||||
if err := r.Validate(); err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -28,8 +27,7 @@ const (
|
||||
ECHO_PORT = 9002
|
||||
ECHO_SERVER = "0.0.0.0:9002"
|
||||
|
||||
RAW_LISTEN = "0.0.0.0:1234"
|
||||
RAW_LISTEN_WITH_MAX_CONNECTION = "0.0.0.0:2234"
|
||||
RAW_LISTEN = "0.0.0.0:1234"
|
||||
|
||||
WS_LISTEN = "0.0.0.0:1235"
|
||||
WS_REMOTE = "ws://0.0.0.0:2000"
|
||||
@@ -38,22 +36,15 @@ const (
|
||||
WSS_LISTEN = "0.0.0.0:1236"
|
||||
WSS_REMOTE = "wss://0.0.0.0:2001"
|
||||
WSS_SERVER = "0.0.0.0:2001"
|
||||
|
||||
MWSS_LISTEN = "0.0.0.0:1237"
|
||||
MWSS_REMOTE = "wss://0.0.0.0:2002"
|
||||
MWSS_SERVER = "0.0.0.0:2002"
|
||||
|
||||
MTCP_LISTEN = "0.0.0.0:1238"
|
||||
MTCP_REMOTE = "0.0.0.0:2003"
|
||||
MTCP_SERVER = "0.0.0.0:2003"
|
||||
|
||||
MWS_LISTEN = "0.0.0.0:1239"
|
||||
MWS_REMOTE = "ws://0.0.0.0:2004"
|
||||
MSS_SERVER = "0.0.0.0:2004"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Setup
|
||||
|
||||
// change the idle timeout to 1 second to make connection close faster in test
|
||||
constant.IdleTimeOut = time.Second
|
||||
constant.ReadTimeOut = time.Second
|
||||
|
||||
_ = log.InitGlobalLogger("debug")
|
||||
_ = tls.InitTlsCfg()
|
||||
|
||||
@@ -78,38 +69,35 @@ func TestMain(m *testing.M) {
|
||||
|
||||
func startRelayServers() []*relay.Relay {
|
||||
cfg := config.Config{
|
||||
PATH: "",
|
||||
RelayConfigs: []*conf.Config{
|
||||
// raw cfg
|
||||
// raw
|
||||
{
|
||||
Listen: RAW_LISTEN,
|
||||
ListenType: constant.RelayTypeRaw,
|
||||
TCPRemotes: []string{ECHO_SERVER},
|
||||
UDPRemotes: []string{ECHO_SERVER},
|
||||
TransportType: constant.RelayTypeRaw,
|
||||
Options: &conf.Options{
|
||||
EnableUDP: true,
|
||||
},
|
||||
},
|
||||
// raw cfg with max connection
|
||||
{
|
||||
Listen: RAW_LISTEN_WITH_MAX_CONNECTION,
|
||||
ListenType: constant.RelayTypeRaw,
|
||||
TCPRemotes: []string{ECHO_SERVER},
|
||||
UDPRemotes: []string{ECHO_SERVER},
|
||||
TransportType: constant.RelayTypeRaw,
|
||||
MaxConnection: 1,
|
||||
},
|
||||
|
||||
// ws
|
||||
{
|
||||
Listen: WS_LISTEN,
|
||||
ListenType: constant.RelayTypeRaw,
|
||||
TCPRemotes: []string{WS_REMOTE},
|
||||
TransportType: constant.RelayTypeWS,
|
||||
Options: &conf.Options{
|
||||
EnableUDP: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Listen: WS_SERVER,
|
||||
ListenType: constant.RelayTypeWS,
|
||||
TCPRemotes: []string{ECHO_SERVER},
|
||||
TransportType: constant.RelayTypeRaw,
|
||||
Options: &conf.Options{
|
||||
EnableUDP: true,
|
||||
},
|
||||
},
|
||||
|
||||
// wss
|
||||
@@ -118,65 +106,30 @@ func startRelayServers() []*relay.Relay {
|
||||
ListenType: constant.RelayTypeRaw,
|
||||
TCPRemotes: []string{WSS_REMOTE},
|
||||
TransportType: constant.RelayTypeWSS,
|
||||
Options: &conf.Options{
|
||||
EnableUDP: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
Listen: WSS_SERVER,
|
||||
ListenType: constant.RelayTypeWSS,
|
||||
TCPRemotes: []string{ECHO_SERVER},
|
||||
TransportType: constant.RelayTypeRaw,
|
||||
},
|
||||
|
||||
// mwss
|
||||
{
|
||||
Listen: MWSS_LISTEN,
|
||||
ListenType: constant.RelayTypeRaw,
|
||||
TCPRemotes: []string{MWSS_REMOTE},
|
||||
TransportType: constant.RelayTypeMWSS,
|
||||
},
|
||||
{
|
||||
Listen: MWSS_SERVER,
|
||||
ListenType: constant.RelayTypeMWSS,
|
||||
TCPRemotes: []string{ECHO_SERVER},
|
||||
TransportType: constant.RelayTypeRaw,
|
||||
},
|
||||
|
||||
// mtcp
|
||||
{
|
||||
Listen: MTCP_LISTEN,
|
||||
ListenType: constant.RelayTypeRaw,
|
||||
TCPRemotes: []string{MTCP_REMOTE},
|
||||
TransportType: constant.RelayTypeMTCP,
|
||||
},
|
||||
{
|
||||
Listen: MTCP_SERVER,
|
||||
ListenType: constant.RelayTypeMTCP,
|
||||
TCPRemotes: []string{ECHO_SERVER},
|
||||
TransportType: constant.RelayTypeRaw,
|
||||
},
|
||||
|
||||
// mws
|
||||
{
|
||||
Listen: MWS_LISTEN,
|
||||
ListenType: constant.RelayTypeRaw,
|
||||
TCPRemotes: []string{MWS_REMOTE},
|
||||
TransportType: constant.RelayTypeMWS,
|
||||
},
|
||||
{
|
||||
Listen: MSS_SERVER,
|
||||
ListenType: constant.RelayTypeMWS,
|
||||
TCPRemotes: []string{ECHO_SERVER},
|
||||
TransportType: constant.RelayTypeRaw,
|
||||
Options: &conf.Options{
|
||||
EnableUDP: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var servers []*relay.Relay
|
||||
for _, c := range cfg.RelayConfigs {
|
||||
c.Adjust()
|
||||
r, err := relay.NewRelay(c, cmgr.NewCmgr(cmgr.DummyConfig))
|
||||
if err != nil {
|
||||
zap.S().Fatal(err)
|
||||
}
|
||||
go r.ListenAndServe()
|
||||
go r.ListenAndServe(context.TODO())
|
||||
servers = append(servers, r)
|
||||
}
|
||||
|
||||
@@ -194,16 +147,14 @@ func TestRelay(t *testing.T) {
|
||||
{"Raw", RAW_LISTEN, "raw"},
|
||||
{"WS", WS_LISTEN, "ws"},
|
||||
{"WSS", WSS_LISTEN, "wss"},
|
||||
{"MWSS", MWSS_LISTEN, "mwss"},
|
||||
{"MTCP", MTCP_LISTEN, "mtcp"},
|
||||
{"MWS", MWS_LISTEN, "mws"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testRelayCommon(t, tc.address, tc.protocol, false)
|
||||
testTCPRelay(t, tc.address, tc.protocol, false)
|
||||
testUDPRelay(t, tc.address, false)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -214,21 +165,22 @@ func TestRelayConcurrent(t *testing.T) {
|
||||
address string
|
||||
concurrency int
|
||||
}{
|
||||
{"MWSS", MWSS_LISTEN, 10},
|
||||
{"MTCP", MTCP_LISTEN, 10},
|
||||
{"MWS", MWS_LISTEN, 10},
|
||||
{"Raw", RAW_LISTEN, 10},
|
||||
{"WS", WS_LISTEN, 10},
|
||||
{"WSS", WSS_LISTEN, 10},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc // capture range variable
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testRelayCommon(t, tc.address, tc.name, true, tc.concurrency)
|
||||
testTCPRelay(t, tc.address, tc.name, true, tc.concurrency)
|
||||
testUDPRelay(t, tc.address, true, tc.concurrency)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testRelayCommon(t *testing.T, address, protocol string, concurrent bool, concurrency ...int) {
|
||||
func testTCPRelay(t *testing.T, address, protocol string, concurrent bool, concurrency ...int) {
|
||||
t.Helper()
|
||||
msg := []byte("hello")
|
||||
|
||||
@@ -264,40 +216,42 @@ func testRelayCommon(t *testing.T, address, protocol string, concurrent bool, co
|
||||
t.Logf("Test TCP over %s done!", protocol)
|
||||
}
|
||||
|
||||
func TestRelayWithMaxConnectionCount(t *testing.T) {
|
||||
msg := []byte("hello")
|
||||
func testUDPRelay(t *testing.T, address string, concurrent bool, concurrency ...int) {
|
||||
t.Helper()
|
||||
msg := []byte("hello udp")
|
||||
|
||||
// First connection will be accepted
|
||||
go func() {
|
||||
err := echo.EchoTcpMsgLong(msg, time.Second, RAW_LISTEN_WITH_MAX_CONNECTION)
|
||||
require.NoError(t, err, "First connection should be accepted")
|
||||
}()
|
||||
runTest := func() error {
|
||||
res := echo.SendUdpMsg(msg, address)
|
||||
if !bytes.Equal(msg, res) {
|
||||
return fmt.Errorf("response mismatch: got %s, want %s", res, msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait for first connection
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Second connection should be rejected
|
||||
err := echo.EchoTcpMsgLong(msg, time.Second, RAW_LISTEN_WITH_MAX_CONNECTION)
|
||||
require.Error(t, err, "Second connection should be rejected")
|
||||
if concurrent {
|
||||
n := 10
|
||||
if len(concurrency) > 0 {
|
||||
n = concurrency[0]
|
||||
}
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
for i := 0; i < n; i++ {
|
||||
g.Go(func() error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
return runTest()
|
||||
}
|
||||
})
|
||||
}
|
||||
require.NoError(t, g.Wait(), "Concurrent test failed")
|
||||
} else {
|
||||
require.NoError(t, runTest(), "Single test failed")
|
||||
}
|
||||
t.Logf("Test UDP over %s done!", address)
|
||||
}
|
||||
|
||||
func TestRelayWithDeadline(t *testing.T) {
|
||||
logger, _ := zap.NewDevelopment()
|
||||
msg := []byte("hello")
|
||||
conn, err := net.Dial("tcp", RAW_LISTEN)
|
||||
if err != nil {
|
||||
logger.Sugar().Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
if _, err := conn.Write(msg); err != nil {
|
||||
logger.Sugar().Fatal(err)
|
||||
}
|
||||
|
||||
buf := make([]byte, len(msg))
|
||||
constant.IdleTimeOut = time.Second // change for test
|
||||
time.Sleep(constant.IdleTimeOut)
|
||||
_, err = conn.Read(buf)
|
||||
if err != nil {
|
||||
logger.Sugar().Fatal("need error here")
|
||||
}
|
||||
func TestRelayIdleTimeout(t *testing.T) {
|
||||
err := echo.EchoTcpMsgLong([]byte("hello"), time.Second, RAW_LISTEN)
|
||||
require.Error(t, err, "Connection should be rejected")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user