From a4d0eacbbf783e27ef168ff918da46bde91b602b Mon Sep 17 00:00:00 2001 From: charlesbao Date: Sat, 25 Jan 2020 20:21:18 +0800 Subject: [PATCH] init --- .gitignore | 23 + README.md | 6 + client/control.go | 342 +++++++ client/event/event.go | 28 + client/health/health.go | 171 ++++ client/proxy/proxy.go | 595 +++++++++++ client/proxy/proxy_manager.go | 146 +++ client/proxy/proxy_wrapper.go | 250 +++++ client/service.go | 252 +++++ client/visitor.go | 330 +++++++ client/visitor_manager.go | 129 +++ conf/frpc.ini | 11 + conf/frpc_full.ini | 266 +++++ conf/frps.ini | 13 + conf/frps_full.ini | 83 ++ conf/systemd/frpc.service | 14 + conf/systemd/frpc@.service | 14 + conf/systemd/frps.service | 13 + conf/systemd/frps@.service | 13 + example/config.ini | 0 example/main.go | 8 + frpc.go | 136 +++ frpc_test.go | 9 + go.mod | 32 + go.sum | 57 ++ models/config/client_common.go | 319 ++++++ models/config/proxy.go | 1042 ++++++++++++++++++++ models/config/server_common.go | 403 ++++++++ models/config/types.go | 112 +++ models/config/types_test.go | 40 + models/config/value.go | 64 ++ models/config/visitor.go | 213 ++++ models/consts/consts.go | 32 + models/errors/errors.go | 24 + models/msg/ctl.go | 46 + models/msg/msg.go | 185 ++++ models/nathole/nathole.go | 212 ++++ models/plugin/client/http2https.go | 111 +++ models/plugin/client/http_proxy.go | 243 +++++ models/plugin/client/https2http.go | 133 +++ models/plugin/client/plugin.go | 92 ++ models/plugin/client/socks5.go | 69 ++ models/plugin/client/static_file.go | 89 ++ models/plugin/client/unix_domain_socket.go | 72 ++ models/plugin/server/http.go | 104 ++ models/plugin/server/manager.go | 105 ++ models/plugin/server/plugin.go | 32 + models/plugin/server/tracer.go | 34 + models/plugin/server/types.go | 46 + models/proto/udp/udp.go | 137 +++ models/proto/udp/udp_test.go | 18 + utils/limit/reader.go | 51 + utils/limit/writer.go | 60 ++ utils/log/log.go | 93 ++ utils/metric/counter.go | 60 ++ utils/metric/counter_test.go | 23 + utils/metric/date_counter.go | 134 +++ utils/metric/date_counter_test.go | 27 + utils/net/conn.go | 242 +++++ utils/net/http.go | 115 +++ utils/net/kcp.go | 99 ++ utils/net/listener.go | 69 ++ utils/net/tls.go | 52 + utils/net/udp.go | 256 +++++ utils/net/websocket.go | 103 ++ utils/util/util.go | 103 ++ utils/util/util_test.go | 48 + utils/version/version.go | 82 ++ utils/version/version_test.go | 65 ++ utils/vhost/http.go | 216 ++++ utils/vhost/https.go | 194 ++++ utils/vhost/resource.go | 122 +++ utils/vhost/reverseproxy.go | 563 +++++++++++ utils/vhost/router.go | 119 +++ utils/vhost/vhost.go | 227 +++++ utils/xlog/ctx.go | 42 + utils/xlog/xlog.go | 73 ++ 77 files changed, 10156 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 client/control.go create mode 100644 client/event/event.go create mode 100644 client/health/health.go create mode 100644 client/proxy/proxy.go create mode 100644 client/proxy/proxy_manager.go create mode 100644 client/proxy/proxy_wrapper.go create mode 100644 client/service.go create mode 100644 client/visitor.go create mode 100644 client/visitor_manager.go create mode 100644 conf/frpc.ini create mode 100644 conf/frpc_full.ini create mode 100644 conf/frps.ini create mode 100644 conf/frps_full.ini create mode 100644 conf/systemd/frpc.service create mode 100644 conf/systemd/frpc@.service create mode 100644 conf/systemd/frps.service create mode 100644 conf/systemd/frps@.service create mode 100644 example/config.ini create mode 100644 example/main.go create mode 100644 frpc.go create mode 100644 frpc_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 models/config/client_common.go create mode 100644 models/config/proxy.go create mode 100644 models/config/server_common.go create mode 100644 models/config/types.go create mode 100644 models/config/types_test.go create mode 100644 models/config/value.go create mode 100644 models/config/visitor.go create mode 100644 models/consts/consts.go create mode 100644 models/errors/errors.go create mode 100644 models/msg/ctl.go create mode 100644 models/msg/msg.go create mode 100644 models/nathole/nathole.go create mode 100644 models/plugin/client/http2https.go create mode 100644 models/plugin/client/http_proxy.go create mode 100644 models/plugin/client/https2http.go create mode 100644 models/plugin/client/plugin.go create mode 100644 models/plugin/client/socks5.go create mode 100644 models/plugin/client/static_file.go create mode 100644 models/plugin/client/unix_domain_socket.go create mode 100644 models/plugin/server/http.go create mode 100644 models/plugin/server/manager.go create mode 100644 models/plugin/server/plugin.go create mode 100644 models/plugin/server/tracer.go create mode 100644 models/plugin/server/types.go create mode 100644 models/proto/udp/udp.go create mode 100644 models/proto/udp/udp_test.go create mode 100644 utils/limit/reader.go create mode 100644 utils/limit/writer.go create mode 100644 utils/log/log.go create mode 100644 utils/metric/counter.go create mode 100644 utils/metric/counter_test.go create mode 100644 utils/metric/date_counter.go create mode 100644 utils/metric/date_counter_test.go create mode 100644 utils/net/conn.go create mode 100644 utils/net/http.go create mode 100644 utils/net/kcp.go create mode 100644 utils/net/listener.go create mode 100644 utils/net/tls.go create mode 100644 utils/net/udp.go create mode 100644 utils/net/websocket.go create mode 100644 utils/util/util.go create mode 100644 utils/util/util_test.go create mode 100644 utils/version/version.go create mode 100644 utils/version/version_test.go create mode 100644 utils/vhost/http.go create mode 100644 utils/vhost/https.go create mode 100644 utils/vhost/resource.go create mode 100644 utils/vhost/reverseproxy.go create mode 100644 utils/vhost/router.go create mode 100644 utils/vhost/vhost.go create mode 100644 utils/xlog/ctx.go create mode 100644 utils/xlog/xlog.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..04a3bb1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +_testmain.go + +*.exe +*.test +*.prof + +# Self +bin/ +packages/ +test/bin/ + +# Cache +*.swp +*.zip diff --git a/README.md b/README.md new file mode 100644 index 0000000..75dbd35 --- /dev/null +++ b/README.md @@ -0,0 +1,6 @@ +# frpc + +## 调用方法 +``` +frpc.Run("config.ini") +``` \ No newline at end of file diff --git a/client/control.go b/client/control.go new file mode 100644 index 0000000..34566a8 --- /dev/null +++ b/client/control.go @@ -0,0 +1,342 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "crypto/tls" + "fmt" + "io" + "net" + "runtime/debug" + "sync" + "time" + + "github.com/charlesbao/frpc/client/proxy" + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/models/msg" + frpNet "github.com/charlesbao/frpc/utils/net" + "github.com/charlesbao/frpc/utils/xlog" + + "github.com/fatedier/golib/control/shutdown" + "github.com/fatedier/golib/crypto" + fmux "github.com/hashicorp/yamux" +) + +type Control struct { + // uniq id got from frps, attach it in loginMsg + runId string + + // manage all proxies + pxyCfgs map[string]config.ProxyConf + pm *proxy.ProxyManager + + // manage all visitors + vm *VisitorManager + + // control connection + conn net.Conn + + // tcp stream multiplexing, if enabled + session *fmux.Session + + // put a message in this channel to send it over control connection to server + sendCh chan (msg.Message) + + // read from this channel to get the next message sent by server + readCh chan (msg.Message) + + // goroutines can block by reading from this channel, it will be closed only in reader() when control connection is closed + closedCh chan struct{} + + closedDoneCh chan struct{} + + // last time got the Pong message + lastPong time.Time + + // The client configuration + clientCfg config.ClientCommonConf + + readerShutdown *shutdown.Shutdown + writerShutdown *shutdown.Shutdown + msgHandlerShutdown *shutdown.Shutdown + + // The UDP port that the server is listening on + serverUDPPort int + + mu sync.RWMutex + + xl *xlog.Logger + + // service context + ctx context.Context +} + +func NewControl(ctx context.Context, runId string, conn net.Conn, session *fmux.Session, + clientCfg config.ClientCommonConf, + pxyCfgs map[string]config.ProxyConf, + visitorCfgs map[string]config.VisitorConf, + serverUDPPort int) *Control { + + // new xlog instance + ctl := &Control{ + runId: runId, + conn: conn, + session: session, + pxyCfgs: pxyCfgs, + sendCh: make(chan msg.Message, 100), + readCh: make(chan msg.Message, 100), + closedCh: make(chan struct{}), + closedDoneCh: make(chan struct{}), + clientCfg: clientCfg, + readerShutdown: shutdown.New(), + writerShutdown: shutdown.New(), + msgHandlerShutdown: shutdown.New(), + serverUDPPort: serverUDPPort, + xl: xlog.FromContextSafe(ctx), + ctx: ctx, + } + ctl.pm = proxy.NewProxyManager(ctl.ctx, ctl.sendCh, clientCfg, serverUDPPort) + + ctl.vm = NewVisitorManager(ctl.ctx, ctl) + ctl.vm.Reload(visitorCfgs) + return ctl +} + +func (ctl *Control) Run() { + go ctl.worker() + + // start all proxies + ctl.pm.Reload(ctl.pxyCfgs) + + // start all visitors + go ctl.vm.Run() + return +} + +func (ctl *Control) HandleReqWorkConn(inMsg *msg.ReqWorkConn) { + xl := ctl.xl + workConn, err := ctl.connectServer() + if err != nil { + return + } + + m := &msg.NewWorkConn{ + RunId: ctl.runId, + } + if err = msg.WriteMsg(workConn, m); err != nil { + xl.Warn("work connection write to server error: %v", err) + workConn.Close() + return + } + + var startMsg msg.StartWorkConn + if err = msg.ReadMsgInto(workConn, &startMsg); err != nil { + xl.Error("work connection closed before response StartWorkConn message: %v", err) + workConn.Close() + return + } + + // dispatch this work connection to related proxy + ctl.pm.HandleWorkConn(startMsg.ProxyName, workConn, &startMsg) +} + +func (ctl *Control) HandleNewProxyResp(inMsg *msg.NewProxyResp) { + xl := ctl.xl + // Server will return NewProxyResp message to each NewProxy message. + // Start a new proxy handler if no error got + err := ctl.pm.StartProxy(inMsg.ProxyName, inMsg.RemoteAddr, inMsg.Error) + if err != nil { + xl.Warn("[%s] start error: %v", inMsg.ProxyName, err) + } else { + xl.Info("[%s] start proxy success", inMsg.ProxyName) + } +} + +func (ctl *Control) Close() error { + ctl.pm.Close() + ctl.conn.Close() + if ctl.session != nil { + ctl.session.Close() + } + return nil +} + +// ClosedDoneCh returns a channel which will be closed after all resources are released +func (ctl *Control) ClosedDoneCh() <-chan struct{} { + return ctl.closedDoneCh +} + +// connectServer return a new connection to frps +func (ctl *Control) connectServer() (conn net.Conn, err error) { + xl := ctl.xl + if ctl.clientCfg.TcpMux { + stream, errRet := ctl.session.OpenStream() + if errRet != nil { + err = errRet + xl.Warn("start new connection to server error: %v", err) + return + } + conn = stream + } else { + var tlsConfig *tls.Config + if ctl.clientCfg.TLSEnable { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + conn, err = frpNet.ConnectServerByProxyWithTLS(ctl.clientCfg.HttpProxy, ctl.clientCfg.Protocol, + fmt.Sprintf("%s:%d", ctl.clientCfg.ServerAddr, ctl.clientCfg.ServerPort), tlsConfig) + if err != nil { + xl.Warn("start new connection to server error: %v", err) + return + } + } + return +} + +// reader read all messages from frps and send to readCh +func (ctl *Control) reader() { + xl := ctl.xl + defer func() { + if err := recover(); err != nil { + xl.Error("panic error: %v", err) + xl.Error(string(debug.Stack())) + } + }() + defer ctl.readerShutdown.Done() + defer close(ctl.closedCh) + + encReader := crypto.NewReader(ctl.conn, []byte(ctl.clientCfg.Token)) + for { + if m, err := msg.ReadMsg(encReader); err != nil { + if err == io.EOF { + xl.Debug("read from control connection EOF") + return + } else { + xl.Warn("read error: %v", err) + ctl.conn.Close() + return + } + } else { + ctl.readCh <- m + } + } +} + +// writer writes messages got from sendCh to frps +func (ctl *Control) writer() { + xl := ctl.xl + defer ctl.writerShutdown.Done() + encWriter, err := crypto.NewWriter(ctl.conn, []byte(ctl.clientCfg.Token)) + if err != nil { + xl.Error("crypto new writer error: %v", err) + ctl.conn.Close() + return + } + for { + if m, ok := <-ctl.sendCh; !ok { + xl.Info("control writer is closing") + return + } else { + if err := msg.WriteMsg(encWriter, m); err != nil { + xl.Warn("write message to control connection error: %v", err) + return + } + } + } +} + +// msgHandler handles all channel events and do corresponding operations. +func (ctl *Control) msgHandler() { + xl := ctl.xl + defer func() { + if err := recover(); err != nil { + xl.Error("panic error: %v", err) + xl.Error(string(debug.Stack())) + } + }() + defer ctl.msgHandlerShutdown.Done() + + hbSend := time.NewTicker(time.Duration(ctl.clientCfg.HeartBeatInterval) * time.Second) + defer hbSend.Stop() + hbCheck := time.NewTicker(time.Second) + defer hbCheck.Stop() + + ctl.lastPong = time.Now() + + for { + select { + case <-hbSend.C: + // send heartbeat to server + xl.Debug("send heartbeat to server") + ctl.sendCh <- &msg.Ping{} + case <-hbCheck.C: + if time.Since(ctl.lastPong) > time.Duration(ctl.clientCfg.HeartBeatTimeout)*time.Second { + xl.Warn("heartbeat timeout") + // let reader() stop + ctl.conn.Close() + return + } + case rawMsg, ok := <-ctl.readCh: + if !ok { + return + } + + switch m := rawMsg.(type) { + case *msg.ReqWorkConn: + go ctl.HandleReqWorkConn(m) + case *msg.NewProxyResp: + ctl.HandleNewProxyResp(m) + case *msg.Pong: + ctl.lastPong = time.Now() + xl.Debug("receive heartbeat from server") + } + } + } +} + +// If controler is notified by closedCh, reader and writer and handler will exit +func (ctl *Control) worker() { + go ctl.msgHandler() + go ctl.reader() + go ctl.writer() + + select { + case <-ctl.closedCh: + // close related channels and wait until other goroutines done + close(ctl.readCh) + ctl.readerShutdown.WaitDone() + ctl.msgHandlerShutdown.WaitDone() + + close(ctl.sendCh) + ctl.writerShutdown.WaitDone() + + ctl.pm.Close() + ctl.vm.Close() + + close(ctl.closedDoneCh) + if ctl.session != nil { + ctl.session.Close() + } + return + } +} + +func (ctl *Control) ReloadConf(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.VisitorConf) error { + ctl.vm.Reload(visitorCfgs) + ctl.pm.Reload(pxyCfgs) + return nil +} diff --git a/client/event/event.go b/client/event/event.go new file mode 100644 index 0000000..6e22672 --- /dev/null +++ b/client/event/event.go @@ -0,0 +1,28 @@ +package event + +import ( + "errors" + + "github.com/charlesbao/frpc/models/msg" +) + +type EventType int + +const ( + EvStartProxy EventType = iota + EvCloseProxy +) + +var ( + ErrPayloadType = errors.New("error payload type") +) + +type EventHandler func(evType EventType, payload interface{}) error + +type StartProxyPayload struct { + NewProxyMsg *msg.NewProxy +} + +type CloseProxyPayload struct { + CloseProxyMsg *msg.CloseProxy +} diff --git a/client/health/health.go b/client/health/health.go new file mode 100644 index 0000000..59a001c --- /dev/null +++ b/client/health/health.go @@ -0,0 +1,171 @@ +// Copyright 2018 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package health + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "time" + + "github.com/charlesbao/frpc/utils/xlog" +) + +var ( + ErrHealthCheckType = errors.New("error health check type") +) + +type HealthCheckMonitor struct { + checkType string + interval time.Duration + timeout time.Duration + maxFailedTimes int + + // For tcp + addr string + + // For http + url string + + failedTimes uint64 + statusOK bool + statusNormalFn func() + statusFailedFn func() + + ctx context.Context + cancel context.CancelFunc +} + +func NewHealthCheckMonitor(ctx context.Context, checkType string, + intervalS int, timeoutS int, maxFailedTimes int, + addr string, url string, + statusNormalFn func(), statusFailedFn func()) *HealthCheckMonitor { + + if intervalS <= 0 { + intervalS = 10 + } + if timeoutS <= 0 { + timeoutS = 3 + } + if maxFailedTimes <= 0 { + maxFailedTimes = 1 + } + newctx, cancel := context.WithCancel(ctx) + return &HealthCheckMonitor{ + checkType: checkType, + interval: time.Duration(intervalS) * time.Second, + timeout: time.Duration(timeoutS) * time.Second, + maxFailedTimes: maxFailedTimes, + addr: addr, + url: url, + statusOK: false, + statusNormalFn: statusNormalFn, + statusFailedFn: statusFailedFn, + ctx: newctx, + cancel: cancel, + } +} + +func (monitor *HealthCheckMonitor) Start() { + go monitor.checkWorker() +} + +func (monitor *HealthCheckMonitor) Stop() { + monitor.cancel() +} + +func (monitor *HealthCheckMonitor) checkWorker() { + xl := xlog.FromContextSafe(monitor.ctx) + for { + doCtx, cancel := context.WithDeadline(monitor.ctx, time.Now().Add(monitor.timeout)) + err := monitor.doCheck(doCtx) + + // check if this monitor has been closed + select { + case <-monitor.ctx.Done(): + cancel() + return + default: + cancel() + } + + if err == nil { + xl.Trace("do one health check success") + if !monitor.statusOK && monitor.statusNormalFn != nil { + xl.Info("health check status change to success") + monitor.statusOK = true + monitor.statusNormalFn() + } + } else { + xl.Warn("do one health check failed: %v", err) + monitor.failedTimes++ + if monitor.statusOK && int(monitor.failedTimes) >= monitor.maxFailedTimes && monitor.statusFailedFn != nil { + xl.Warn("health check status change to failed") + monitor.statusOK = false + monitor.statusFailedFn() + } + } + + time.Sleep(monitor.interval) + } +} + +func (monitor *HealthCheckMonitor) doCheck(ctx context.Context) error { + switch monitor.checkType { + case "tcp": + return monitor.doTcpCheck(ctx) + case "http": + return monitor.doHttpCheck(ctx) + default: + return ErrHealthCheckType + } +} + +func (monitor *HealthCheckMonitor) doTcpCheck(ctx context.Context) error { + // if tcp address is not specified, always return nil + if monitor.addr == "" { + return nil + } + + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", monitor.addr) + if err != nil { + return err + } + conn.Close() + return nil +} + +func (monitor *HealthCheckMonitor) doHttpCheck(ctx context.Context) error { + req, err := http.NewRequest("GET", monitor.url, nil) + if err != nil { + return err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + io.Copy(ioutil.Discard, resp.Body) + + if resp.StatusCode/100 != 2 { + return fmt.Errorf("do http health check, StatusCode is [%d] not 2xx", resp.StatusCode) + } + return nil +} diff --git a/client/proxy/proxy.go b/client/proxy/proxy.go new file mode 100644 index 0000000..9534ca5 --- /dev/null +++ b/client/proxy/proxy.go @@ -0,0 +1,595 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package proxy + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net" + "strconv" + "strings" + "sync" + "time" + + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/models/msg" + plugin "github.com/charlesbao/frpc/models/plugin/client" + "github.com/charlesbao/frpc/models/proto/udp" + "github.com/charlesbao/frpc/utils/limit" + frpNet "github.com/charlesbao/frpc/utils/net" + "github.com/charlesbao/frpc/utils/xlog" + + "github.com/fatedier/golib/errors" + frpIo "github.com/fatedier/golib/io" + "github.com/fatedier/golib/pool" + fmux "github.com/hashicorp/yamux" + pp "github.com/pires/go-proxyproto" + "golang.org/x/time/rate" +) + +// Proxy defines how to handle work connections for different proxy type. +type Proxy interface { + Run() error + + // InWorkConn accept work connections registered to server. + InWorkConn(net.Conn, *msg.StartWorkConn) + + Close() +} + +func NewProxy(ctx context.Context, pxyConf config.ProxyConf, clientCfg config.ClientCommonConf, serverUDPPort int) (pxy Proxy) { + var limiter *rate.Limiter + limitBytes := pxyConf.GetBaseInfo().BandwidthLimit.Bytes() + if limitBytes > 0 { + limiter = rate.NewLimiter(rate.Limit(float64(limitBytes)), int(limitBytes)) + } + + baseProxy := BaseProxy{ + clientCfg: clientCfg, + serverUDPPort: serverUDPPort, + limiter: limiter, + xl: xlog.FromContextSafe(ctx), + ctx: ctx, + } + switch cfg := pxyConf.(type) { + case *config.TcpProxyConf: + pxy = &TcpProxy{ + BaseProxy: &baseProxy, + cfg: cfg, + } + case *config.UdpProxyConf: + pxy = &UdpProxy{ + BaseProxy: &baseProxy, + cfg: cfg, + } + case *config.HttpProxyConf: + pxy = &HttpProxy{ + BaseProxy: &baseProxy, + cfg: cfg, + } + case *config.HttpsProxyConf: + pxy = &HttpsProxy{ + BaseProxy: &baseProxy, + cfg: cfg, + } + case *config.StcpProxyConf: + pxy = &StcpProxy{ + BaseProxy: &baseProxy, + cfg: cfg, + } + case *config.XtcpProxyConf: + pxy = &XtcpProxy{ + BaseProxy: &baseProxy, + cfg: cfg, + } + } + return +} + +type BaseProxy struct { + closed bool + clientCfg config.ClientCommonConf + serverUDPPort int + limiter *rate.Limiter + + mu sync.RWMutex + xl *xlog.Logger + ctx context.Context +} + +// TCP +type TcpProxy struct { + *BaseProxy + + cfg *config.TcpProxyConf + proxyPlugin plugin.Plugin +} + +func (pxy *TcpProxy) Run() (err error) { + if pxy.cfg.Plugin != "" { + pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) + if err != nil { + return + } + } + return +} + +func (pxy *TcpProxy) Close() { + if pxy.proxyPlugin != nil { + pxy.proxyPlugin.Close() + } +} + +func (pxy *TcpProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + HandleTcpWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, &pxy.cfg.BaseProxyConf, pxy.limiter, + conn, []byte(pxy.clientCfg.Token), m) +} + +// HTTP +type HttpProxy struct { + *BaseProxy + + cfg *config.HttpProxyConf + proxyPlugin plugin.Plugin +} + +func (pxy *HttpProxy) Run() (err error) { + if pxy.cfg.Plugin != "" { + pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) + if err != nil { + return + } + } + return +} + +func (pxy *HttpProxy) Close() { + if pxy.proxyPlugin != nil { + pxy.proxyPlugin.Close() + } +} + +func (pxy *HttpProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + HandleTcpWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, &pxy.cfg.BaseProxyConf, pxy.limiter, + conn, []byte(pxy.clientCfg.Token), m) +} + +// HTTPS +type HttpsProxy struct { + *BaseProxy + + cfg *config.HttpsProxyConf + proxyPlugin plugin.Plugin +} + +func (pxy *HttpsProxy) Run() (err error) { + if pxy.cfg.Plugin != "" { + pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) + if err != nil { + return + } + } + return +} + +func (pxy *HttpsProxy) Close() { + if pxy.proxyPlugin != nil { + pxy.proxyPlugin.Close() + } +} + +func (pxy *HttpsProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + HandleTcpWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, &pxy.cfg.BaseProxyConf, pxy.limiter, + conn, []byte(pxy.clientCfg.Token), m) +} + +// STCP +type StcpProxy struct { + *BaseProxy + + cfg *config.StcpProxyConf + proxyPlugin plugin.Plugin +} + +func (pxy *StcpProxy) Run() (err error) { + if pxy.cfg.Plugin != "" { + pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) + if err != nil { + return + } + } + return +} + +func (pxy *StcpProxy) Close() { + if pxy.proxyPlugin != nil { + pxy.proxyPlugin.Close() + } +} + +func (pxy *StcpProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + HandleTcpWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, &pxy.cfg.BaseProxyConf, pxy.limiter, + conn, []byte(pxy.clientCfg.Token), m) +} + +// XTCP +type XtcpProxy struct { + *BaseProxy + + cfg *config.XtcpProxyConf + proxyPlugin plugin.Plugin +} + +func (pxy *XtcpProxy) Run() (err error) { + if pxy.cfg.Plugin != "" { + pxy.proxyPlugin, err = plugin.Create(pxy.cfg.Plugin, pxy.cfg.PluginParams) + if err != nil { + return + } + } + return +} + +func (pxy *XtcpProxy) Close() { + if pxy.proxyPlugin != nil { + pxy.proxyPlugin.Close() + } +} + +func (pxy *XtcpProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + xl := pxy.xl + defer conn.Close() + var natHoleSidMsg msg.NatHoleSid + err := msg.ReadMsgInto(conn, &natHoleSidMsg) + if err != nil { + xl.Error("xtcp read from workConn error: %v", err) + return + } + + natHoleClientMsg := &msg.NatHoleClient{ + ProxyName: pxy.cfg.ProxyName, + Sid: natHoleSidMsg.Sid, + } + raddr, _ := net.ResolveUDPAddr("udp", + fmt.Sprintf("%s:%d", pxy.clientCfg.ServerAddr, pxy.serverUDPPort)) + clientConn, err := net.DialUDP("udp", nil, raddr) + defer clientConn.Close() + + err = msg.WriteMsg(clientConn, natHoleClientMsg) + if err != nil { + xl.Error("send natHoleClientMsg to server error: %v", err) + return + } + + // Wait for client address at most 5 seconds. + var natHoleRespMsg msg.NatHoleResp + clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + + buf := pool.GetBuf(1024) + n, err := clientConn.Read(buf) + if err != nil { + xl.Error("get natHoleRespMsg error: %v", err) + return + } + err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg) + if err != nil { + xl.Error("get natHoleRespMsg error: %v", err) + return + } + clientConn.SetReadDeadline(time.Time{}) + clientConn.Close() + + if natHoleRespMsg.Error != "" { + xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) + return + } + + xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) + + // Send detect message + array := strings.Split(natHoleRespMsg.VisitorAddr, ":") + if len(array) <= 1 { + xl.Error("get NatHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) + } + laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String()) + /* + for i := 1000; i < 65000; i++ { + pxy.sendDetectMsg(array[0], int64(i), laddr, "a") + } + */ + port, err := strconv.ParseInt(array[1], 10, 64) + if err != nil { + xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr) + return + } + pxy.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid)) + xl.Trace("send all detect msg done") + + msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{}) + + // Listen for clientConn's address and wait for visitor connection + lConn, err := net.ListenUDP("udp", laddr) + if err != nil { + xl.Error("listen on visitorConn's local adress error: %v", err) + return + } + defer lConn.Close() + + lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) + sidBuf := pool.GetBuf(1024) + var uAddr *net.UDPAddr + n, uAddr, err = lConn.ReadFromUDP(sidBuf) + if err != nil { + xl.Warn("get sid from visitor error: %v", err) + return + } + lConn.SetReadDeadline(time.Time{}) + if string(sidBuf[:n]) != natHoleRespMsg.Sid { + xl.Warn("incorrect sid from visitor") + return + } + pool.PutBuf(sidBuf) + xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) + + lConn.WriteToUDP(sidBuf[:n], uAddr) + + kcpConn, err := frpNet.NewKcpConnFromUdp(lConn, false, uAddr.String()) + if err != nil { + xl.Error("create kcp connection from udp connection error: %v", err) + return + } + + fmuxCfg := fmux.DefaultConfig() + fmuxCfg.KeepAliveInterval = 5 * time.Second + fmuxCfg.LogOutput = ioutil.Discard + sess, err := fmux.Server(kcpConn, fmuxCfg) + if err != nil { + xl.Error("create yamux server from kcp connection error: %v", err) + return + } + defer sess.Close() + muxConn, err := sess.Accept() + if err != nil { + xl.Error("accept for yamux connection error: %v", err) + return + } + + HandleTcpWorkConnection(pxy.ctx, &pxy.cfg.LocalSvrConf, pxy.proxyPlugin, &pxy.cfg.BaseProxyConf, pxy.limiter, + muxConn, []byte(pxy.cfg.Sk), m) +} + +func (pxy *XtcpProxy) sendDetectMsg(addr string, port int, laddr *net.UDPAddr, content []byte) (err error) { + daddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", addr, port)) + if err != nil { + return err + } + + tConn, err := net.DialUDP("udp", laddr, daddr) + if err != nil { + return err + } + + //uConn := ipv4.NewConn(tConn) + //uConn.SetTTL(3) + + tConn.Write(content) + tConn.Close() + return nil +} + +// UDP +type UdpProxy struct { + *BaseProxy + + cfg *config.UdpProxyConf + + localAddr *net.UDPAddr + readCh chan *msg.UdpPacket + + // include msg.UdpPacket and msg.Ping + sendCh chan msg.Message + workConn net.Conn +} + +func (pxy *UdpProxy) Run() (err error) { + pxy.localAddr, err = net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pxy.cfg.LocalIp, pxy.cfg.LocalPort)) + if err != nil { + return + } + return +} + +func (pxy *UdpProxy) Close() { + pxy.mu.Lock() + defer pxy.mu.Unlock() + + if !pxy.closed { + pxy.closed = true + if pxy.workConn != nil { + pxy.workConn.Close() + } + if pxy.readCh != nil { + close(pxy.readCh) + } + if pxy.sendCh != nil { + close(pxy.sendCh) + } + } +} + +func (pxy *UdpProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) { + xl := pxy.xl + xl.Info("incoming a new work connection for udp proxy, %s", conn.RemoteAddr().String()) + // close resources releated with old workConn + pxy.Close() + + if pxy.limiter != nil { + rwc := frpIo.WrapReadWriteCloser(limit.NewReader(conn, pxy.limiter), limit.NewWriter(conn, pxy.limiter), func() error { + return conn.Close() + }) + conn = frpNet.WrapReadWriteCloserToConn(rwc, conn) + } + + pxy.mu.Lock() + pxy.workConn = conn + pxy.readCh = make(chan *msg.UdpPacket, 1024) + pxy.sendCh = make(chan msg.Message, 1024) + pxy.closed = false + pxy.mu.Unlock() + + workConnReaderFn := func(conn net.Conn, readCh chan *msg.UdpPacket) { + for { + var udpMsg msg.UdpPacket + if errRet := msg.ReadMsgInto(conn, &udpMsg); errRet != nil { + xl.Warn("read from workConn for udp error: %v", errRet) + return + } + if errRet := errors.PanicToError(func() { + xl.Trace("get udp package from workConn: %s", udpMsg.Content) + readCh <- &udpMsg + }); errRet != nil { + xl.Info("reader goroutine for udp work connection closed: %v", errRet) + return + } + } + } + workConnSenderFn := func(conn net.Conn, sendCh chan msg.Message) { + defer func() { + xl.Info("writer goroutine for udp work connection closed") + }() + var errRet error + for rawMsg := range sendCh { + switch m := rawMsg.(type) { + case *msg.UdpPacket: + xl.Trace("send udp package to workConn: %s", m.Content) + case *msg.Ping: + xl.Trace("send ping message to udp workConn") + } + if errRet = msg.WriteMsg(conn, rawMsg); errRet != nil { + xl.Error("udp work write error: %v", errRet) + return + } + } + } + heartbeatFn := func(conn net.Conn, sendCh chan msg.Message) { + var errRet error + for { + time.Sleep(time.Duration(30) * time.Second) + if errRet = errors.PanicToError(func() { + sendCh <- &msg.Ping{} + }); errRet != nil { + xl.Trace("heartbeat goroutine for udp work connection closed") + break + } + } + } + + go workConnSenderFn(pxy.workConn, pxy.sendCh) + go workConnReaderFn(pxy.workConn, pxy.readCh) + go heartbeatFn(pxy.workConn, pxy.sendCh) + udp.Forwarder(pxy.localAddr, pxy.readCh, pxy.sendCh) +} + +// Common handler for tcp work connections. +func HandleTcpWorkConnection(ctx context.Context, localInfo *config.LocalSvrConf, proxyPlugin plugin.Plugin, + baseInfo *config.BaseProxyConf, limiter *rate.Limiter, workConn net.Conn, encKey []byte, m *msg.StartWorkConn) { + xl := xlog.FromContextSafe(ctx) + var ( + remote io.ReadWriteCloser + err error + ) + remote = workConn + if limiter != nil { + remote = frpIo.WrapReadWriteCloser(limit.NewReader(workConn, limiter), limit.NewWriter(workConn, limiter), func() error { + return workConn.Close() + }) + } + + xl.Trace("handle tcp work connection, use_encryption: %t, use_compression: %t", + baseInfo.UseEncryption, baseInfo.UseCompression) + if baseInfo.UseEncryption { + remote, err = frpIo.WithEncryption(remote, encKey) + if err != nil { + workConn.Close() + xl.Error("create encryption stream error: %v", err) + return + } + } + if baseInfo.UseCompression { + remote = frpIo.WithCompression(remote) + } + + // check if we need to send proxy protocol info + var extraInfo []byte + if baseInfo.ProxyProtocolVersion != "" { + if m.SrcAddr != "" && m.SrcPort != 0 { + if m.DstAddr == "" { + m.DstAddr = "127.0.0.1" + } + h := &pp.Header{ + Command: pp.PROXY, + SourceAddress: net.ParseIP(m.SrcAddr), + SourcePort: m.SrcPort, + DestinationAddress: net.ParseIP(m.DstAddr), + DestinationPort: m.DstPort, + } + + if strings.Contains(m.SrcAddr, ".") { + h.TransportProtocol = pp.TCPv4 + } else { + h.TransportProtocol = pp.TCPv6 + } + + if baseInfo.ProxyProtocolVersion == "v1" { + h.Version = 1 + } else if baseInfo.ProxyProtocolVersion == "v2" { + h.Version = 2 + } + + buf := bytes.NewBuffer(nil) + h.WriteTo(buf) + extraInfo = buf.Bytes() + } + } + + if proxyPlugin != nil { + // if plugin is set, let plugin handle connections first + xl.Debug("handle by plugin: %s", proxyPlugin.Name()) + proxyPlugin.Handle(remote, workConn, extraInfo) + xl.Debug("handle by plugin finished") + return + } else { + localConn, err := frpNet.ConnectServer("tcp", fmt.Sprintf("%s:%d", localInfo.LocalIp, localInfo.LocalPort)) + if err != nil { + workConn.Close() + xl.Error("connect to local service [%s:%d] error: %v", localInfo.LocalIp, localInfo.LocalPort, err) + return + } + + xl.Debug("join connections, localConn(l[%s] r[%s]) workConn(l[%s] r[%s])", localConn.LocalAddr().String(), + localConn.RemoteAddr().String(), workConn.LocalAddr().String(), workConn.RemoteAddr().String()) + + if len(extraInfo) > 0 { + localConn.Write(extraInfo) + } + + frpIo.Join(localConn, remote) + xl.Debug("join connections closed") + } +} diff --git a/client/proxy/proxy_manager.go b/client/proxy/proxy_manager.go new file mode 100644 index 0000000..f8eda8a --- /dev/null +++ b/client/proxy/proxy_manager.go @@ -0,0 +1,146 @@ +package proxy + +import ( + "context" + "fmt" + "net" + "sync" + + "github.com/charlesbao/frpc/client/event" + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/models/msg" + "github.com/charlesbao/frpc/utils/xlog" + + "github.com/fatedier/golib/errors" +) + +type ProxyManager struct { + sendCh chan (msg.Message) + proxies map[string]*ProxyWrapper + + closed bool + mu sync.RWMutex + + clientCfg config.ClientCommonConf + + // The UDP port that the server is listening on + serverUDPPort int + + ctx context.Context +} + +func NewProxyManager(ctx context.Context, msgSendCh chan (msg.Message), clientCfg config.ClientCommonConf, serverUDPPort int) *ProxyManager { + return &ProxyManager{ + sendCh: msgSendCh, + proxies: make(map[string]*ProxyWrapper), + closed: false, + clientCfg: clientCfg, + serverUDPPort: serverUDPPort, + ctx: ctx, + } +} + +func (pm *ProxyManager) StartProxy(name string, remoteAddr string, serverRespErr string) error { + pm.mu.RLock() + pxy, ok := pm.proxies[name] + pm.mu.RUnlock() + if !ok { + return fmt.Errorf("proxy [%s] not found", name) + } + + err := pxy.SetRunningStatus(remoteAddr, serverRespErr) + if err != nil { + return err + } + return nil +} + +func (pm *ProxyManager) Close() { + pm.mu.Lock() + defer pm.mu.Unlock() + for _, pxy := range pm.proxies { + pxy.Stop() + } + pm.proxies = make(map[string]*ProxyWrapper) +} + +func (pm *ProxyManager) HandleWorkConn(name string, workConn net.Conn, m *msg.StartWorkConn) { + pm.mu.RLock() + pw, ok := pm.proxies[name] + pm.mu.RUnlock() + if ok { + pw.InWorkConn(workConn, m) + } else { + workConn.Close() + } +} + +func (pm *ProxyManager) HandleEvent(evType event.EventType, payload interface{}) error { + var m msg.Message + switch e := payload.(type) { + case *event.StartProxyPayload: + m = e.NewProxyMsg + case *event.CloseProxyPayload: + m = e.CloseProxyMsg + default: + return event.ErrPayloadType + } + + err := errors.PanicToError(func() { + pm.sendCh <- m + }) + return err +} + +func (pm *ProxyManager) GetAllProxyStatus() []*ProxyStatus { + ps := make([]*ProxyStatus, 0) + pm.mu.RLock() + defer pm.mu.RUnlock() + for _, pxy := range pm.proxies { + ps = append(ps, pxy.GetStatus()) + } + return ps +} + +func (pm *ProxyManager) Reload(pxyCfgs map[string]config.ProxyConf) { + xl := xlog.FromContextSafe(pm.ctx) + pm.mu.Lock() + defer pm.mu.Unlock() + + delPxyNames := make([]string, 0) + for name, pxy := range pm.proxies { + del := false + cfg, ok := pxyCfgs[name] + if !ok { + del = true + } else { + if !pxy.Cfg.Compare(cfg) { + del = true + } + } + + if del { + delPxyNames = append(delPxyNames, name) + delete(pm.proxies, name) + + pxy.Stop() + } + } + if len(delPxyNames) > 0 { + xl.Info("proxy removed: %v", delPxyNames) + } + + addPxyNames := make([]string, 0) + for name, cfg := range pxyCfgs { + if _, ok := pm.proxies[name]; !ok { + pxy := NewProxyWrapper(pm.ctx, cfg, pm.clientCfg, pm.HandleEvent, pm.serverUDPPort) + pm.proxies[name] = pxy + addPxyNames = append(addPxyNames, name) + + pxy.Start() + } + } + if len(addPxyNames) > 0 { + xl.Info("proxy added: %v", addPxyNames) + } +} diff --git a/client/proxy/proxy_wrapper.go b/client/proxy/proxy_wrapper.go new file mode 100644 index 0000000..6ac181b --- /dev/null +++ b/client/proxy/proxy_wrapper.go @@ -0,0 +1,250 @@ +package proxy + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/charlesbao/frpc/client/event" + "github.com/charlesbao/frpc/client/health" + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/models/msg" + "github.com/charlesbao/frpc/utils/xlog" + + "github.com/fatedier/golib/errors" +) + +const ( + ProxyStatusNew = "new" + ProxyStatusWaitStart = "wait start" + ProxyStatusStartErr = "start error" + ProxyStatusRunning = "running" + ProxyStatusCheckFailed = "check failed" + ProxyStatusClosed = "closed" +) + +var ( + statusCheckInterval time.Duration = 3 * time.Second + waitResponseTimeout = 20 * time.Second + startErrTimeout = 30 * time.Second +) + +type ProxyStatus struct { + Name string `json:"name"` + Type string `json:"type"` + Status string `json:"status"` + Err string `json:"err"` + Cfg config.ProxyConf `json:"cfg"` + + // Got from server. + RemoteAddr string `json:"remote_addr"` +} + +type ProxyWrapper struct { + ProxyStatus + + // underlying proxy + pxy Proxy + + // if ProxyConf has healcheck config + // monitor will watch if it is alive + monitor *health.HealthCheckMonitor + + // event handler + handler event.EventHandler + + health uint32 + lastSendStartMsg time.Time + lastStartErr time.Time + closeCh chan struct{} + healthNotifyCh chan struct{} + mu sync.RWMutex + + xl *xlog.Logger + ctx context.Context +} + +func NewProxyWrapper(ctx context.Context, cfg config.ProxyConf, clientCfg config.ClientCommonConf, eventHandler event.EventHandler, serverUDPPort int) *ProxyWrapper { + baseInfo := cfg.GetBaseInfo() + xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(baseInfo.ProxyName) + pw := &ProxyWrapper{ + ProxyStatus: ProxyStatus{ + Name: baseInfo.ProxyName, + Type: baseInfo.ProxyType, + Status: ProxyStatusNew, + Cfg: cfg, + }, + closeCh: make(chan struct{}), + healthNotifyCh: make(chan struct{}), + handler: eventHandler, + xl: xl, + ctx: xlog.NewContext(ctx, xl), + } + + if baseInfo.HealthCheckType != "" { + pw.health = 1 // means failed + pw.monitor = health.NewHealthCheckMonitor(pw.ctx, baseInfo.HealthCheckType, baseInfo.HealthCheckIntervalS, + baseInfo.HealthCheckTimeoutS, baseInfo.HealthCheckMaxFailed, baseInfo.HealthCheckAddr, + baseInfo.HealthCheckUrl, pw.statusNormalCallback, pw.statusFailedCallback) + xl.Trace("enable health check monitor") + } + + pw.pxy = NewProxy(pw.ctx, pw.Cfg, clientCfg, serverUDPPort) + return pw +} + +func (pw *ProxyWrapper) SetRunningStatus(remoteAddr string, respErr string) error { + pw.mu.Lock() + defer pw.mu.Unlock() + if pw.Status != ProxyStatusWaitStart { + return fmt.Errorf("status not wait start, ignore start message") + } + + pw.RemoteAddr = remoteAddr + if respErr != "" { + pw.Status = ProxyStatusStartErr + pw.Err = respErr + pw.lastStartErr = time.Now() + return fmt.Errorf(pw.Err) + } + + if err := pw.pxy.Run(); err != nil { + pw.Status = ProxyStatusStartErr + pw.Err = err.Error() + pw.lastStartErr = time.Now() + return err + } + + pw.Status = ProxyStatusRunning + pw.Err = "" + return nil +} + +func (pw *ProxyWrapper) Start() { + go pw.checkWorker() + if pw.monitor != nil { + go pw.monitor.Start() + } +} + +func (pw *ProxyWrapper) Stop() { + pw.mu.Lock() + defer pw.mu.Unlock() + close(pw.closeCh) + close(pw.healthNotifyCh) + pw.pxy.Close() + if pw.monitor != nil { + pw.monitor.Stop() + } + pw.Status = ProxyStatusClosed + + pw.handler(event.EvCloseProxy, &event.CloseProxyPayload{ + CloseProxyMsg: &msg.CloseProxy{ + ProxyName: pw.Name, + }, + }) +} + +func (pw *ProxyWrapper) checkWorker() { + xl := pw.xl + if pw.monitor != nil { + // let monitor do check request first + time.Sleep(500 * time.Millisecond) + } + for { + // check proxy status + now := time.Now() + if atomic.LoadUint32(&pw.health) == 0 { + pw.mu.Lock() + if pw.Status == ProxyStatusNew || + pw.Status == ProxyStatusCheckFailed || + (pw.Status == ProxyStatusWaitStart && now.After(pw.lastSendStartMsg.Add(waitResponseTimeout))) || + (pw.Status == ProxyStatusStartErr && now.After(pw.lastStartErr.Add(startErrTimeout))) { + + xl.Trace("change status from [%s] to [%s]", pw.Status, ProxyStatusWaitStart) + pw.Status = ProxyStatusWaitStart + + var newProxyMsg msg.NewProxy + pw.Cfg.MarshalToMsg(&newProxyMsg) + pw.lastSendStartMsg = now + pw.handler(event.EvStartProxy, &event.StartProxyPayload{ + NewProxyMsg: &newProxyMsg, + }) + } + pw.mu.Unlock() + } else { + pw.mu.Lock() + if pw.Status == ProxyStatusRunning || pw.Status == ProxyStatusWaitStart { + pw.handler(event.EvCloseProxy, &event.CloseProxyPayload{ + CloseProxyMsg: &msg.CloseProxy{ + ProxyName: pw.Name, + }, + }) + xl.Trace("change status from [%s] to [%s]", pw.Status, ProxyStatusCheckFailed) + pw.Status = ProxyStatusCheckFailed + } + pw.mu.Unlock() + } + + select { + case <-pw.closeCh: + return + case <-time.After(statusCheckInterval): + case <-pw.healthNotifyCh: + } + } +} + +func (pw *ProxyWrapper) statusNormalCallback() { + xl := pw.xl + atomic.StoreUint32(&pw.health, 0) + errors.PanicToError(func() { + select { + case pw.healthNotifyCh <- struct{}{}: + default: + } + }) + xl.Info("health check success") +} + +func (pw *ProxyWrapper) statusFailedCallback() { + xl := pw.xl + atomic.StoreUint32(&pw.health, 1) + errors.PanicToError(func() { + select { + case pw.healthNotifyCh <- struct{}{}: + default: + } + }) + xl.Info("health check failed") +} + +func (pw *ProxyWrapper) InWorkConn(workConn net.Conn, m *msg.StartWorkConn) { + xl := pw.xl + pw.mu.RLock() + pxy := pw.pxy + pw.mu.RUnlock() + if pxy != nil { + xl.Debug("start a new work connection, localAddr: %s remoteAddr: %s", workConn.LocalAddr().String(), workConn.RemoteAddr().String()) + go pxy.InWorkConn(workConn, m) + } else { + workConn.Close() + } +} + +func (pw *ProxyWrapper) GetStatus() *ProxyStatus { + pw.mu.RLock() + defer pw.mu.RUnlock() + ps := &ProxyStatus{ + Name: pw.Name, + Type: pw.Type, + Status: pw.Status, + Err: pw.Err, + Cfg: pw.Cfg, + RemoteAddr: pw.RemoteAddr, + } + return ps +} diff --git a/client/service.go b/client/service.go new file mode 100644 index 0000000..be4c4e5 --- /dev/null +++ b/client/service.go @@ -0,0 +1,252 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "crypto/tls" + "fmt" + "io/ioutil" + "net" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/models/msg" + frpNet "github.com/charlesbao/frpc/utils/net" + "github.com/charlesbao/frpc/utils/util" + "github.com/charlesbao/frpc/utils/version" + "github.com/charlesbao/frpc/utils/xlog" + + fmux "github.com/hashicorp/yamux" +) + +// Service is a client service. +type Service struct { + // uniq id got from frps, attach it in loginMsg + runId string + + // manager control connection with server + ctl *Control + ctlMu sync.RWMutex + + cfg config.ClientCommonConf + pxyCfgs map[string]config.ProxyConf + visitorCfgs map[string]config.VisitorConf + cfgMu sync.RWMutex + + // The configuration file used to initialize this client, or an empty + // string if no configuration file was used. + cfgFile string + + // This is configured by the login response from frps + serverUDPPort int + + exit uint32 // 0 means not exit + + // service context + ctx context.Context + // call cancel to stop service + cancel context.CancelFunc +} + +func NewService(cfg config.ClientCommonConf, pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.VisitorConf, cfgFile string) (svr *Service, err error) { + + ctx, cancel := context.WithCancel(context.Background()) + svr = &Service{ + cfg: cfg, + cfgFile: cfgFile, + pxyCfgs: pxyCfgs, + visitorCfgs: visitorCfgs, + exit: 0, + ctx: xlog.NewContext(ctx, xlog.New()), + cancel: cancel, + } + return +} + +func (svr *Service) GetController() *Control { + svr.ctlMu.RLock() + defer svr.ctlMu.RUnlock() + return svr.ctl +} + +func (svr *Service) Run() error { + xl := xlog.FromContextSafe(svr.ctx) + + // login to frps + for { + conn, session, err := svr.login() + if err != nil { + xl.Warn("login to server failed: %v", err) + + // if login_fail_exit is true, just exit this program + // otherwise sleep a while and try again to connect to server + if svr.cfg.LoginFailExit { + return err + } else { + time.Sleep(10 * time.Second) + } + } else { + // login success + ctl := NewControl(svr.ctx, svr.runId, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort) + ctl.Run() + svr.ctlMu.Lock() + svr.ctl = ctl + svr.ctlMu.Unlock() + break + } + } + + go svr.keepControllerWorking() + + <-svr.ctx.Done() + return nil +} + +func (svr *Service) keepControllerWorking() { + xl := xlog.FromContextSafe(svr.ctx) + maxDelayTime := 20 * time.Second + delayTime := time.Second + + for { + <-svr.ctl.ClosedDoneCh() + if atomic.LoadUint32(&svr.exit) != 0 { + return + } + + for { + xl.Info("try to reconnect to server...") + conn, session, err := svr.login() + if err != nil { + xl.Warn("reconnect to server error: %v", err) + time.Sleep(delayTime) + delayTime = delayTime * 2 + if delayTime > maxDelayTime { + delayTime = maxDelayTime + } + continue + } + // reconnect success, init delayTime + delayTime = time.Second + + ctl := NewControl(svr.ctx, svr.runId, conn, session, svr.cfg, svr.pxyCfgs, svr.visitorCfgs, svr.serverUDPPort) + ctl.Run() + svr.ctlMu.Lock() + svr.ctl = ctl + svr.ctlMu.Unlock() + break + } + } +} + +// login creates a connection to frps and registers it self as a client +// conn: control connection +// session: if it's not nil, using tcp mux +func (svr *Service) login() (conn net.Conn, session *fmux.Session, err error) { + xl := xlog.FromContextSafe(svr.ctx) + var tlsConfig *tls.Config + if svr.cfg.TLSEnable { + tlsConfig = &tls.Config{ + InsecureSkipVerify: true, + } + } + conn, err = frpNet.ConnectServerByProxyWithTLS(svr.cfg.HttpProxy, svr.cfg.Protocol, + fmt.Sprintf("%s:%d", svr.cfg.ServerAddr, svr.cfg.ServerPort), tlsConfig) + if err != nil { + return + } + + defer func() { + if err != nil { + conn.Close() + if session != nil { + session.Close() + } + } + }() + + if svr.cfg.TcpMux { + fmuxCfg := fmux.DefaultConfig() + fmuxCfg.KeepAliveInterval = 20 * time.Second + fmuxCfg.LogOutput = ioutil.Discard + session, err = fmux.Client(conn, fmuxCfg) + if err != nil { + return + } + stream, errRet := session.OpenStream() + if errRet != nil { + session.Close() + err = errRet + return + } + conn = stream + } + + now := time.Now().Unix() + loginMsg := &msg.Login{ + Arch: runtime.GOARCH, + Os: runtime.GOOS, + PoolCount: svr.cfg.PoolCount, + User: svr.cfg.User, + Version: version.Full(), + PrivilegeKey: util.GetAuthKey(svr.cfg.Token, now), + Timestamp: now, + RunId: svr.runId, + Metas: svr.cfg.Metas, + } + + if err = msg.WriteMsg(conn, loginMsg); err != nil { + return + } + + var loginRespMsg msg.LoginResp + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + if err = msg.ReadMsgInto(conn, &loginRespMsg); err != nil { + return + } + conn.SetReadDeadline(time.Time{}) + + if loginRespMsg.Error != "" { + err = fmt.Errorf("%s", loginRespMsg.Error) + xl.Error("%s", loginRespMsg.Error) + return + } + + svr.runId = loginRespMsg.RunId + xl.ResetPrefixes() + xl.AppendPrefix(svr.runId) + + svr.serverUDPPort = loginRespMsg.ServerUdpPort + xl.Info("login to server success, get run id [%s], server udp port [%d]", loginRespMsg.RunId, loginRespMsg.ServerUdpPort) + return +} + +func (svr *Service) ReloadConf(pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.VisitorConf) error { + svr.cfgMu.Lock() + svr.pxyCfgs = pxyCfgs + svr.visitorCfgs = visitorCfgs + svr.cfgMu.Unlock() + + return svr.ctl.ReloadConf(pxyCfgs, visitorCfgs) +} + +func (svr *Service) Close() { + atomic.StoreUint32(&svr.exit, 1) + svr.ctl.Close() + svr.cancel() +} diff --git a/client/visitor.go b/client/visitor.go new file mode 100644 index 0000000..dc73068 --- /dev/null +++ b/client/visitor.go @@ -0,0 +1,330 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "bytes" + "context" + "fmt" + "io" + "io/ioutil" + "net" + "sync" + "time" + + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/models/msg" + frpNet "github.com/charlesbao/frpc/utils/net" + "github.com/charlesbao/frpc/utils/util" + "github.com/charlesbao/frpc/utils/xlog" + + frpIo "github.com/fatedier/golib/io" + "github.com/fatedier/golib/pool" + fmux "github.com/hashicorp/yamux" +) + +// Visitor is used for forward traffics from local port tot remote service. +type Visitor interface { + Run() error + Close() +} + +func NewVisitor(ctx context.Context, ctl *Control, cfg config.VisitorConf) (visitor Visitor) { + xl := xlog.FromContextSafe(ctx).Spawn().AppendPrefix(cfg.GetBaseInfo().ProxyName) + baseVisitor := BaseVisitor{ + ctl: ctl, + ctx: xlog.NewContext(ctx, xl), + } + switch cfg := cfg.(type) { + case *config.StcpVisitorConf: + visitor = &StcpVisitor{ + BaseVisitor: &baseVisitor, + cfg: cfg, + } + case *config.XtcpVisitorConf: + visitor = &XtcpVisitor{ + BaseVisitor: &baseVisitor, + cfg: cfg, + } + } + return +} + +type BaseVisitor struct { + ctl *Control + l net.Listener + closed bool + + mu sync.RWMutex + ctx context.Context +} + +type StcpVisitor struct { + *BaseVisitor + + cfg *config.StcpVisitorConf +} + +func (sv *StcpVisitor) Run() (err error) { + sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort)) + if err != nil { + return + } + + go sv.worker() + return +} + +func (sv *StcpVisitor) Close() { + sv.l.Close() +} + +func (sv *StcpVisitor) worker() { + xl := xlog.FromContextSafe(sv.ctx) + for { + conn, err := sv.l.Accept() + if err != nil { + xl.Warn("stcp local listener closed") + return + } + + go sv.handleConn(conn) + } +} + +func (sv *StcpVisitor) handleConn(userConn net.Conn) { + xl := xlog.FromContextSafe(sv.ctx) + defer userConn.Close() + + xl.Debug("get a new stcp user connection") + visitorConn, err := sv.ctl.connectServer() + if err != nil { + return + } + defer visitorConn.Close() + + now := time.Now().Unix() + newVisitorConnMsg := &msg.NewVisitorConn{ + ProxyName: sv.cfg.ServerName, + SignKey: util.GetAuthKey(sv.cfg.Sk, now), + Timestamp: now, + UseEncryption: sv.cfg.UseEncryption, + UseCompression: sv.cfg.UseCompression, + } + err = msg.WriteMsg(visitorConn, newVisitorConnMsg) + if err != nil { + xl.Warn("send newVisitorConnMsg to server error: %v", err) + return + } + + var newVisitorConnRespMsg msg.NewVisitorConnResp + visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + err = msg.ReadMsgInto(visitorConn, &newVisitorConnRespMsg) + if err != nil { + xl.Warn("get newVisitorConnRespMsg error: %v", err) + return + } + visitorConn.SetReadDeadline(time.Time{}) + + if newVisitorConnRespMsg.Error != "" { + xl.Warn("start new visitor connection error: %s", newVisitorConnRespMsg.Error) + return + } + + var remote io.ReadWriteCloser + remote = visitorConn + if sv.cfg.UseEncryption { + remote, err = frpIo.WithEncryption(remote, []byte(sv.cfg.Sk)) + if err != nil { + xl.Error("create encryption stream error: %v", err) + return + } + } + + if sv.cfg.UseCompression { + remote = frpIo.WithCompression(remote) + } + + frpIo.Join(userConn, remote) +} + +type XtcpVisitor struct { + *BaseVisitor + + cfg *config.XtcpVisitorConf +} + +func (sv *XtcpVisitor) Run() (err error) { + sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort)) + if err != nil { + return + } + + go sv.worker() + return +} + +func (sv *XtcpVisitor) Close() { + sv.l.Close() +} + +func (sv *XtcpVisitor) worker() { + xl := xlog.FromContextSafe(sv.ctx) + for { + conn, err := sv.l.Accept() + if err != nil { + xl.Warn("xtcp local listener closed") + return + } + + go sv.handleConn(conn) + } +} + +func (sv *XtcpVisitor) handleConn(userConn net.Conn) { + xl := xlog.FromContextSafe(sv.ctx) + defer userConn.Close() + + xl.Debug("get a new xtcp user connection") + if sv.ctl.serverUDPPort == 0 { + xl.Error("xtcp is not supported by server") + return + } + + raddr, err := net.ResolveUDPAddr("udp", + fmt.Sprintf("%s:%d", sv.ctl.clientCfg.ServerAddr, sv.ctl.serverUDPPort)) + if err != nil { + xl.Error("resolve server UDP addr error") + return + } + + visitorConn, err := net.DialUDP("udp", nil, raddr) + if err != nil { + xl.Warn("dial server udp addr error: %v", err) + return + } + defer visitorConn.Close() + + now := time.Now().Unix() + natHoleVisitorMsg := &msg.NatHoleVisitor{ + ProxyName: sv.cfg.ServerName, + SignKey: util.GetAuthKey(sv.cfg.Sk, now), + Timestamp: now, + } + err = msg.WriteMsg(visitorConn, natHoleVisitorMsg) + if err != nil { + xl.Warn("send natHoleVisitorMsg to server error: %v", err) + return + } + + // Wait for client address at most 10 seconds. + var natHoleRespMsg msg.NatHoleResp + visitorConn.SetReadDeadline(time.Now().Add(10 * time.Second)) + buf := pool.GetBuf(1024) + n, err := visitorConn.Read(buf) + if err != nil { + xl.Warn("get natHoleRespMsg error: %v", err) + return + } + + err = msg.ReadMsgInto(bytes.NewReader(buf[:n]), &natHoleRespMsg) + if err != nil { + xl.Warn("get natHoleRespMsg error: %v", err) + return + } + visitorConn.SetReadDeadline(time.Time{}) + pool.PutBuf(buf) + + if natHoleRespMsg.Error != "" { + xl.Error("natHoleRespMsg get error info: %s", natHoleRespMsg.Error) + return + } + + xl.Trace("get natHoleRespMsg, sid [%s], client address [%s], visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr) + + // Close visitorConn, so we can use it's local address. + visitorConn.Close() + + // send sid message to client + laddr, _ := net.ResolveUDPAddr("udp", visitorConn.LocalAddr().String()) + daddr, err := net.ResolveUDPAddr("udp", natHoleRespMsg.ClientAddr) + if err != nil { + xl.Error("resolve client udp address error: %v", err) + return + } + lConn, err := net.DialUDP("udp", laddr, daddr) + if err != nil { + xl.Error("dial client udp address error: %v", err) + return + } + defer lConn.Close() + + lConn.Write([]byte(natHoleRespMsg.Sid)) + + // read ack sid from client + sidBuf := pool.GetBuf(1024) + lConn.SetReadDeadline(time.Now().Add(8 * time.Second)) + n, err = lConn.Read(sidBuf) + if err != nil { + xl.Warn("get sid from client error: %v", err) + return + } + lConn.SetReadDeadline(time.Time{}) + if string(sidBuf[:n]) != natHoleRespMsg.Sid { + xl.Warn("incorrect sid from client") + return + } + pool.PutBuf(sidBuf) + + xl.Info("nat hole connection make success, sid [%s]", natHoleRespMsg.Sid) + + // wrap kcp connection + var remote io.ReadWriteCloser + remote, err = frpNet.NewKcpConnFromUdp(lConn, true, natHoleRespMsg.ClientAddr) + if err != nil { + xl.Error("create kcp connection from udp connection error: %v", err) + return + } + + fmuxCfg := fmux.DefaultConfig() + fmuxCfg.KeepAliveInterval = 5 * time.Second + fmuxCfg.LogOutput = ioutil.Discard + sess, err := fmux.Client(remote, fmuxCfg) + if err != nil { + xl.Error("create yamux session error: %v", err) + return + } + defer sess.Close() + muxConn, err := sess.Open() + if err != nil { + xl.Error("open yamux stream error: %v", err) + return + } + + var muxConnRWCloser io.ReadWriteCloser = muxConn + if sv.cfg.UseEncryption { + muxConnRWCloser, err = frpIo.WithEncryption(muxConnRWCloser, []byte(sv.cfg.Sk)) + if err != nil { + xl.Error("create encryption stream error: %v", err) + return + } + } + if sv.cfg.UseCompression { + muxConnRWCloser = frpIo.WithCompression(muxConnRWCloser) + } + + frpIo.Join(userConn, muxConnRWCloser) + xl.Debug("join connections closed") +} diff --git a/client/visitor_manager.go b/client/visitor_manager.go new file mode 100644 index 0000000..10510ac --- /dev/null +++ b/client/visitor_manager.go @@ -0,0 +1,129 @@ +// Copyright 2018 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "context" + "sync" + "time" + + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/utils/xlog" +) + +type VisitorManager struct { + ctl *Control + + cfgs map[string]config.VisitorConf + visitors map[string]Visitor + + checkInterval time.Duration + + mu sync.Mutex + ctx context.Context +} + +func NewVisitorManager(ctx context.Context, ctl *Control) *VisitorManager { + return &VisitorManager{ + ctl: ctl, + cfgs: make(map[string]config.VisitorConf), + visitors: make(map[string]Visitor), + checkInterval: 10 * time.Second, + ctx: ctx, + } +} + +func (vm *VisitorManager) Run() { + xl := xlog.FromContextSafe(vm.ctx) + for { + time.Sleep(vm.checkInterval) + vm.mu.Lock() + for _, cfg := range vm.cfgs { + name := cfg.GetBaseInfo().ProxyName + if _, exist := vm.visitors[name]; !exist { + xl.Info("try to start visitor [%s]", name) + vm.startVisitor(cfg) + } + } + vm.mu.Unlock() + } +} + +// Hold lock before calling this function. +func (vm *VisitorManager) startVisitor(cfg config.VisitorConf) (err error) { + xl := xlog.FromContextSafe(vm.ctx) + name := cfg.GetBaseInfo().ProxyName + visitor := NewVisitor(vm.ctx, vm.ctl, cfg) + err = visitor.Run() + if err != nil { + xl.Warn("start error: %v", err) + } else { + vm.visitors[name] = visitor + xl.Info("start visitor success") + } + return +} + +func (vm *VisitorManager) Reload(cfgs map[string]config.VisitorConf) { + xl := xlog.FromContextSafe(vm.ctx) + vm.mu.Lock() + defer vm.mu.Unlock() + + delNames := make([]string, 0) + for name, oldCfg := range vm.cfgs { + del := false + cfg, ok := cfgs[name] + if !ok { + del = true + } else { + if !oldCfg.Compare(cfg) { + del = true + } + } + + if del { + delNames = append(delNames, name) + delete(vm.cfgs, name) + if visitor, ok := vm.visitors[name]; ok { + visitor.Close() + } + delete(vm.visitors, name) + } + } + if len(delNames) > 0 { + xl.Info("visitor removed: %v", delNames) + } + + addNames := make([]string, 0) + for name, cfg := range cfgs { + if _, ok := vm.cfgs[name]; !ok { + vm.cfgs[name] = cfg + addNames = append(addNames, name) + vm.startVisitor(cfg) + } + } + if len(addNames) > 0 { + xl.Info("visitor added: %v", addNames) + } + return +} + +func (vm *VisitorManager) Close() { + vm.mu.Lock() + defer vm.mu.Unlock() + for _, v := range vm.visitors { + v.Close() + } +} diff --git a/conf/frpc.ini b/conf/frpc.ini new file mode 100644 index 0000000..38aabc5 --- /dev/null +++ b/conf/frpc.ini @@ -0,0 +1,11 @@ +[common] +server_addr = 167.179.84.169 +server_port = 2333 +log_level = info +token = japanese + +[web] +type = tcp +local_ip = 127.0.0.1 +local_port = 8080 +remote_port = 33000 diff --git a/conf/frpc_full.ini b/conf/frpc_full.ini new file mode 100644 index 0000000..8c86acb --- /dev/null +++ b/conf/frpc_full.ini @@ -0,0 +1,266 @@ +# [common] is integral section +[common] +# A literal address or host name for IPv6 must be enclosed +# in square brackets, as in "[::1]:80", "[ipv6-host]:http" or "[ipv6-host%zone]:80" +server_addr = 0.0.0.0 +server_port = 7000 + +# if you want to connect frps by http proxy or socks5 proxy, you can set http_proxy here or in global environment variables +# it only works when protocol is tcp +# http_proxy = http://user:passwd@192.168.1.128:8080 +# http_proxy = socks5://user:passwd@192.168.1.128:1080 + +# console or real logFile path like ./frpc.log +log_file = ./frpc.log + +# trace, debug, info, warn, error +log_level = info + +log_max_days = 3 + +# disable log colors when log_file is console, default is false +disable_log_color = false + +# for authentication +token = 12345678 + +# set admin address for control frpc's action by http api such as reload +admin_addr = 127.0.0.1 +admin_port = 7400 +admin_user = admin +admin_pwd = admin +# Admin assets directory. By default, these assets are bundled with frpc. +# assets_dir = ./static + +# connections will be established in advance, default value is zero +pool_count = 5 + +# if tcp stream multiplexing is used, default is true, it must be same with frps +tcp_mux = true + +# your proxy name will be changed to {user}.{proxy} +user = your_name + +# decide if exit program when first login failed, otherwise continuous relogin to frps +# default is true +login_fail_exit = true + +# communication protocol used to connect to server +# now it supports tcp and kcp and websocket, default is tcp +protocol = tcp + +# if tls_enable is true, frpc will connect frps by tls +tls_enable = true + +# specify a dns server, so frpc will use this instead of default one +# dns_server = 8.8.8.8 + +# proxy names you want to start seperated by ',' +# default is empty, means all proxies +# start = ssh,dns + +# heartbeat configure, it's not recommended to modify the default value +# the default value of heartbeat_interval is 10 and heartbeat_timeout is 90 +# heartbeat_interval = 30 +# heartbeat_timeout = 90 + +# additional meta info for client +meta_var1 = 123 +meta_var2 = 234 + +# 'ssh' is the unique proxy name +# if user in [common] section is not empty, it will be changed to {user}.{proxy} such as 'your_name.ssh' +[ssh] +# tcp | udp | http | https | stcp | xtcp, default is tcp +type = tcp +local_ip = 127.0.0.1 +local_port = 22 +# limit bandwith for this proxy, unit is KB and MB +bandwith_limit = 1MB +# true or false, if true, messages between frps and frpc will be encrypted, default is false +use_encryption = false +# if true, message will be compressed +use_compression = false +# remote port listen by frps +remote_port = 6001 +# frps will load balancing connections for proxies in same group +group = test_group +# group should have same group key +group_key = 123456 +# enable health check for the backend service, it support 'tcp' and 'http' now +# frpc will connect local service's port to detect it's healthy status +health_check_type = tcp +# health check connection timeout +health_check_timeout_s = 3 +# if continuous failed in 3 times, the proxy will be removed from frps +health_check_max_failed = 3 +# every 10 seconds will do a health check +health_check_interval_s = 10 +# additional meta info for each proxy +meta_var1 = 123 +meta_var2 = 234 + +[ssh_random] +type = tcp +local_ip = 127.0.0.1 +local_port = 22 +# if remote_port is 0, frps will assign a random port for you +remote_port = 0 + +# if you want to expose multiple ports, add 'range:' prefix to the section name +# frpc will generate multiple proxies such as 'tcp_port_6010', 'tcp_port_6011' and so on. +[range:tcp_port] +type = tcp +local_ip = 127.0.0.1 +local_port = 6010-6020,6022,6024-6028 +remote_port = 6010-6020,6022,6024-6028 +use_encryption = false +use_compression = false + +[dns] +type = udp +local_ip = 114.114.114.114 +local_port = 53 +remote_port = 6002 +use_encryption = false +use_compression = false + +[range:udp_port] +type = udp +local_ip = 127.0.0.1 +local_port = 6010-6020 +remote_port = 6010-6020 +use_encryption = false +use_compression = false + +# Resolve your domain names to [server_addr] so you can use http://web01.yourdomain.com to browse web01 and http://web02.yourdomain.com to browse web02 +[web01] +type = http +local_ip = 127.0.0.1 +local_port = 80 +use_encryption = false +use_compression = true +# http username and password are safety certification for http protocol +# if not set, you can access this custom_domains without certification +http_user = admin +http_pwd = admin +# if domain for frps is frps.com, then you can access [web01] proxy by URL http://test.frps.com +subdomain = web01 +custom_domains = web02.yourdomain.com +# locations is only available for http type +locations = /,/pic +host_header_rewrite = example.com +# params with prefix "header_" will be used to update http request headers +header_X-From-Where = frp +health_check_type = http +# frpc will send a GET http request '/status' to local http service +# http service is alive when it return 2xx http response code +health_check_url = /status +health_check_interval_s = 10 +health_check_max_failed = 3 +health_check_timeout_s = 3 + +[web02] +type = https +local_ip = 127.0.0.1 +local_port = 8000 +use_encryption = false +use_compression = false +subdomain = web01 +custom_domains = web02.yourdomain.com +# if not empty, frpc will use proxy protocol to transfer connection info to your local service +# v1 or v2 or empty +proxy_protocol_version = v2 + +[plugin_unix_domain_socket] +type = tcp +remote_port = 6003 +# if plugin is defined, local_ip and local_port is useless +# plugin will handle connections got from frps +plugin = unix_domain_socket +# params with prefix "plugin_" that plugin needed +plugin_unix_path = /var/run/docker.sock + +[plugin_http_proxy] +type = tcp +remote_port = 6004 +plugin = http_proxy +plugin_http_user = abc +plugin_http_passwd = abc + +[plugin_socks5] +type = tcp +remote_port = 6005 +plugin = socks5 +plugin_user = abc +plugin_passwd = abc + +[plugin_static_file] +type = tcp +remote_port = 6006 +plugin = static_file +plugin_local_path = /var/www/blog +plugin_strip_prefix = static +plugin_http_user = abc +plugin_http_passwd = abc + +[plugin_https2http] +type = https +custom_domains = test.yourdomain.com +plugin = https2http +plugin_local_addr = 127.0.0.1:80 +plugin_crt_path = ./server.crt +plugin_key_path = ./server.key +plugin_host_header_rewrite = 127.0.0.1 +plugin_header_X-From-Where = frp + +[plugin_http2https] +type = http +custom_domains = test.yourdomain.com +plugin = http2https +plugin_local_addr = 127.0.0.1:443 +plugin_host_header_rewrite = 127.0.0.1 +plugin_header_X-From-Where = frp + +[secret_tcp] +# If the type is secret tcp, remote_port is useless +# Who want to connect local port should deploy another frpc with stcp proxy and role is visitor +type = stcp +# sk used for authentication for visitors +sk = abcdefg +local_ip = 127.0.0.1 +local_port = 22 +use_encryption = false +use_compression = false + +# user of frpc should be same in both stcp server and stcp visitor +[secret_tcp_visitor] +# frpc role visitor -> frps -> frpc role server +role = visitor +type = stcp +# the server name you want to visitor +server_name = secret_tcp +sk = abcdefg +# connect this address to visitor stcp server +bind_addr = 127.0.0.1 +bind_port = 9000 +use_encryption = false +use_compression = false + +[p2p_tcp] +type = xtcp +sk = abcdefg +local_ip = 127.0.0.1 +local_port = 22 +use_encryption = false +use_compression = false + +[p2p_tcp_visitor] +role = visitor +type = xtcp +server_name = p2p_tcp +sk = abcdefg +bind_addr = 127.0.0.1 +bind_port = 9001 +use_encryption = false +use_compression = false diff --git a/conf/frps.ini b/conf/frps.ini new file mode 100644 index 0000000..8be24fc --- /dev/null +++ b/conf/frps.ini @@ -0,0 +1,13 @@ +[common] +bind_addr = 0.0.0.0 +bind_port = 2333 +token = japanese +log_level = info + + +dashboard_addr = 0.0.0.0 +dashboard_port = 7589 +dashboard_user = admin +dashboard_pwd = gameOver + +allow_ports = 2000-3000,3001,3003,4000-50000 \ No newline at end of file diff --git a/conf/frps_full.ini b/conf/frps_full.ini new file mode 100644 index 0000000..030a3b3 --- /dev/null +++ b/conf/frps_full.ini @@ -0,0 +1,83 @@ +# [common] is integral section +[common] +# A literal address or host name for IPv6 must be enclosed +# in square brackets, as in "[::1]:80", "[ipv6-host]:http" or "[ipv6-host%zone]:80" +bind_addr = 0.0.0.0 +bind_port = 7000 + +# udp port to help make udp hole to penetrate nat +bind_udp_port = 7001 + +# udp port used for kcp protocol, it can be same with 'bind_port' +# if not set, kcp is disabled in frps +kcp_bind_port = 7000 + +# specify which address proxy will listen for, default value is same with bind_addr +# proxy_bind_addr = 127.0.0.1 + +# if you want to support virtual host, you must set the http port for listening (optional) +# Note: http port and https port can be same with bind_port +vhost_http_port = 80 +vhost_https_port = 443 + +# response header timeout(seconds) for vhost http server, default is 60s +# vhost_http_timeout = 60 + +# set dashboard_addr and dashboard_port to view dashboard of frps +# dashboard_addr's default value is same with bind_addr +# dashboard is available only if dashboard_port is set +dashboard_addr = 0.0.0.0 +dashboard_port = 7500 + +# dashboard user and passwd for basic auth protect, if not set, both default value is admin +dashboard_user = admin +dashboard_pwd = admin + +# dashboard assets directory(only for debug mode) +# assets_dir = ./static +# console or real logFile path like ./frps.log +log_file = ./frps.log + +# trace, debug, info, warn, error +log_level = info + +log_max_days = 3 + +# disable log colors when log_file is console, default is false +disable_log_color = false + +# auth token +token = 12345678 + +# heartbeat configure, it's not recommended to modify the default value +# the default value of heartbeat_timeout is 90 +# heartbeat_timeout = 90 + +# only allow frpc to bind ports you list, if you set nothing, there won't be any limit +allow_ports = 2000-3000,3001,3003,4000-50000 + +# pool_count in each proxy will change to max_pool_count if they exceed the maximum value +max_pool_count = 5 + +# max ports can be used for each client, default value is 0 means no limit +max_ports_per_client = 0 + +# if subdomain_host is not empty, you can set subdomain when type is http or https in frpc's configure file +# when subdomain is test, the host used by routing is test.frps.com +subdomain_host = frps.com + +# if tcp stream multiplexing is used, default is true +tcp_mux = true + +# custom 404 page for HTTP requests +# custom_404_page = /path/to/404.html + +[plugin.user-manager] +addr = 127.0.0.1:9000 +path = /handler +ops = Login + +[plugin.port-manager] +addr = 127.0.0.1:9001 +path = /handler +ops = NewProxy diff --git a/conf/systemd/frpc.service b/conf/systemd/frpc.service new file mode 100644 index 0000000..dd88ce0 --- /dev/null +++ b/conf/systemd/frpc.service @@ -0,0 +1,14 @@ +[Unit] +Description=Frp Client Service +After=network.target + +[Service] +Type=simple +User=nobody +Restart=on-failure +RestartSec=5s +ExecStart=/usr/bin/frpc -c /etc/frp/frpc.ini +ExecReload=/usr/bin/frpc reload -c /etc/frp/frpc.ini + +[Install] +WantedBy=multi-user.target diff --git a/conf/systemd/frpc@.service b/conf/systemd/frpc@.service new file mode 100644 index 0000000..46251ed --- /dev/null +++ b/conf/systemd/frpc@.service @@ -0,0 +1,14 @@ +[Unit] +Description=Frp Client Service +After=network.target + +[Service] +Type=idle +User=nobody +Restart=on-failure +RestartSec=5s +ExecStart=/usr/bin/frpc -c /etc/frp/%i.ini +ExecReload=/usr/bin/frpc reload -c /etc/frp/%i.ini + +[Install] +WantedBy=multi-user.target diff --git a/conf/systemd/frps.service b/conf/systemd/frps.service new file mode 100644 index 0000000..1daa267 --- /dev/null +++ b/conf/systemd/frps.service @@ -0,0 +1,13 @@ +[Unit] +Description=Frp Server Service +After=network.target + +[Service] +Type=simple +User=nobody +Restart=on-failure +RestartSec=5s +ExecStart=/usr/bin/frps -c /etc/frp/frps.ini + +[Install] +WantedBy=multi-user.target diff --git a/conf/systemd/frps@.service b/conf/systemd/frps@.service new file mode 100644 index 0000000..8b625ca --- /dev/null +++ b/conf/systemd/frps@.service @@ -0,0 +1,13 @@ +[Unit] +Description=Frp Server Service +After=network.target + +[Service] +Type=simple +User=nobody +Restart=on-failure +RestartSec=5s +ExecStart=/usr/bin/frps -c /etc/frp/%i.ini + +[Install] +WantedBy=multi-user.target diff --git a/example/config.ini b/example/config.ini new file mode 100644 index 0000000..e69de29 diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..27a228e --- /dev/null +++ b/example/main.go @@ -0,0 +1,8 @@ +package main + +import "github.com/charlesbao/frpc" + +func main() { + err := frpc.RunClient("config.ini") + panic(err) +} diff --git a/frpc.go b/frpc.go new file mode 100644 index 0000000..a6f6084 --- /dev/null +++ b/frpc.go @@ -0,0 +1,136 @@ +// Copyright 2018 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package frpc + +import ( + "context" + "net" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/charlesbao/frpc/client" + "github.com/charlesbao/frpc/models/config" + "github.com/charlesbao/frpc/utils/log" +) + +var ( + cfgFile string + showVersion bool + + serverAddr string + user string + protocol string + token string + logLevel string + logFile string + logMaxDays int + disableLogColor bool + + proxyName string + localIp string + localPort int + remotePort int + useEncryption bool + useCompression bool + customDomains string + subDomain string + httpUser string + httpPwd string + locations string + hostHeaderRewrite string + role string + sk string + serverName string + bindAddr string + bindPort int + + kcpDoneCh chan struct{} +) + +func handleSignal(svr *client.Service) { + ch := make(chan os.Signal) + signal.Notify(ch, syscall.SIGINT, syscall.SIGTERM) + <-ch + svr.Close() + time.Sleep(250 * time.Millisecond) + close(kcpDoneCh) +} + +func parseClientCommonCfgFromIni(content string) (config.ClientCommonConf, error) { + cfg, err := config.UnmarshalClientConfFromIni(content) + if err != nil { + return config.ClientCommonConf{}, err + } + return cfg, err +} + +func Run(cfgFilePath string) (err error) { + var content string + content, err = config.GetRenderedConfFromFile(cfgFilePath) + if err != nil { + return + } + + cfg, err := parseClientCommonCfgFromIni(content) + if err != nil { + return + } + + pxyCfgs, visitorCfgs, err := config.LoadAllConfFromIni(cfg.User, content, cfg.Start) + if err != nil { + return err + } + + err = startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath) + return +} + +func startService(cfg config.ClientCommonConf, pxyCfgs map[string]config.ProxyConf, visitorCfgs map[string]config.VisitorConf, cfgFile string) (err error) { + log.InitLog(cfg.LogWay, cfg.LogFile, cfg.LogLevel, + cfg.LogMaxDays, cfg.DisableLogColor) + + if cfg.DnsServer != "" { + s := cfg.DnsServer + if !strings.Contains(s, ":") { + s += ":53" + } + // Change default dns server for frpc + net.DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + return net.Dial("udp", s) + }, + } + } + svr, errRet := client.NewService(cfg, pxyCfgs, visitorCfgs, cfgFile) + if errRet != nil { + err = errRet + return + } + + // Capture the exit signal if we use kcp. + if cfg.Protocol == "kcp" { + go handleSignal(svr) + } + + err = svr.Run() + if cfg.Protocol == "kcp" { + <-kcpDoneCh + } + return +} diff --git a/frpc_test.go b/frpc_test.go new file mode 100644 index 0000000..9ea0e05 --- /dev/null +++ b/frpc_test.go @@ -0,0 +1,9 @@ +package frpc + +import ( + "testing" +) + +func TestFrpc(t *testing.T) { + Run("example/config.ini") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..17245a5 --- /dev/null +++ b/go.mod @@ -0,0 +1,32 @@ +module github.com/charlesbao/frpc + +go 1.12 + +require ( + github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 + github.com/davecgh/go-spew v1.1.0 + github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb + github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049 + github.com/fatedier/kcp-go v2.0.4-0.20190803094908-fe8645b0a904+incompatible + github.com/golang/snappy v0.0.0-20170215233205-553a64147049 + github.com/gorilla/mux v1.7.3 + github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d + github.com/inconshreveable/mousetrap v1.0.0 + github.com/klauspost/cpuid v1.2.0 + github.com/klauspost/reedsolomon v1.9.1 + github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc + github.com/pkg/errors v0.8.0 + github.com/pmezard/go-difflib v1.0.0 + github.com/spf13/pflag v1.0.1 + github.com/stretchr/testify v1.3.0 + github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047 + github.com/templexxx/xor v0.0.0-20170926022130-0af8e873c554 + github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8 + github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec + github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae // indirect + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 + golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 + golang.org/x/sys v0.0.0-20200117145432-59e60aa80a0c + golang.org/x/text v0.3.2 + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..aa08912 --- /dev/null +++ b/go.sum @@ -0,0 +1,57 @@ +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= +github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb h1:wCrNShQidLmvVWn/0PikGmpdP0vtQmnvyRg3ZBEhczw= +github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb/go.mod h1:wx3gB6dbIfBRcucp94PI9Bt3I0F2c/MyNEWuhzpWiwk= +github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049 h1:teH578mf2ii42NHhIp3PhgvjU5bv+NFMq9fSQR8NaG8= +github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049/go.mod h1:DqIrnl0rp3Zybg9zbJmozTy1n8fYJoX+QoAj9slIkKM= +github.com/fatedier/kcp-go v2.0.4-0.20190803094908-fe8645b0a904+incompatible h1:ssXat9YXFvigNge/IkkZvFMn8yeYKFX+uI6wn2mLJ74= +github.com/fatedier/kcp-go v2.0.4-0.20190803094908-fe8645b0a904+incompatible/go.mod h1:YpCOaxj7vvMThhIQ9AfTOPW2sfztQR5WDfs7AflSy4s= +github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk= +github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw= +github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d h1:kJCB4vdITiW1eC1vq2e6IsrXKrZit1bv/TDYFGMp4BQ= +github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= +github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/klauspost/cpuid v1.2.0 h1:NMpwD2G9JSFOE1/TJjGSo5zG7Yb2bTe7eq1jH+irmeE= +github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/klauspost/reedsolomon v1.9.1 h1:kYrT1MlR4JH6PqOpC+okdb9CDTcwEC/BqpzK4WFyXL8= +github.com/klauspost/reedsolomon v1.9.1/go.mod h1:CwCi+NUr9pqSVktrkN+Ondf06rkhYZ/pcNv7fu+8Un4= +github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc h1:lNOt1SMsgHXTdpuGw+RpnJtzUcCb/oRKZP65pBy9pr8= +github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc/go.mod h1:6/gX3+E/IYGa0wMORlSMla999awQFdbaeQCHjSMKIzY= +github.com/pkg/errors v0.8.0 h1:WdK/asTD0HN+q6hsWO3/vpuAkAr+tw6aNJNDFFf0+qw= +github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spf13/pflag v1.0.1 h1:aCvUg6QPl3ibpQUxyLkrEkCHtPqYJL4x9AuhqVqFis4= +github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047 h1:K+jtWCOuZgCra7eXZ/VWn2FbJmrA/D058mTXhh2rq+8= +github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047/go.mod h1:wM7WEvslTq+iOEAMDLSzhVuOt5BRZ05WirO+b09GHQU= +github.com/templexxx/xor v0.0.0-20170926022130-0af8e873c554 h1:pexgSe+JCFuxG+uoMZLO+ce8KHtdHGhst4cs6rw3gmk= +github.com/templexxx/xor v0.0.0-20170926022130-0af8e873c554/go.mod h1:5XA7W9S6mni3h5uvOC75dA3m9CCCaS83lltmc0ukdi4= +github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8 h1:6CNSDqI1wiE+JqyOy5Qt/yo/DoNI2/QmmOZeiCid2Nw= +github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8/go.mod h1:XxO4hdhhrzAd+G4CjDqaOkd0hUzmtPR/d3EiBBMn/wc= +github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec h1:DGmKwyZwEB8dI7tbLt/I/gQuP559o/0FrAkHKlQM/Ks= +github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec/go.mod h1:owBmyHYMLkxyrugmfwE/DLJyW8Ro9mkphwuVErQ0iUw= +github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM= +github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae/go.mod h1:gXtu8J62kEgmN++bm9BVICuT/e8yiLI2KFobd/TRFsE= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 h1:Ao/3l156eZf2AW5wK8a7/smtodRU+gha3+BeqJ69lRk= +golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20200117145432-59e60aa80a0c h1:gUYreENmqtjZb2brVfUas1sC6UivSY8XwKwPo8tloLs= +golang.org/x/sys v0.0.0-20200117145432-59e60aa80a0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/models/config/client_common.go b/models/config/client_common.go new file mode 100644 index 0000000..2b5006b --- /dev/null +++ b/models/config/client_common.go @@ -0,0 +1,319 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "strconv" + "strings" + + ini "github.com/vaughan0/go-ini" +) + +// ClientCommonConf contains information for a client service. It is +// recommended to use GetDefaultClientConf instead of creating this object +// directly, so that all unspecified fields have reasonable default values. +type ClientCommonConf struct { + // ServerAddr specifies the address of the server to connect to. By + // default, this value is "0.0.0.0". + ServerAddr string `json:"server_addr"` + // ServerPort specifies the port to connect to the server on. By default, + // this value is 7000. + ServerPort int `json:"server_port"` + // HttpProxy specifies a proxy address to connect to the server through. If + // this value is "", the server will be connected to directly. By default, + // this value is read from the "http_proxy" environment variable. + HttpProxy string `json:"http_proxy"` + // LogFile specifies a file where logs will be written to. This value will + // only be used if LogWay is set appropriately. By default, this value is + // "console". + LogFile string `json:"log_file"` + // LogWay specifies the way logging is managed. Valid values are "console" + // or "file". If "console" is used, logs will be printed to stdout. If + // "file" is used, logs will be printed to LogFile. By default, this value + // is "console". + LogWay string `json:"log_way"` + // LogLevel specifies the minimum log level. Valid values are "trace", + // "debug", "info", "warn", and "error". By default, this value is "info". + LogLevel string `json:"log_level"` + // LogMaxDays specifies the maximum number of days to store log information + // before deletion. This is only used if LogWay == "file". By default, this + // value is 0. + LogMaxDays int64 `json:"log_max_days"` + // DisableLogColor disables log colors when LogWay == "console" when set to + // true. By default, this value is false. + DisableLogColor bool `json:"disable_log_color"` + // Token specifies the authorization token used to create keys to be sent + // to the server. The server must have a matching token for authorization + // to succeed. By default, this value is "". + Token string `json:"token"` + // AdminAddr specifies the address that the admin server binds to. By + // default, this value is "127.0.0.1". + AdminAddr string `json:"admin_addr"` + // AdminPort specifies the port for the admin server to listen on. If this + // value is 0, the admin server will not be started. By default, this value + // is 0. + AdminPort int `json:"admin_port"` + // AdminUser specifies the username that the admin server will use for + // login. By default, this value is "admin". + AdminUser string `json:"admin_user"` + // AdminPwd specifies the password that the admin server will use for + // login. By default, this value is "admin". + AdminPwd string `json:"admin_pwd"` + // AssetsDir specifies the local directory that the admin server will load + // resources from. If this value is "", assets will be loaded from the + // bundled executable using statik. By default, this value is "". + AssetsDir string `json:"assets_dir"` + // PoolCount specifies the number of connections the client will make to + // the server in advance. By default, this value is 0. + PoolCount int `json:"pool_count"` + // TcpMux toggles TCP stream multiplexing. This allows multiple requests + // from a client to share a single TCP connection. If this value is true, + // the server must have TCP multiplexing enabled as well. By default, this + // value is true. + TcpMux bool `json:"tcp_mux"` + // User specifies a prefix for proxy names to distinguish them from other + // clients. If this value is not "", proxy names will automatically be + // changed to "{user}.{proxy_name}". By default, this value is "". + User string `json:"user"` + // DnsServer specifies a DNS server address for FRPC to use. If this value + // is "", the default DNS will be used. By default, this value is "". + DnsServer string `json:"dns_server"` + // LoginFailExit controls whether or not the client should exit after a + // failed login attempt. If false, the client will retry until a login + // attempt succeeds. By default, this value is true. + LoginFailExit bool `json:"login_fail_exit"` + // Start specifies a set of enabled proxies by name. If this set is empty, + // all supplied proxies are enabled. By default, this value is an empty + // set. + Start map[string]struct{} `json:"start"` + // Protocol specifies the protocol to use when interacting with the server. + // Valid values are "tcp", "kcp", and "websocket". By default, this value + // is "tcp". + Protocol string `json:"protocol"` + // TLSEnable specifies whether or not TLS should be used when communicating + // with the server. + TLSEnable bool `json:"tls_enable"` + // HeartBeatInterval specifies at what interval heartbeats are sent to the + // server, in seconds. It is not recommended to change this value. By + // default, this value is 30. + HeartBeatInterval int64 `json:"heartbeat_interval"` + // HeartBeatTimeout specifies the maximum allowed heartbeat response delay + // before the connection is terminated, in seconds. It is not recommended + // to change this value. By default, this value is 90. + HeartBeatTimeout int64 `json:"heartbeat_timeout"` + // Client meta info + Metas map[string]string `json:"metas"` +} + +// GetDefaultClientConf returns a client configuration with default values. +func GetDefaultClientConf() ClientCommonConf { + return ClientCommonConf{ + ServerAddr: "0.0.0.0", + ServerPort: 7000, + HttpProxy: os.Getenv("http_proxy"), + LogFile: "console", + LogWay: "console", + LogLevel: "info", + LogMaxDays: 3, + DisableLogColor: false, + Token: "", + AdminAddr: "127.0.0.1", + AdminPort: 0, + AdminUser: "", + AdminPwd: "", + AssetsDir: "", + PoolCount: 1, + TcpMux: true, + User: "", + DnsServer: "", + LoginFailExit: true, + Start: make(map[string]struct{}), + Protocol: "tcp", + TLSEnable: false, + HeartBeatInterval: 30, + HeartBeatTimeout: 90, + Metas: make(map[string]string), + } +} + +func UnmarshalClientConfFromIni(content string) (cfg ClientCommonConf, err error) { + cfg = GetDefaultClientConf() + + conf, err := ini.Load(strings.NewReader(content)) + if err != nil { + return ClientCommonConf{}, fmt.Errorf("parse ini conf file error: %v", err) + } + + var ( + tmpStr string + ok bool + v int64 + ) + if tmpStr, ok = conf.Get("common", "server_addr"); ok { + cfg.ServerAddr = tmpStr + } + + if tmpStr, ok = conf.Get("common", "server_port"); ok { + v, err = strconv.ParseInt(tmpStr, 10, 64) + if err != nil { + err = fmt.Errorf("Parse conf error: invalid server_port") + return + } + cfg.ServerPort = int(v) + } + + if tmpStr, ok = conf.Get("common", "disable_log_color"); ok && tmpStr == "true" { + cfg.DisableLogColor = true + } + + if tmpStr, ok = conf.Get("common", "http_proxy"); ok { + cfg.HttpProxy = tmpStr + } + + if tmpStr, ok = conf.Get("common", "log_file"); ok { + cfg.LogFile = tmpStr + if cfg.LogFile == "console" { + cfg.LogWay = "console" + } else { + cfg.LogWay = "file" + } + } + + if tmpStr, ok = conf.Get("common", "log_level"); ok { + cfg.LogLevel = tmpStr + } + + if tmpStr, ok = conf.Get("common", "log_max_days"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { + cfg.LogMaxDays = v + } + } + + if tmpStr, ok = conf.Get("common", "token"); ok { + cfg.Token = tmpStr + } + + if tmpStr, ok = conf.Get("common", "admin_addr"); ok { + cfg.AdminAddr = tmpStr + } + + if tmpStr, ok = conf.Get("common", "admin_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { + cfg.AdminPort = int(v) + } else { + err = fmt.Errorf("Parse conf error: invalid admin_port") + return + } + } + + if tmpStr, ok = conf.Get("common", "admin_user"); ok { + cfg.AdminUser = tmpStr + } + + if tmpStr, ok = conf.Get("common", "admin_pwd"); ok { + cfg.AdminPwd = tmpStr + } + + if tmpStr, ok = conf.Get("common", "assets_dir"); ok { + cfg.AssetsDir = tmpStr + } + + if tmpStr, ok = conf.Get("common", "pool_count"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err == nil { + cfg.PoolCount = int(v) + } + } + + if tmpStr, ok = conf.Get("common", "tcp_mux"); ok && tmpStr == "false" { + cfg.TcpMux = false + } else { + cfg.TcpMux = true + } + + if tmpStr, ok = conf.Get("common", "user"); ok { + cfg.User = tmpStr + } + + if tmpStr, ok = conf.Get("common", "dns_server"); ok { + cfg.DnsServer = tmpStr + } + + if tmpStr, ok = conf.Get("common", "start"); ok { + proxyNames := strings.Split(tmpStr, ",") + for _, name := range proxyNames { + cfg.Start[strings.TrimSpace(name)] = struct{}{} + } + } + + if tmpStr, ok = conf.Get("common", "login_fail_exit"); ok && tmpStr == "false" { + cfg.LoginFailExit = false + } else { + cfg.LoginFailExit = true + } + + if tmpStr, ok = conf.Get("common", "protocol"); ok { + // Now it only support tcp and kcp and websocket. + if tmpStr != "tcp" && tmpStr != "kcp" && tmpStr != "websocket" { + err = fmt.Errorf("Parse conf error: invalid protocol") + return + } + cfg.Protocol = tmpStr + } + + if tmpStr, ok = conf.Get("common", "tls_enable"); ok && tmpStr == "true" { + cfg.TLSEnable = true + } else { + cfg.TLSEnable = false + } + + if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout") + return + } else { + cfg.HeartBeatTimeout = v + } + } + + if tmpStr, ok = conf.Get("common", "heartbeat_interval"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid heartbeat_interval") + return + } else { + cfg.HeartBeatInterval = v + } + } + for k, v := range conf.Section("common") { + if strings.HasPrefix(k, "meta_") { + cfg.Metas[strings.TrimPrefix(k, "meta_")] = v + } + } + return +} + +func (cfg *ClientCommonConf) Check() (err error) { + if cfg.HeartBeatInterval <= 0 { + err = fmt.Errorf("Parse conf error: invalid heartbeat_interval") + return + } + + if cfg.HeartBeatTimeout < cfg.HeartBeatInterval { + err = fmt.Errorf("Parse conf error: invalid heartbeat_timeout, heartbeat_timeout is less than heartbeat_interval") + return + } + return +} diff --git a/models/config/proxy.go b/models/config/proxy.go new file mode 100644 index 0000000..8b633bf --- /dev/null +++ b/models/config/proxy.go @@ -0,0 +1,1042 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/charlesbao/frpc/models/consts" + "github.com/charlesbao/frpc/models/msg" + "github.com/charlesbao/frpc/utils/util" + + ini "github.com/vaughan0/go-ini" +) + +var ( + proxyConfTypeMap map[string]reflect.Type +) + +func init() { + proxyConfTypeMap = make(map[string]reflect.Type) + proxyConfTypeMap[consts.TcpProxy] = reflect.TypeOf(TcpProxyConf{}) + proxyConfTypeMap[consts.UdpProxy] = reflect.TypeOf(UdpProxyConf{}) + proxyConfTypeMap[consts.HttpProxy] = reflect.TypeOf(HttpProxyConf{}) + proxyConfTypeMap[consts.HttpsProxy] = reflect.TypeOf(HttpsProxyConf{}) + proxyConfTypeMap[consts.StcpProxy] = reflect.TypeOf(StcpProxyConf{}) + proxyConfTypeMap[consts.XtcpProxy] = reflect.TypeOf(XtcpProxyConf{}) +} + +// NewConfByType creates a empty ProxyConf object by proxyType. +// If proxyType isn't exist, return nil. +func NewConfByType(proxyType string) ProxyConf { + v, ok := proxyConfTypeMap[proxyType] + if !ok { + return nil + } + cfg := reflect.New(v).Interface().(ProxyConf) + return cfg +} + +type ProxyConf interface { + GetBaseInfo() *BaseProxyConf + UnmarshalFromMsg(pMsg *msg.NewProxy) + UnmarshalFromIni(prefix string, name string, conf ini.Section) error + MarshalToMsg(pMsg *msg.NewProxy) + CheckForCli() error + CheckForSvr(serverCfg ServerCommonConf) error + Compare(conf ProxyConf) bool +} + +func NewProxyConfFromMsg(pMsg *msg.NewProxy, serverCfg ServerCommonConf) (cfg ProxyConf, err error) { + if pMsg.ProxyType == "" { + pMsg.ProxyType = consts.TcpProxy + } + + cfg = NewConfByType(pMsg.ProxyType) + if cfg == nil { + err = fmt.Errorf("proxy [%s] type [%s] error", pMsg.ProxyName, pMsg.ProxyType) + return + } + cfg.UnmarshalFromMsg(pMsg) + err = cfg.CheckForSvr(serverCfg) + return +} + +func NewProxyConfFromIni(prefix string, name string, section ini.Section) (cfg ProxyConf, err error) { + proxyType := section["type"] + if proxyType == "" { + proxyType = consts.TcpProxy + section["type"] = consts.TcpProxy + } + cfg = NewConfByType(proxyType) + if cfg == nil { + err = fmt.Errorf("proxy [%s] type [%s] error", name, proxyType) + return + } + if err = cfg.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + if err = cfg.CheckForCli(); err != nil { + return + } + return +} + +// BaseProxyConf provides configuration info that is common to all proxy types. +type BaseProxyConf struct { + // ProxyName is the name of this proxy. + ProxyName string `json:"proxy_name"` + // ProxyType specifies the type of this proxy. Valid values include "tcp", + // "udp", "http", "https", "stcp", and "xtcp". By default, this value is + // "tcp". + ProxyType string `json:"proxy_type"` + + // UseEncryption controls whether or not communication with the server will + // be encrypted. Encryption is done using the tokens supplied in the server + // and client configuration. By default, this value is false. + UseEncryption bool `json:"use_encryption"` + // UseCompression controls whether or not communication with the server + // will be compressed. By default, this value is false. + UseCompression bool `json:"use_compression"` + // Group specifies which group the proxy is a part of. The server will use + // this information to load balance proxies in the same group. If the value + // is "", this proxy will not be in a group. By default, this value is "". + Group string `json:"group"` + // GroupKey specifies a group key, which should be the same among proxies + // of the same group. By default, this value is "". + GroupKey string `json:"group_key"` + + // ProxyProtocolVersion specifies which protocol version to use. Valid + // values include "v1", "v2", and "". If the value is "", a protocol + // version will be automatically selected. By default, this value is "". + ProxyProtocolVersion string `json:"proxy_protocol_version"` + + // BandwidthLimit limit the proxy bandwidth + // 0 means no limit + BandwidthLimit BandwidthQuantity `json:"bandwidth_limit"` + + // meta info for each proxy + Metas map[string]string `json:"metas"` + + LocalSvrConf + HealthCheckConf +} + +func (cfg *BaseProxyConf) GetBaseInfo() *BaseProxyConf { + return cfg +} + +func (cfg *BaseProxyConf) compare(cmp *BaseProxyConf) bool { + if cfg.ProxyName != cmp.ProxyName || + cfg.ProxyType != cmp.ProxyType || + cfg.UseEncryption != cmp.UseEncryption || + cfg.UseCompression != cmp.UseCompression || + cfg.Group != cmp.Group || + cfg.GroupKey != cmp.GroupKey || + cfg.ProxyProtocolVersion != cmp.ProxyProtocolVersion || + cfg.BandwidthLimit.Equal(&cmp.BandwidthLimit) || + !reflect.DeepEqual(cfg.Metas, cmp.Metas) { + return false + } + if !cfg.LocalSvrConf.compare(&cmp.LocalSvrConf) { + return false + } + if !cfg.HealthCheckConf.compare(&cmp.HealthCheckConf) { + return false + } + return true +} + +func (cfg *BaseProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.ProxyName = pMsg.ProxyName + cfg.ProxyType = pMsg.ProxyType + cfg.UseEncryption = pMsg.UseEncryption + cfg.UseCompression = pMsg.UseCompression + cfg.Group = pMsg.Group + cfg.GroupKey = pMsg.GroupKey + cfg.Metas = pMsg.Metas +} + +func (cfg *BaseProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) error { + var ( + tmpStr string + ok bool + err error + ) + cfg.ProxyName = prefix + name + cfg.ProxyType = section["type"] + + tmpStr, ok = section["use_encryption"] + if ok && tmpStr == "true" { + cfg.UseEncryption = true + } + + tmpStr, ok = section["use_compression"] + if ok && tmpStr == "true" { + cfg.UseCompression = true + } + + cfg.Group = section["group"] + cfg.GroupKey = section["group_key"] + cfg.ProxyProtocolVersion = section["proxy_protocol_version"] + + if cfg.BandwidthLimit, err = NewBandwidthQuantity(section["bandwidth_limit"]); err != nil { + return err + } + + if err = cfg.LocalSvrConf.UnmarshalFromIni(prefix, name, section); err != nil { + return err + } + + if err = cfg.HealthCheckConf.UnmarshalFromIni(prefix, name, section); err != nil { + return err + } + + if cfg.HealthCheckType == "tcp" && cfg.Plugin == "" { + cfg.HealthCheckAddr = cfg.LocalIp + fmt.Sprintf(":%d", cfg.LocalPort) + } + if cfg.HealthCheckType == "http" && cfg.Plugin == "" && cfg.HealthCheckUrl != "" { + s := fmt.Sprintf("http://%s:%d", cfg.LocalIp, cfg.LocalPort) + if !strings.HasPrefix(cfg.HealthCheckUrl, "/") { + s += "/" + } + cfg.HealthCheckUrl = s + cfg.HealthCheckUrl + } + + cfg.Metas = make(map[string]string) + for k, v := range section { + if strings.HasPrefix(k, "meta_") { + cfg.Metas[strings.TrimPrefix(k, "meta_")] = v + } + } + return nil +} + +func (cfg *BaseProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { + pMsg.ProxyName = cfg.ProxyName + pMsg.ProxyType = cfg.ProxyType + pMsg.UseEncryption = cfg.UseEncryption + pMsg.UseCompression = cfg.UseCompression + pMsg.Group = cfg.Group + pMsg.GroupKey = cfg.GroupKey + pMsg.Metas = cfg.Metas +} + +func (cfg *BaseProxyConf) checkForCli() (err error) { + if cfg.ProxyProtocolVersion != "" { + if cfg.ProxyProtocolVersion != "v1" && cfg.ProxyProtocolVersion != "v2" { + return fmt.Errorf("no support proxy protocol version: %s", cfg.ProxyProtocolVersion) + } + } + + if err = cfg.LocalSvrConf.checkForCli(); err != nil { + return + } + if err = cfg.HealthCheckConf.checkForCli(); err != nil { + return + } + return nil +} + +// Bind info +type BindInfoConf struct { + RemotePort int `json:"remote_port"` +} + +func (cfg *BindInfoConf) compare(cmp *BindInfoConf) bool { + if cfg.RemotePort != cmp.RemotePort { + return false + } + return true +} + +func (cfg *BindInfoConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.RemotePort = pMsg.RemotePort +} + +func (cfg *BindInfoConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + var ( + tmpStr string + ok bool + v int64 + ) + if tmpStr, ok = section["remote_port"]; ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + return fmt.Errorf("Parse conf error: proxy [%s] remote_port error", name) + } else { + cfg.RemotePort = int(v) + } + } else { + return fmt.Errorf("Parse conf error: proxy [%s] remote_port not found", name) + } + return nil +} + +func (cfg *BindInfoConf) MarshalToMsg(pMsg *msg.NewProxy) { + pMsg.RemotePort = cfg.RemotePort +} + +// Domain info +type DomainConf struct { + CustomDomains []string `json:"custom_domains"` + SubDomain string `json:"sub_domain"` +} + +func (cfg *DomainConf) compare(cmp *DomainConf) bool { + if strings.Join(cfg.CustomDomains, " ") != strings.Join(cmp.CustomDomains, " ") || + cfg.SubDomain != cmp.SubDomain { + return false + } + return true +} + +func (cfg *DomainConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.CustomDomains = pMsg.CustomDomains + cfg.SubDomain = pMsg.SubDomain +} + +func (cfg *DomainConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + var ( + tmpStr string + ok bool + ) + if tmpStr, ok = section["custom_domains"]; ok { + cfg.CustomDomains = strings.Split(tmpStr, ",") + for i, domain := range cfg.CustomDomains { + cfg.CustomDomains[i] = strings.ToLower(strings.TrimSpace(domain)) + } + } + + if tmpStr, ok = section["subdomain"]; ok { + cfg.SubDomain = tmpStr + } + return +} + +func (cfg *DomainConf) MarshalToMsg(pMsg *msg.NewProxy) { + pMsg.CustomDomains = cfg.CustomDomains + pMsg.SubDomain = cfg.SubDomain +} + +func (cfg *DomainConf) check() (err error) { + if len(cfg.CustomDomains) == 0 && cfg.SubDomain == "" { + err = fmt.Errorf("custom_domains and subdomain should set at least one of them") + return + } + return +} + +func (cfg *DomainConf) checkForCli() (err error) { + if err = cfg.check(); err != nil { + return + } + return +} + +func (cfg *DomainConf) checkForSvr(serverCfg ServerCommonConf) (err error) { + if err = cfg.check(); err != nil { + return + } + + for _, domain := range cfg.CustomDomains { + if serverCfg.SubDomainHost != "" && len(strings.Split(serverCfg.SubDomainHost, ".")) < len(strings.Split(domain, ".")) { + if strings.Contains(domain, serverCfg.SubDomainHost) { + return fmt.Errorf("custom domain [%s] should not belong to subdomain_host [%s]", domain, serverCfg.SubDomainHost) + } + } + } + + if cfg.SubDomain != "" { + if serverCfg.SubDomainHost == "" { + return fmt.Errorf("subdomain is not supported because this feature is not enabled in remote frps") + } + if strings.Contains(cfg.SubDomain, ".") || strings.Contains(cfg.SubDomain, "*") { + return fmt.Errorf("'.' and '*' is not supported in subdomain") + } + } + return +} + +// LocalSvrConf configures what location the client will proxy to, or what +// plugin will be used. +type LocalSvrConf struct { + // LocalIp specifies the IP address or host name to proxy to. + LocalIp string `json:"local_ip"` + // LocalPort specifies the port to proxy to. + LocalPort int `json:"local_port"` + + // Plugin specifies what plugin should be used for proxying. If this value + // is set, the LocalIp and LocalPort values will be ignored. By default, + // this value is "". + Plugin string `json:"plugin"` + // PluginParams specify parameters to be passed to the plugin, if one is + // being used. By default, this value is an empty map. + PluginParams map[string]string `json:"plugin_params"` +} + +func (cfg *LocalSvrConf) compare(cmp *LocalSvrConf) bool { + if cfg.LocalIp != cmp.LocalIp || + cfg.LocalPort != cmp.LocalPort { + return false + } + if cfg.Plugin != cmp.Plugin || + len(cfg.PluginParams) != len(cmp.PluginParams) { + return false + } + for k, v := range cfg.PluginParams { + value, ok := cmp.PluginParams[k] + if !ok || v != value { + return false + } + } + return true +} + +func (cfg *LocalSvrConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + cfg.Plugin = section["plugin"] + cfg.PluginParams = make(map[string]string) + if cfg.Plugin != "" { + // get params begin with "plugin_" + for k, v := range section { + if strings.HasPrefix(k, "plugin_") { + cfg.PluginParams[k] = v + } + } + } else { + if cfg.LocalIp = section["local_ip"]; cfg.LocalIp == "" { + cfg.LocalIp = "127.0.0.1" + } + + if tmpStr, ok := section["local_port"]; ok { + if cfg.LocalPort, err = strconv.Atoi(tmpStr); err != nil { + return fmt.Errorf("Parse conf error: proxy [%s] local_port error", name) + } + } else { + return fmt.Errorf("Parse conf error: proxy [%s] local_port not found", name) + } + } + return +} + +func (cfg *LocalSvrConf) checkForCli() (err error) { + if cfg.Plugin == "" { + if cfg.LocalIp == "" { + err = fmt.Errorf("local ip or plugin is required") + return + } + if cfg.LocalPort <= 0 { + err = fmt.Errorf("error local_port") + return + } + } + return +} + +// HealthCheckConf configures health checking. This can be useful for load +// balancing purposes to detect and remove proxies to failing services. +type HealthCheckConf struct { + // HealthCheckType specifies what protocol to use for health checking. + // Valid values include "tcp", "http", and "". If this value is "", health + // checking will not be performed. By default, this value is "". + // + // If the type is "tcp", a connection will be attempted to the target + // server. If a connection cannot be established, the health check fails. + // + // If the type is "http", a GET request will be made to the endpoint + // specified by HealthCheckUrl. If the response is not a 200, the health + // check fails. + HealthCheckType string `json:"health_check_type"` // tcp | http + // HealthCheckTimeoutS specifies the number of seconds to wait for a health + // check attempt to connect. If the timeout is reached, this counts as a + // health check failure. By default, this value is 3. + HealthCheckTimeoutS int `json:"health_check_timeout_s"` + // HealthCheckMaxFailed specifies the number of allowed failures before the + // proxy is stopped. By default, this value is 1. + HealthCheckMaxFailed int `json:"health_check_max_failed"` + // HealthCheckIntervalS specifies the time in seconds between health + // checks. By default, this value is 10. + HealthCheckIntervalS int `json:"health_check_interval_s"` + // HealthCheckUrl specifies the address to send health checks to if the + // health check type is "http". + HealthCheckUrl string `json:"health_check_url"` + // HealthCheckAddr specifies the address to connect to if the health check + // type is "tcp". + HealthCheckAddr string `json:"-"` +} + +func (cfg *HealthCheckConf) compare(cmp *HealthCheckConf) bool { + if cfg.HealthCheckType != cmp.HealthCheckType || + cfg.HealthCheckTimeoutS != cmp.HealthCheckTimeoutS || + cfg.HealthCheckMaxFailed != cmp.HealthCheckMaxFailed || + cfg.HealthCheckIntervalS != cmp.HealthCheckIntervalS || + cfg.HealthCheckUrl != cmp.HealthCheckUrl { + return false + } + return true +} + +func (cfg *HealthCheckConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + cfg.HealthCheckType = section["health_check_type"] + cfg.HealthCheckUrl = section["health_check_url"] + + if tmpStr, ok := section["health_check_timeout_s"]; ok { + if cfg.HealthCheckTimeoutS, err = strconv.Atoi(tmpStr); err != nil { + return fmt.Errorf("Parse conf error: proxy [%s] health_check_timeout_s error", name) + } + } + + if tmpStr, ok := section["health_check_max_failed"]; ok { + if cfg.HealthCheckMaxFailed, err = strconv.Atoi(tmpStr); err != nil { + return fmt.Errorf("Parse conf error: proxy [%s] health_check_max_failed error", name) + } + } + + if tmpStr, ok := section["health_check_interval_s"]; ok { + if cfg.HealthCheckIntervalS, err = strconv.Atoi(tmpStr); err != nil { + return fmt.Errorf("Parse conf error: proxy [%s] health_check_interval_s error", name) + } + } + return +} + +func (cfg *HealthCheckConf) checkForCli() error { + if cfg.HealthCheckType != "" && cfg.HealthCheckType != "tcp" && cfg.HealthCheckType != "http" { + return fmt.Errorf("unsupport health check type") + } + if cfg.HealthCheckType != "" { + if cfg.HealthCheckType == "http" && cfg.HealthCheckUrl == "" { + return fmt.Errorf("health_check_url is required for health check type 'http'") + } + } + return nil +} + +// TCP +type TcpProxyConf struct { + BaseProxyConf + BindInfoConf +} + +func (cfg *TcpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*TcpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) { + return false + } + return true +} + +func (cfg *TcpProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.UnmarshalFromMsg(pMsg) + cfg.BindInfoConf.UnmarshalFromMsg(pMsg) +} + +func (cfg *TcpProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseProxyConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + if err = cfg.BindInfoConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + return +} + +func (cfg *TcpProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.MarshalToMsg(pMsg) + cfg.BindInfoConf.MarshalToMsg(pMsg) +} + +func (cfg *TcpProxyConf) CheckForCli() (err error) { + if err = cfg.BaseProxyConf.checkForCli(); err != nil { + return err + } + return +} + +func (cfg *TcpProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { return nil } + +// UDP +type UdpProxyConf struct { + BaseProxyConf + BindInfoConf +} + +func (cfg *UdpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*UdpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.BindInfoConf.compare(&cmpConf.BindInfoConf) { + return false + } + return true +} + +func (cfg *UdpProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.UnmarshalFromMsg(pMsg) + cfg.BindInfoConf.UnmarshalFromMsg(pMsg) +} + +func (cfg *UdpProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseProxyConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + if err = cfg.BindInfoConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + return +} + +func (cfg *UdpProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.MarshalToMsg(pMsg) + cfg.BindInfoConf.MarshalToMsg(pMsg) +} + +func (cfg *UdpProxyConf) CheckForCli() (err error) { + if err = cfg.BaseProxyConf.checkForCli(); err != nil { + return + } + return +} + +func (cfg *UdpProxyConf) CheckForSvr(serverCfg ServerCommonConf) error { return nil } + +// HTTP +type HttpProxyConf struct { + BaseProxyConf + DomainConf + + Locations []string `json:"locations"` + HttpUser string `json:"http_user"` + HttpPwd string `json:"http_pwd"` + HostHeaderRewrite string `json:"host_header_rewrite"` + Headers map[string]string `json:"headers"` +} + +func (cfg *HttpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*HttpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.DomainConf.compare(&cmpConf.DomainConf) || + strings.Join(cfg.Locations, " ") != strings.Join(cmpConf.Locations, " ") || + cfg.HostHeaderRewrite != cmpConf.HostHeaderRewrite || + cfg.HttpUser != cmpConf.HttpUser || + cfg.HttpPwd != cmpConf.HttpPwd || + len(cfg.Headers) != len(cmpConf.Headers) { + return false + } + + for k, v := range cfg.Headers { + if v2, ok := cmpConf.Headers[k]; !ok { + return false + } else { + if v != v2 { + return false + } + } + } + return true +} + +func (cfg *HttpProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.UnmarshalFromMsg(pMsg) + cfg.DomainConf.UnmarshalFromMsg(pMsg) + + cfg.Locations = pMsg.Locations + cfg.HostHeaderRewrite = pMsg.HostHeaderRewrite + cfg.HttpUser = pMsg.HttpUser + cfg.HttpPwd = pMsg.HttpPwd + cfg.Headers = pMsg.Headers +} + +func (cfg *HttpProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseProxyConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + if err = cfg.DomainConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + + var ( + tmpStr string + ok bool + ) + if tmpStr, ok = section["locations"]; ok { + cfg.Locations = strings.Split(tmpStr, ",") + } else { + cfg.Locations = []string{""} + } + + cfg.HostHeaderRewrite = section["host_header_rewrite"] + cfg.HttpUser = section["http_user"] + cfg.HttpPwd = section["http_pwd"] + cfg.Headers = make(map[string]string) + + for k, v := range section { + if strings.HasPrefix(k, "header_") { + cfg.Headers[strings.TrimPrefix(k, "header_")] = v + } + } + return +} + +func (cfg *HttpProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.MarshalToMsg(pMsg) + cfg.DomainConf.MarshalToMsg(pMsg) + + pMsg.Locations = cfg.Locations + pMsg.HostHeaderRewrite = cfg.HostHeaderRewrite + pMsg.HttpUser = cfg.HttpUser + pMsg.HttpPwd = cfg.HttpPwd + pMsg.Headers = cfg.Headers +} + +func (cfg *HttpProxyConf) CheckForCli() (err error) { + if err = cfg.BaseProxyConf.checkForCli(); err != nil { + return + } + if err = cfg.DomainConf.checkForCli(); err != nil { + return + } + return +} + +func (cfg *HttpProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { + if serverCfg.VhostHttpPort == 0 { + return fmt.Errorf("type [http] not support when vhost_http_port is not set") + } + if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil { + err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) + return + } + return +} + +// HTTPS +type HttpsProxyConf struct { + BaseProxyConf + DomainConf +} + +func (cfg *HttpsProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*HttpsProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.DomainConf.compare(&cmpConf.DomainConf) { + return false + } + return true +} + +func (cfg *HttpsProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.UnmarshalFromMsg(pMsg) + cfg.DomainConf.UnmarshalFromMsg(pMsg) +} + +func (cfg *HttpsProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseProxyConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + if err = cfg.DomainConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + return +} + +func (cfg *HttpsProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.MarshalToMsg(pMsg) + cfg.DomainConf.MarshalToMsg(pMsg) +} + +func (cfg *HttpsProxyConf) CheckForCli() (err error) { + if err = cfg.BaseProxyConf.checkForCli(); err != nil { + return + } + if err = cfg.DomainConf.checkForCli(); err != nil { + return + } + return +} + +func (cfg *HttpsProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { + if serverCfg.VhostHttpsPort == 0 { + return fmt.Errorf("type [https] not support when vhost_https_port is not set") + } + if err = cfg.DomainConf.checkForSvr(serverCfg); err != nil { + err = fmt.Errorf("proxy [%s] domain conf check error: %v", cfg.ProxyName, err) + return + } + return +} + +// STCP +type StcpProxyConf struct { + BaseProxyConf + + Role string `json:"role"` + Sk string `json:"sk"` +} + +func (cfg *StcpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*StcpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + cfg.Role != cmpConf.Role || + cfg.Sk != cmpConf.Sk { + return false + } + return true +} + +// Only for role server. +func (cfg *StcpProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.UnmarshalFromMsg(pMsg) + cfg.Sk = pMsg.Sk +} + +func (cfg *StcpProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseProxyConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + + cfg.Role = section["role"] + if cfg.Role != "server" { + return fmt.Errorf("Parse conf error: proxy [%s] incorrect role [%s]", name, cfg.Role) + } + + cfg.Sk = section["sk"] + + if err = cfg.LocalSvrConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + return +} + +func (cfg *StcpProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.MarshalToMsg(pMsg) + pMsg.Sk = cfg.Sk +} + +func (cfg *StcpProxyConf) CheckForCli() (err error) { + if err = cfg.BaseProxyConf.checkForCli(); err != nil { + return + } + if cfg.Role != "server" { + err = fmt.Errorf("role should be 'server'") + return + } + return +} + +func (cfg *StcpProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { + return +} + +// XTCP +type XtcpProxyConf struct { + BaseProxyConf + + Role string `json:"role"` + Sk string `json:"sk"` +} + +func (cfg *XtcpProxyConf) Compare(cmp ProxyConf) bool { + cmpConf, ok := cmp.(*XtcpProxyConf) + if !ok { + return false + } + + if !cfg.BaseProxyConf.compare(&cmpConf.BaseProxyConf) || + !cfg.LocalSvrConf.compare(&cmpConf.LocalSvrConf) || + cfg.Role != cmpConf.Role || + cfg.Sk != cmpConf.Sk { + return false + } + return true +} + +// Only for role server. +func (cfg *XtcpProxyConf) UnmarshalFromMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.UnmarshalFromMsg(pMsg) + cfg.Sk = pMsg.Sk +} + +func (cfg *XtcpProxyConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseProxyConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + + cfg.Role = section["role"] + if cfg.Role != "server" { + return fmt.Errorf("Parse conf error: proxy [%s] incorrect role [%s]", name, cfg.Role) + } + + cfg.Sk = section["sk"] + + if err = cfg.LocalSvrConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + return +} + +func (cfg *XtcpProxyConf) MarshalToMsg(pMsg *msg.NewProxy) { + cfg.BaseProxyConf.MarshalToMsg(pMsg) + pMsg.Sk = cfg.Sk +} + +func (cfg *XtcpProxyConf) CheckForCli() (err error) { + if err = cfg.BaseProxyConf.checkForCli(); err != nil { + return + } + if cfg.Role != "server" { + err = fmt.Errorf("role should be 'server'") + return + } + return +} + +func (cfg *XtcpProxyConf) CheckForSvr(serverCfg ServerCommonConf) (err error) { + return +} + +func ParseRangeSection(name string, section ini.Section) (sections map[string]ini.Section, err error) { + localPorts, errRet := util.ParseRangeNumbers(section["local_port"]) + if errRet != nil { + err = fmt.Errorf("Parse conf error: range section [%s] local_port invalid, %v", name, errRet) + return + } + + remotePorts, errRet := util.ParseRangeNumbers(section["remote_port"]) + if errRet != nil { + err = fmt.Errorf("Parse conf error: range section [%s] remote_port invalid, %v", name, errRet) + return + } + if len(localPorts) != len(remotePorts) { + err = fmt.Errorf("Parse conf error: range section [%s] local ports number should be same with remote ports number", name) + return + } + if len(localPorts) == 0 { + err = fmt.Errorf("Parse conf error: range section [%s] local_port and remote_port is necessary", name) + return + } + + sections = make(map[string]ini.Section) + for i, port := range localPorts { + subName := fmt.Sprintf("%s_%d", name, i) + subSection := copySection(section) + subSection["local_port"] = fmt.Sprintf("%d", port) + subSection["remote_port"] = fmt.Sprintf("%d", remotePorts[i]) + sections[subName] = subSection + } + return +} + +// if len(startProxy) is 0, start all +// otherwise just start proxies in startProxy map +func LoadAllConfFromIni(prefix string, content string, startProxy map[string]struct{}) ( + proxyConfs map[string]ProxyConf, visitorConfs map[string]VisitorConf, err error) { + + conf, errRet := ini.Load(strings.NewReader(content)) + if errRet != nil { + err = errRet + return + } + + if prefix != "" { + prefix += "." + } + + startAll := true + if len(startProxy) > 0 { + startAll = false + } + proxyConfs = make(map[string]ProxyConf) + visitorConfs = make(map[string]VisitorConf) + for name, section := range conf { + if name == "common" { + continue + } + + _, shouldStart := startProxy[name] + if !startAll && !shouldStart { + continue + } + + subSections := make(map[string]ini.Section) + + if strings.HasPrefix(name, "range:") { + // range section + rangePrefix := strings.TrimSpace(strings.TrimPrefix(name, "range:")) + subSections, err = ParseRangeSection(rangePrefix, section) + if err != nil { + return + } + } else { + subSections[name] = section + } + + for subName, subSection := range subSections { + if subSection["role"] == "" { + subSection["role"] = "server" + } + role := subSection["role"] + if role == "server" { + cfg, errRet := NewProxyConfFromIni(prefix, subName, subSection) + if errRet != nil { + err = errRet + return + } + proxyConfs[prefix+subName] = cfg + } else if role == "visitor" { + cfg, errRet := NewVisitorConfFromIni(prefix, subName, subSection) + if errRet != nil { + err = errRet + return + } + visitorConfs[prefix+subName] = cfg + } else { + err = fmt.Errorf("role should be 'server' or 'visitor'") + return + } + } + } + return +} + +func copySection(section ini.Section) (out ini.Section) { + out = make(ini.Section) + for k, v := range section { + out[k] = v + } + return +} diff --git a/models/config/server_common.go b/models/config/server_common.go new file mode 100644 index 0000000..256b990 --- /dev/null +++ b/models/config/server_common.go @@ -0,0 +1,403 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "strconv" + "strings" + + ini "github.com/vaughan0/go-ini" + + plugin "github.com/charlesbao/frpc/models/plugin/server" + "github.com/charlesbao/frpc/utils/util" +) + +// ServerCommonConf contains information for a server service. It is +// recommended to use GetDefaultServerConf instead of creating this object +// directly, so that all unspecified fields have reasonable default values. +type ServerCommonConf struct { + // BindAddr specifies the address that the server binds to. By default, + // this value is "0.0.0.0". + BindAddr string `json:"bind_addr"` + // BindPort specifies the port that the server listens on. By default, this + // value is 7000. + BindPort int `json:"bind_port"` + // BindUdpPort specifies the UDP port that the server listens on. If this + // value is 0, the server will not listen for UDP connections. By default, + // this value is 0 + BindUdpPort int `json:"bind_udp_port"` + // BindKcpPort specifies the KCP port that the server listens on. If this + // value is 0, the server will not listen for KCP connections. By default, + // this value is 0. + KcpBindPort int `json:"kcp_bind_port"` + // ProxyBindAddr specifies the address that the proxy binds to. This value + // may be the same as BindAddr. By default, this value is "0.0.0.0". + ProxyBindAddr string `json:"proxy_bind_addr"` + + // VhostHttpPort specifies the port that the server listens for HTTP Vhost + // requests. If this value is 0, the server will not listen for HTTP + // requests. By default, this value is 0. + VhostHttpPort int `json:"vhost_http_port"` + + // VhostHttpsPort specifies the port that the server listens for HTTPS + // Vhost requests. If this value is 0, the server will not listen for HTTPS + // requests. By default, this value is 0. + VhostHttpsPort int `json:"vhost_https_port"` + + // VhostHttpTimeout specifies the response header timeout for the Vhost + // HTTP server, in seconds. By default, this value is 60. + VhostHttpTimeout int64 `json:"vhost_http_timeout"` + + // DashboardAddr specifies the address that the dashboard binds to. By + // default, this value is "0.0.0.0". + DashboardAddr string `json:"dashboard_addr"` + + // DashboardPort specifies the port that the dashboard listens on. If this + // value is 0, the dashboard will not be started. By default, this value is + // 0. + DashboardPort int `json:"dashboard_port"` + // DashboardUser specifies the username that the dashboard will use for + // login. By default, this value is "admin". + DashboardUser string `json:"dashboard_user"` + // DashboardUser specifies the password that the dashboard will use for + // login. By default, this value is "admin". + DashboardPwd string `json:"dashboard_pwd"` + // AssetsDir specifies the local directory that the dashboard will load + // resources from. If this value is "", assets will be loaded from the + // bundled executable using statik. By default, this value is "". + AssetsDir string `json:"asserts_dir"` + // LogFile specifies a file where logs will be written to. This value will + // only be used if LogWay is set appropriately. By default, this value is + // "console". + LogFile string `json:"log_file"` + // LogWay specifies the way logging is managed. Valid values are "console" + // or "file". If "console" is used, logs will be printed to stdout. If + // "file" is used, logs will be printed to LogFile. By default, this value + // is "console". + LogWay string `json:"log_way"` + // LogLevel specifies the minimum log level. Valid values are "trace", + // "debug", "info", "warn", and "error". By default, this value is "info". + LogLevel string `json:"log_level"` + // LogMaxDays specifies the maximum number of days to store log information + // before deletion. This is only used if LogWay == "file". By default, this + // value is 0. + LogMaxDays int64 `json:"log_max_days"` + // DisableLogColor disables log colors when LogWay == "console" when set to + // true. By default, this value is false. + DisableLogColor bool `json:"disable_log_color"` + // Token specifies the authorization token used to authenticate keys + // received from clients. Clients must have a matching token to be + // authorized to use the server. By default, this value is "". + Token string `json:"token"` + // SubDomainHost specifies the domain that will be attached to sub-domains + // requested by the client when using Vhost proxying. For example, if this + // value is set to "frps.com" and the client requested the subdomain + // "test", the resulting URL would be "test.frps.com". By default, this + // value is "". + SubDomainHost string `json:"subdomain_host"` + // TcpMux toggles TCP stream multiplexing. This allows multiple requests + // from a client to share a single TCP connection. By default, this value + // is true. + TcpMux bool `json:"tcp_mux"` + // Custom404Page specifies a path to a custom 404 page to display. If this + // value is "", a default page will be displayed. By default, this value is + // "". + Custom404Page string `json:"custom_404_page"` + + // AllowPorts specifies a set of ports that clients are able to proxy to. + // If the length of this value is 0, all ports are allowed. By default, + // this value is an empty set. + AllowPorts map[int]struct{} + // MaxPoolCount specifies the maximum pool size for each proxy. By default, + // this value is 5. + MaxPoolCount int64 `json:"max_pool_count"` + // MaxPortsPerClient specifies the maximum number of ports a single client + // may proxy to. If this value is 0, no limit will be applied. By default, + // this value is 0. + MaxPortsPerClient int64 `json:"max_ports_per_client"` + // HeartBeatTimeout specifies the maximum time to wait for a heartbeat + // before terminating the connection. It is not recommended to change this + // value. By default, this value is 90. + HeartBeatTimeout int64 `json:"heart_beat_timeout"` + // UserConnTimeout specifies the maximum time to wait for a work + // connection. By default, this value is 10. + UserConnTimeout int64 `json:"user_conn_timeout"` + // HTTPPlugins specify the server plugins support HTTP protocol. + HTTPPlugins map[string]plugin.HTTPPluginOptions `json:"http_plugins"` +} + +// GetDefaultServerConf returns a server configuration with reasonable +// defaults. +func GetDefaultServerConf() ServerCommonConf { + AllowPorts := make(map[int]struct{}) + ports, _ := util.ParseRangeNumbers("20000-60000") + for _, port := range ports { + AllowPorts[int(port)] = struct{}{} + } + return ServerCommonConf{ + BindAddr: "0.0.0.0", + BindPort: 23333, + BindUdpPort: 0, + KcpBindPort: 0, + ProxyBindAddr: "0.0.0.0", + VhostHttpPort: 0, + VhostHttpsPort: 0, + VhostHttpTimeout: 60, + DashboardAddr: "0.0.0.0", + DashboardPort: 23334, + DashboardUser: "admin", + DashboardPwd: "chaos54319", + AssetsDir: "", + LogFile: "console", + LogWay: "console", + LogLevel: "warn", + LogMaxDays: 1, + DisableLogColor: false, + Token: "", + SubDomainHost: "", + TcpMux: true, + AllowPorts: AllowPorts, + MaxPoolCount: 5, + MaxPortsPerClient: 0, + HeartBeatTimeout: 90, + UserConnTimeout: 10, + Custom404Page: "", + HTTPPlugins: make(map[string]plugin.HTTPPluginOptions), + } +} + +// UnmarshalServerConfFromIni parses the contents of a server configuration ini +// file and returns the resulting server configuration. +func UnmarshalServerConfFromIni(content string) (cfg ServerCommonConf, err error) { + cfg = GetDefaultServerConf() + + conf, err := ini.Load(strings.NewReader(content)) + if err != nil { + err = fmt.Errorf("parse ini conf file error: %v", err) + return ServerCommonConf{}, err + } + + UnmarshalPluginsFromIni(conf, &cfg) + + var ( + tmpStr string + ok bool + v int64 + ) + if tmpStr, ok = conf.Get("common", "bind_addr"); ok { + cfg.BindAddr = tmpStr + } + + if tmpStr, ok = conf.Get("common", "bind_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid bind_port") + return + } else { + cfg.BindPort = int(v) + } + } + + if tmpStr, ok = conf.Get("common", "bind_udp_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid bind_udp_port") + return + } else { + cfg.BindUdpPort = int(v) + } + } + + if tmpStr, ok = conf.Get("common", "kcp_bind_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid kcp_bind_port") + return + } else { + cfg.KcpBindPort = int(v) + } + } + + if tmpStr, ok = conf.Get("common", "proxy_bind_addr"); ok { + cfg.ProxyBindAddr = tmpStr + } else { + cfg.ProxyBindAddr = cfg.BindAddr + } + + if tmpStr, ok = conf.Get("common", "vhost_http_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid vhost_http_port") + return + } else { + cfg.VhostHttpPort = int(v) + } + } + + if tmpStr, ok = conf.Get("common", "vhost_https_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid vhost_https_port") + return + } else { + cfg.VhostHttpsPort = int(v) + } + } + + if tmpStr, ok = conf.Get("common", "vhost_http_timeout"); ok { + v, errRet := strconv.ParseInt(tmpStr, 10, 64) + if errRet != nil || v < 0 { + err = fmt.Errorf("Parse conf error: invalid vhost_http_timeout") + return + } else { + cfg.VhostHttpTimeout = v + } + } + + if tmpStr, ok = conf.Get("common", "dashboard_addr"); ok { + cfg.DashboardAddr = tmpStr + } else { + cfg.DashboardAddr = cfg.BindAddr + } + + if tmpStr, ok = conf.Get("common", "dashboard_port"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid dashboard_port") + return + } else { + cfg.DashboardPort = int(v) + } + } + + if tmpStr, ok = conf.Get("common", "dashboard_user"); ok { + cfg.DashboardUser = tmpStr + } + + if tmpStr, ok = conf.Get("common", "dashboard_pwd"); ok { + cfg.DashboardPwd = tmpStr + } + + if tmpStr, ok = conf.Get("common", "assets_dir"); ok { + cfg.AssetsDir = tmpStr + } + + if tmpStr, ok = conf.Get("common", "log_file"); ok { + cfg.LogFile = tmpStr + if cfg.LogFile == "console" { + cfg.LogWay = "console" + } else { + cfg.LogWay = "file" + } + } + + if tmpStr, ok = conf.Get("common", "log_level"); ok { + cfg.LogLevel = tmpStr + } + + if tmpStr, ok = conf.Get("common", "log_max_days"); ok { + v, err = strconv.ParseInt(tmpStr, 10, 64) + if err == nil { + cfg.LogMaxDays = v + } + } + + if tmpStr, ok = conf.Get("common", "disable_log_color"); ok && tmpStr == "true" { + cfg.DisableLogColor = true + } + + cfg.Token, _ = conf.Get("common", "token") + + if allowPortsStr, ok := conf.Get("common", "allow_ports"); ok { + // e.g. 1000-2000,2001,2002,3000-4000 + ports, errRet := util.ParseRangeNumbers(allowPortsStr) + if errRet != nil { + err = fmt.Errorf("Parse conf error: allow_ports: %v", errRet) + return + } + + for _, port := range ports { + cfg.AllowPorts[int(port)] = struct{}{} + } + } + + if tmpStr, ok = conf.Get("common", "max_pool_count"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid max_pool_count") + return + } else { + if v < 0 { + err = fmt.Errorf("Parse conf error: invalid max_pool_count") + return + } + cfg.MaxPoolCount = v + } + } + + if tmpStr, ok = conf.Get("common", "max_ports_per_client"); ok { + if v, err = strconv.ParseInt(tmpStr, 10, 64); err != nil { + err = fmt.Errorf("Parse conf error: invalid max_ports_per_client") + return + } else { + if v < 0 { + err = fmt.Errorf("Parse conf error: invalid max_ports_per_client") + return + } + cfg.MaxPortsPerClient = v + } + } + + if tmpStr, ok = conf.Get("common", "subdomain_host"); ok { + cfg.SubDomainHost = strings.ToLower(strings.TrimSpace(tmpStr)) + } + + if tmpStr, ok = conf.Get("common", "tcp_mux"); ok && tmpStr == "false" { + cfg.TcpMux = false + } else { + cfg.TcpMux = true + } + + if tmpStr, ok = conf.Get("common", "custom_404_page"); ok { + cfg.Custom404Page = tmpStr + } + + if tmpStr, ok = conf.Get("common", "heartbeat_timeout"); ok { + v, errRet := strconv.ParseInt(tmpStr, 10, 64) + if errRet != nil { + err = fmt.Errorf("Parse conf error: heartbeat_timeout is incorrect") + return + } else { + cfg.HeartBeatTimeout = v + } + } + return +} + +func UnmarshalPluginsFromIni(sections ini.File, cfg *ServerCommonConf) { + for name, section := range sections { + if strings.HasPrefix(name, "plugin.") { + name = strings.TrimSpace(strings.TrimPrefix(name, "plugin.")) + options := plugin.HTTPPluginOptions{ + Name: name, + Addr: section["addr"], + Path: section["path"], + Ops: strings.Split(section["ops"], ","), + } + for i, _ := range options.Ops { + options.Ops[i] = strings.TrimSpace(options.Ops[i]) + } + cfg.HTTPPlugins[name] = options + } + } +} + +func (cfg *ServerCommonConf) Check() (err error) { + return +} diff --git a/models/config/types.go b/models/config/types.go new file mode 100644 index 0000000..87c240d --- /dev/null +++ b/models/config/types.go @@ -0,0 +1,112 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "encoding/json" + "errors" + "strconv" + "strings" +) + +const ( + MB = 1024 * 1024 + KB = 1024 +) + +type BandwidthQuantity struct { + s string // MB or KB + + i int64 // bytes +} + +func NewBandwidthQuantity(s string) (BandwidthQuantity, error) { + q := BandwidthQuantity{} + err := q.UnmarshalString(s) + if err != nil { + return q, err + } + return q, nil +} + +func (q *BandwidthQuantity) Equal(u *BandwidthQuantity) bool { + if q == nil && u == nil { + return true + } + if q != nil && u != nil { + return q.i == u.i + } + return false +} + +func (q *BandwidthQuantity) String() string { + return q.s +} + +func (q *BandwidthQuantity) UnmarshalString(s string) error { + s = strings.TrimSpace(s) + if s == "" { + return nil + } + + var ( + base int64 + f float64 + err error + ) + if strings.HasSuffix(s, "MB") { + base = MB + fstr := strings.TrimSuffix(s, "MB") + f, err = strconv.ParseFloat(fstr, 64) + if err != nil { + return err + } + } else if strings.HasSuffix(s, "KB") { + base = KB + fstr := strings.TrimSuffix(s, "KB") + f, err = strconv.ParseFloat(fstr, 64) + if err != nil { + return err + } + } else { + return errors.New("unit not support") + } + + q.s = s + q.i = int64(f * float64(base)) + return nil +} + +func (q *BandwidthQuantity) UnmarshalJSON(b []byte) error { + if len(b) == 4 && string(b) == "null" { + return nil + } + + var str string + err := json.Unmarshal(b, &str) + if err != nil { + return err + } + + return q.UnmarshalString(str) +} + +func (q *BandwidthQuantity) MarshalJSON() ([]byte, error) { + return []byte("\"" + q.s + "\""), nil +} + +func (q *BandwidthQuantity) Bytes() int64 { + return q.i +} diff --git a/models/config/types_test.go b/models/config/types_test.go new file mode 100644 index 0000000..ab03dfd --- /dev/null +++ b/models/config/types_test.go @@ -0,0 +1,40 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +type Wrap struct { + B BandwidthQuantity `json:"b"` + Int int `json:"int"` +} + +func TestBandwidthQuantity(t *testing.T) { + assert := assert.New(t) + + var w Wrap + err := json.Unmarshal([]byte(`{"b":"1KB","int":5}`), &w) + assert.NoError(err) + assert.EqualValues(1*KB, w.B.Bytes()) + + buf, err := json.Marshal(&w) + assert.NoError(err) + assert.Equal(`{"b":"1KB","int":5}`, string(buf)) +} diff --git a/models/config/value.go b/models/config/value.go new file mode 100644 index 0000000..3457024 --- /dev/null +++ b/models/config/value.go @@ -0,0 +1,64 @@ +package config + +import ( + "bytes" + "io/ioutil" + "os" + "strings" + "text/template" +) + +var ( + glbEnvs map[string]string +) + +func init() { + glbEnvs = make(map[string]string) + envs := os.Environ() + for _, env := range envs { + kv := strings.Split(env, "=") + if len(kv) != 2 { + continue + } + glbEnvs[kv[0]] = kv[1] + } +} + +type Values struct { + Envs map[string]string // environment vars +} + +func GetValues() *Values { + return &Values{ + Envs: glbEnvs, + } +} + +func RenderContent(in string) (out string, err error) { + tmpl, errRet := template.New("frp").Parse(in) + if errRet != nil { + err = errRet + return + } + + buffer := bytes.NewBufferString("") + v := GetValues() + err = tmpl.Execute(buffer, v) + if err != nil { + return + } + out = buffer.String() + return +} + +func GetRenderedConfFromFile(path string) (out string, err error) { + var b []byte + b, err = ioutil.ReadFile(path) + if err != nil { + return + } + content := string(b) + + out, err = RenderContent(content) + return +} diff --git a/models/config/visitor.go b/models/config/visitor.go new file mode 100644 index 0000000..1154faa --- /dev/null +++ b/models/config/visitor.go @@ -0,0 +1,213 @@ +// Copyright 2018 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/charlesbao/frpc/models/consts" + + ini "github.com/vaughan0/go-ini" +) + +var ( + visitorConfTypeMap map[string]reflect.Type +) + +func init() { + visitorConfTypeMap = make(map[string]reflect.Type) + visitorConfTypeMap[consts.StcpProxy] = reflect.TypeOf(StcpVisitorConf{}) + visitorConfTypeMap[consts.XtcpProxy] = reflect.TypeOf(XtcpVisitorConf{}) +} + +type VisitorConf interface { + GetBaseInfo() *BaseVisitorConf + Compare(cmp VisitorConf) bool + UnmarshalFromIni(prefix string, name string, section ini.Section) error + Check() error +} + +func NewVisitorConfByType(cfgType string) VisitorConf { + v, ok := visitorConfTypeMap[cfgType] + if !ok { + return nil + } + cfg := reflect.New(v).Interface().(VisitorConf) + return cfg +} + +func NewVisitorConfFromIni(prefix string, name string, section ini.Section) (cfg VisitorConf, err error) { + cfgType := section["type"] + if cfgType == "" { + err = fmt.Errorf("visitor [%s] type shouldn't be empty", name) + return + } + cfg = NewVisitorConfByType(cfgType) + if cfg == nil { + err = fmt.Errorf("visitor [%s] type [%s] error", name, cfgType) + return + } + if err = cfg.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + if err = cfg.Check(); err != nil { + return + } + return +} + +type BaseVisitorConf struct { + ProxyName string `json:"proxy_name"` + ProxyType string `json:"proxy_type"` + UseEncryption bool `json:"use_encryption"` + UseCompression bool `json:"use_compression"` + Role string `json:"role"` + Sk string `json:"sk"` + ServerName string `json:"server_name"` + BindAddr string `json:"bind_addr"` + BindPort int `json:"bind_port"` +} + +func (cfg *BaseVisitorConf) GetBaseInfo() *BaseVisitorConf { + return cfg +} + +func (cfg *BaseVisitorConf) compare(cmp *BaseVisitorConf) bool { + if cfg.ProxyName != cmp.ProxyName || + cfg.ProxyType != cmp.ProxyType || + cfg.UseEncryption != cmp.UseEncryption || + cfg.UseCompression != cmp.UseCompression || + cfg.Role != cmp.Role || + cfg.Sk != cmp.Sk || + cfg.ServerName != cmp.ServerName || + cfg.BindAddr != cmp.BindAddr || + cfg.BindPort != cmp.BindPort { + return false + } + return true +} + +func (cfg *BaseVisitorConf) check() (err error) { + if cfg.Role != "visitor" { + err = fmt.Errorf("invalid role") + return + } + if cfg.BindAddr == "" { + err = fmt.Errorf("bind_addr shouldn't be empty") + return + } + if cfg.BindPort <= 0 { + err = fmt.Errorf("bind_port is required") + return + } + return +} + +func (cfg *BaseVisitorConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + var ( + tmpStr string + ok bool + ) + cfg.ProxyName = prefix + name + cfg.ProxyType = section["type"] + + if tmpStr, ok = section["use_encryption"]; ok && tmpStr == "true" { + cfg.UseEncryption = true + } + if tmpStr, ok = section["use_compression"]; ok && tmpStr == "true" { + cfg.UseCompression = true + } + + cfg.Role = section["role"] + if cfg.Role != "visitor" { + return fmt.Errorf("Parse conf error: proxy [%s] incorrect role [%s]", name, cfg.Role) + } + cfg.Sk = section["sk"] + cfg.ServerName = prefix + section["server_name"] + if cfg.BindAddr = section["bind_addr"]; cfg.BindAddr == "" { + cfg.BindAddr = "127.0.0.1" + } + + if tmpStr, ok = section["bind_port"]; ok { + if cfg.BindPort, err = strconv.Atoi(tmpStr); err != nil { + return fmt.Errorf("Parse conf error: proxy [%s] bind_port incorrect", name) + } + } else { + return fmt.Errorf("Parse conf error: proxy [%s] bind_port not found", name) + } + return nil +} + +type StcpVisitorConf struct { + BaseVisitorConf +} + +func (cfg *StcpVisitorConf) Compare(cmp VisitorConf) bool { + cmpConf, ok := cmp.(*StcpVisitorConf) + if !ok { + return false + } + + if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) { + return false + } + return true +} + +func (cfg *StcpVisitorConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseVisitorConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + return +} + +func (cfg *StcpVisitorConf) Check() (err error) { + if err = cfg.BaseVisitorConf.check(); err != nil { + return + } + return +} + +type XtcpVisitorConf struct { + BaseVisitorConf +} + +func (cfg *XtcpVisitorConf) Compare(cmp VisitorConf) bool { + cmpConf, ok := cmp.(*XtcpVisitorConf) + if !ok { + return false + } + + if !cfg.BaseVisitorConf.compare(&cmpConf.BaseVisitorConf) { + return false + } + return true +} + +func (cfg *XtcpVisitorConf) UnmarshalFromIni(prefix string, name string, section ini.Section) (err error) { + if err = cfg.BaseVisitorConf.UnmarshalFromIni(prefix, name, section); err != nil { + return + } + return +} + +func (cfg *XtcpVisitorConf) Check() (err error) { + if err = cfg.BaseVisitorConf.check(); err != nil { + return + } + return +} diff --git a/models/consts/consts.go b/models/consts/consts.go new file mode 100644 index 0000000..9bf5880 --- /dev/null +++ b/models/consts/consts.go @@ -0,0 +1,32 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package consts + +var ( + // proxy status + Idle string = "idle" + Working string = "working" + Closed string = "closed" + Online string = "online" + Offline string = "offline" + + // proxy type + TcpProxy string = "tcp" + UdpProxy string = "udp" + HttpProxy string = "http" + HttpsProxy string = "https" + StcpProxy string = "stcp" + XtcpProxy string = "xtcp" +) diff --git a/models/errors/errors.go b/models/errors/errors.go new file mode 100644 index 0000000..e0f229e --- /dev/null +++ b/models/errors/errors.go @@ -0,0 +1,24 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package errors + +import ( + "errors" +) + +var ( + ErrMsgType = errors.New("message type error") + ErrCtlClosed = errors.New("control is closed") +) diff --git a/models/msg/ctl.go b/models/msg/ctl.go new file mode 100644 index 0000000..0eafdbc --- /dev/null +++ b/models/msg/ctl.go @@ -0,0 +1,46 @@ +// Copyright 2018 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import ( + "io" + + jsonMsg "github.com/fatedier/golib/msg/json" +) + +type Message = jsonMsg.Message + +var ( + msgCtl *jsonMsg.MsgCtl +) + +func init() { + msgCtl = jsonMsg.NewMsgCtl() + for typeByte, msg := range msgTypeMap { + msgCtl.RegisterMsg(typeByte, msg) + } +} + +func ReadMsg(c io.Reader) (msg Message, err error) { + return msgCtl.ReadMsg(c) +} + +func ReadMsgInto(c io.Reader, msg Message) (err error) { + return msgCtl.ReadMsgInto(c, msg) +} + +func WriteMsg(c io.Writer, msg interface{}) (err error) { + return msgCtl.WriteMsg(c, msg) +} diff --git a/models/msg/msg.go b/models/msg/msg.go new file mode 100644 index 0000000..ce41c9e --- /dev/null +++ b/models/msg/msg.go @@ -0,0 +1,185 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package msg + +import "net" + +const ( + TypeLogin = 'o' + TypeLoginResp = '1' + TypeNewProxy = 'p' + TypeNewProxyResp = '2' + TypeCloseProxy = 'c' + TypeNewWorkConn = 'w' + TypeReqWorkConn = 'r' + TypeStartWorkConn = 's' + TypeNewVisitorConn = 'v' + TypeNewVisitorConnResp = '3' + TypePing = 'h' + TypePong = '4' + TypeUdpPacket = 'u' + TypeNatHoleVisitor = 'i' + TypeNatHoleClient = 'n' + TypeNatHoleResp = 'm' + TypeNatHoleClientDetectOK = 'd' + TypeNatHoleSid = '5' +) + +var ( + msgTypeMap = map[byte]interface{}{ + TypeLogin: Login{}, + TypeLoginResp: LoginResp{}, + TypeNewProxy: NewProxy{}, + TypeNewProxyResp: NewProxyResp{}, + TypeCloseProxy: CloseProxy{}, + TypeNewWorkConn: NewWorkConn{}, + TypeReqWorkConn: ReqWorkConn{}, + TypeStartWorkConn: StartWorkConn{}, + TypeNewVisitorConn: NewVisitorConn{}, + TypeNewVisitorConnResp: NewVisitorConnResp{}, + TypePing: Ping{}, + TypePong: Pong{}, + TypeUdpPacket: UdpPacket{}, + TypeNatHoleVisitor: NatHoleVisitor{}, + TypeNatHoleClient: NatHoleClient{}, + TypeNatHoleResp: NatHoleResp{}, + TypeNatHoleClientDetectOK: NatHoleClientDetectOK{}, + TypeNatHoleSid: NatHoleSid{}, + } +) + +// When frpc start, client send this message to login to server. +type Login struct { + Version string `json:"version"` + Hostname string `json:"hostname"` + Os string `json:"os"` + Arch string `json:"arch"` + User string `json:"user"` + PrivilegeKey string `json:"privilege_key"` + Timestamp int64 `json:"timestamp"` + RunId string `json:"run_id"` + Metas map[string]string `json:"metas"` + + // Some global configures. + PoolCount int `json:"pool_count"` +} + +type LoginResp struct { + Version string `json:"version"` + RunId string `json:"run_id"` + ServerUdpPort int `json:"server_udp_port"` + Error string `json:"error"` +} + +// When frpc login success, send this message to frps for running a new proxy. +type NewProxy struct { + ProxyName string `json:"proxy_name"` + ProxyType string `json:"proxy_type"` + UseEncryption bool `json:"use_encryption"` + UseCompression bool `json:"use_compression"` + Group string `json:"group"` + GroupKey string `json:"group_key"` + Metas map[string]string `json:"metas"` + + // tcp and udp only + RemotePort int `json:"remote_port"` + + // http and https only + CustomDomains []string `json:"custom_domains"` + SubDomain string `json:"subdomain"` + Locations []string `json:"locations"` + HttpUser string `json:"http_user"` + HttpPwd string `json:"http_pwd"` + HostHeaderRewrite string `json:"host_header_rewrite"` + Headers map[string]string `json:"headers"` + + // stcp + Sk string `json:"sk"` +} + +type NewProxyResp struct { + ProxyName string `json:"proxy_name"` + RemoteAddr string `json:"remote_addr"` + Error string `json:"error"` +} + +type CloseProxy struct { + ProxyName string `json:"proxy_name"` +} + +type NewWorkConn struct { + RunId string `json:"run_id"` +} + +type ReqWorkConn struct { +} + +type StartWorkConn struct { + ProxyName string `json:"proxy_name"` + SrcAddr string `json:"src_addr"` + DstAddr string `json:"dst_addr"` + SrcPort uint16 `json:"src_port"` + DstPort uint16 `json:"dst_port"` +} + +type NewVisitorConn struct { + ProxyName string `json:"proxy_name"` + SignKey string `json:"sign_key"` + Timestamp int64 `json:"timestamp"` + UseEncryption bool `json:"use_encryption"` + UseCompression bool `json:"use_compression"` +} + +type NewVisitorConnResp struct { + ProxyName string `json:"proxy_name"` + Error string `json:"error"` +} + +type Ping struct { +} + +type Pong struct { +} + +type UdpPacket struct { + Content string `json:"c"` + LocalAddr *net.UDPAddr `json:"l"` + RemoteAddr *net.UDPAddr `json:"r"` +} + +type NatHoleVisitor struct { + ProxyName string `json:"proxy_name"` + SignKey string `json:"sign_key"` + Timestamp int64 `json:"timestamp"` +} + +type NatHoleClient struct { + ProxyName string `json:"proxy_name"` + Sid string `json:"sid"` +} + +type NatHoleResp struct { + Sid string `json:"sid"` + VisitorAddr string `json:"visitor_addr"` + ClientAddr string `json:"client_addr"` + Error string `json:"error"` +} + +type NatHoleClientDetectOK struct { +} + +type NatHoleSid struct { + Sid string `json:"sid"` +} diff --git a/models/nathole/nathole.go b/models/nathole/nathole.go new file mode 100644 index 0000000..c1d3531 --- /dev/null +++ b/models/nathole/nathole.go @@ -0,0 +1,212 @@ +package nathole + +import ( + "bytes" + "fmt" + "net" + "sync" + "time" + + "github.com/charlesbao/frpc/models/msg" + "github.com/charlesbao/frpc/utils/log" + "github.com/charlesbao/frpc/utils/util" + + "github.com/fatedier/golib/errors" + "github.com/fatedier/golib/pool" +) + +// Timeout seconds. +var NatHoleTimeout int64 = 10 + +type SidRequest struct { + Sid string + NotifyCh chan struct{} +} + +type NatHoleController struct { + listener *net.UDPConn + + clientCfgs map[string]*NatHoleClientCfg + sessions map[string]*NatHoleSession + + mu sync.RWMutex +} + +func NewNatHoleController(udpBindAddr string) (nc *NatHoleController, err error) { + addr, err := net.ResolveUDPAddr("udp", udpBindAddr) + if err != nil { + return nil, err + } + lconn, err := net.ListenUDP("udp", addr) + if err != nil { + return nil, err + } + nc = &NatHoleController{ + listener: lconn, + clientCfgs: make(map[string]*NatHoleClientCfg), + sessions: make(map[string]*NatHoleSession), + } + return nc, nil +} + +func (nc *NatHoleController) ListenClient(name string, sk string) (sidCh chan *SidRequest) { + clientCfg := &NatHoleClientCfg{ + Name: name, + Sk: sk, + SidCh: make(chan *SidRequest), + } + nc.mu.Lock() + nc.clientCfgs[name] = clientCfg + nc.mu.Unlock() + return clientCfg.SidCh +} + +func (nc *NatHoleController) CloseClient(name string) { + nc.mu.Lock() + defer nc.mu.Unlock() + delete(nc.clientCfgs, name) +} + +func (nc *NatHoleController) Run() { + for { + buf := pool.GetBuf(1024) + n, raddr, err := nc.listener.ReadFromUDP(buf) + if err != nil { + log.Trace("nat hole listener read from udp error: %v", err) + return + } + + rd := bytes.NewReader(buf[:n]) + rawMsg, err := msg.ReadMsg(rd) + if err != nil { + log.Trace("read nat hole message error: %v", err) + continue + } + + switch m := rawMsg.(type) { + case *msg.NatHoleVisitor: + go nc.HandleVisitor(m, raddr) + case *msg.NatHoleClient: + go nc.HandleClient(m, raddr) + default: + log.Trace("error nat hole message type") + continue + } + pool.PutBuf(buf) + } +} + +func (nc *NatHoleController) GenSid() string { + t := time.Now().Unix() + id, _ := util.RandId() + return fmt.Sprintf("%d%s", t, id) +} + +func (nc *NatHoleController) HandleVisitor(m *msg.NatHoleVisitor, raddr *net.UDPAddr) { + sid := nc.GenSid() + session := &NatHoleSession{ + Sid: sid, + VisitorAddr: raddr, + NotifyCh: make(chan struct{}, 0), + } + nc.mu.Lock() + clientCfg, ok := nc.clientCfgs[m.ProxyName] + if !ok { + nc.mu.Unlock() + errInfo := fmt.Sprintf("xtcp server for [%s] doesn't exist", m.ProxyName) + log.Debug(errInfo) + nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr) + return + } + if m.SignKey != util.GetAuthKey(clientCfg.Sk, m.Timestamp) { + nc.mu.Unlock() + errInfo := fmt.Sprintf("xtcp connection of [%s] auth failed", m.ProxyName) + log.Debug(errInfo) + nc.listener.WriteToUDP(nc.GenNatHoleResponse(nil, errInfo), raddr) + return + } + + nc.sessions[sid] = session + nc.mu.Unlock() + log.Trace("handle visitor message, sid [%s]", sid) + + defer func() { + nc.mu.Lock() + delete(nc.sessions, sid) + nc.mu.Unlock() + }() + + err := errors.PanicToError(func() { + clientCfg.SidCh <- &SidRequest{ + Sid: sid, + NotifyCh: session.NotifyCh, + } + }) + if err != nil { + return + } + + // Wait client connections. + select { + case <-session.NotifyCh: + resp := nc.GenNatHoleResponse(session, "") + log.Trace("send nat hole response to visitor") + nc.listener.WriteToUDP(resp, raddr) + case <-time.After(time.Duration(NatHoleTimeout) * time.Second): + return + } +} + +func (nc *NatHoleController) HandleClient(m *msg.NatHoleClient, raddr *net.UDPAddr) { + nc.mu.RLock() + session, ok := nc.sessions[m.Sid] + nc.mu.RUnlock() + if !ok { + return + } + log.Trace("handle client message, sid [%s]", session.Sid) + session.ClientAddr = raddr + + resp := nc.GenNatHoleResponse(session, "") + log.Trace("send nat hole response to client") + nc.listener.WriteToUDP(resp, raddr) +} + +func (nc *NatHoleController) GenNatHoleResponse(session *NatHoleSession, errInfo string) []byte { + var ( + sid string + visitorAddr string + clientAddr string + ) + if session != nil { + sid = session.Sid + visitorAddr = session.VisitorAddr.String() + clientAddr = session.ClientAddr.String() + } + m := &msg.NatHoleResp{ + Sid: sid, + VisitorAddr: visitorAddr, + ClientAddr: clientAddr, + Error: errInfo, + } + b := bytes.NewBuffer(nil) + err := msg.WriteMsg(b, m) + if err != nil { + return []byte("") + } + return b.Bytes() +} + +type NatHoleSession struct { + Sid string + VisitorAddr *net.UDPAddr + ClientAddr *net.UDPAddr + + NotifyCh chan struct{} +} + +type NatHoleClientCfg struct { + Name string + Sk string + SidCh chan *SidRequest +} diff --git a/models/plugin/client/http2https.go b/models/plugin/client/http2https.go new file mode 100644 index 0000000..834182a --- /dev/null +++ b/models/plugin/client/http2https.go @@ -0,0 +1,111 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "strings" + + frpNet "github.com/charlesbao/frpc/utils/net" +) + +const PluginHTTP2HTTPS = "http2https" + +func init() { + Register(PluginHTTP2HTTPS, NewHTTP2HTTPSPlugin) +} + +type HTTP2HTTPSPlugin struct { + hostHeaderRewrite string + localAddr string + headers map[string]string + + l *Listener + s *http.Server +} + +func NewHTTP2HTTPSPlugin(params map[string]string) (Plugin, error) { + localAddr := params["plugin_local_addr"] + hostHeaderRewrite := params["plugin_host_header_rewrite"] + headers := make(map[string]string) + for k, v := range params { + if !strings.HasPrefix(k, "plugin_header_") { + continue + } + if k = strings.TrimPrefix(k, "plugin_header_"); k != "" { + headers[k] = v + } + } + + if localAddr == "" { + return nil, fmt.Errorf("plugin_local_addr is required") + } + + listener := NewProxyListener() + + p := &HTTPS2HTTPPlugin{ + localAddr: localAddr, + hostHeaderRewrite: hostHeaderRewrite, + headers: headers, + l: listener, + } + + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + rp := &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = "https" + req.URL.Host = p.localAddr + if p.hostHeaderRewrite != "" { + req.Host = p.hostHeaderRewrite + } + for k, v := range p.headers { + req.Header.Set(k, v) + } + }, + Transport: tr, + } + + p.s = &http.Server{ + Handler: rp, + } + + go p.s.Serve(listener) + + return p, nil +} + +func (p *HTTP2HTTPSPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { + wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) + p.l.PutConn(wrapConn) +} + +func (p *HTTP2HTTPSPlugin) Name() string { + return PluginHTTP2HTTPS +} + +func (p *HTTP2HTTPSPlugin) Close() error { + if err := p.s.Close(); err != nil { + return err + } + return nil +} diff --git a/models/plugin/client/http_proxy.go b/models/plugin/client/http_proxy.go new file mode 100644 index 0000000..535d2a9 --- /dev/null +++ b/models/plugin/client/http_proxy.go @@ -0,0 +1,243 @@ +// Copyright 2017 frp team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "bufio" + "encoding/base64" + "io" + "net" + "net/http" + "strings" + + frpNet "github.com/charlesbao/frpc/utils/net" + + frpIo "github.com/fatedier/golib/io" + gnet "github.com/fatedier/golib/net" +) + +const PluginHttpProxy = "http_proxy" + +func init() { + Register(PluginHttpProxy, NewHttpProxyPlugin) +} + +type HttpProxy struct { + l *Listener + s *http.Server + AuthUser string + AuthPasswd string +} + +func NewHttpProxyPlugin(params map[string]string) (Plugin, error) { + user := params["plugin_http_user"] + passwd := params["plugin_http_passwd"] + listener := NewProxyListener() + + hp := &HttpProxy{ + l: listener, + AuthUser: user, + AuthPasswd: passwd, + } + + hp.s = &http.Server{ + Handler: hp, + } + + go hp.s.Serve(listener) + return hp, nil +} + +func (hp *HttpProxy) Name() string { + return PluginHttpProxy +} + +func (hp *HttpProxy) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { + wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) + + sc, rd := gnet.NewSharedConn(wrapConn) + firstBytes := make([]byte, 7) + _, err := rd.Read(firstBytes) + if err != nil { + wrapConn.Close() + return + } + + if strings.ToUpper(string(firstBytes)) == "CONNECT" { + bufRd := bufio.NewReader(sc) + request, err := http.ReadRequest(bufRd) + if err != nil { + wrapConn.Close() + return + } + hp.handleConnectReq(request, frpIo.WrapReadWriteCloser(bufRd, wrapConn, wrapConn.Close)) + return + } + + hp.l.PutConn(sc) + return +} + +func (hp *HttpProxy) Close() error { + hp.s.Close() + hp.l.Close() + return nil +} + +func (hp *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + if ok := hp.Auth(req); !ok { + rw.Header().Set("Proxy-Authenticate", "Basic") + rw.WriteHeader(http.StatusProxyAuthRequired) + return + } + + if req.Method == http.MethodConnect { + // deprecated + // Connect request is handled in Handle function. + hp.ConnectHandler(rw, req) + } else { + hp.HttpHandler(rw, req) + } +} + +func (hp *HttpProxy) HttpHandler(rw http.ResponseWriter, req *http.Request) { + removeProxyHeaders(req) + + resp, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + copyHeaders(rw.Header(), resp.Header) + rw.WriteHeader(resp.StatusCode) + + _, err = io.Copy(rw, resp.Body) + if err != nil && err != io.EOF { + return + } +} + +// deprecated +// Hijack needs to SetReadDeadline on the Conn of the request, but if we use stream compression here, +// we may always get i/o timeout error. +func (hp *HttpProxy) ConnectHandler(rw http.ResponseWriter, req *http.Request) { + hj, ok := rw.(http.Hijacker) + if !ok { + rw.WriteHeader(http.StatusInternalServerError) + return + } + + client, _, err := hj.Hijack() + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + return + } + + remote, err := net.Dial("tcp", req.URL.Host) + if err != nil { + http.Error(rw, "Failed", http.StatusBadRequest) + client.Close() + return + } + client.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) + + go frpIo.Join(remote, client) +} + +func (hp *HttpProxy) Auth(req *http.Request) bool { + if hp.AuthUser == "" && hp.AuthPasswd == "" { + return true + } + + s := strings.SplitN(req.Header.Get("Proxy-Authorization"), " ", 2) + if len(s) != 2 { + return false + } + + b, err := base64.StdEncoding.DecodeString(s[1]) + if err != nil { + return false + } + + pair := strings.SplitN(string(b), ":", 2) + if len(pair) != 2 { + return false + } + + if pair[0] != hp.AuthUser || pair[1] != hp.AuthPasswd { + return false + } + return true +} + +func (hp *HttpProxy) handleConnectReq(req *http.Request, rwc io.ReadWriteCloser) { + defer rwc.Close() + if ok := hp.Auth(req); !ok { + res := getBadResponse() + res.Write(rwc) + return + } + + remote, err := net.Dial("tcp", req.URL.Host) + if err != nil { + res := &http.Response{ + StatusCode: 400, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + } + res.Write(rwc) + return + } + rwc.Write([]byte("HTTP/1.1 200 OK\r\n\r\n")) + + frpIo.Join(remote, rwc) +} + +func copyHeaders(dst, src http.Header) { + for key, values := range src { + for _, value := range values { + dst.Add(key, value) + } + } +} + +func removeProxyHeaders(req *http.Request) { + req.RequestURI = "" + req.Header.Del("Proxy-Connection") + req.Header.Del("Connection") + req.Header.Del("Proxy-Authenticate") + req.Header.Del("Proxy-Authorization") + req.Header.Del("TE") + req.Header.Del("Trailers") + req.Header.Del("Transfer-Encoding") + req.Header.Del("Upgrade") +} + +func getBadResponse() *http.Response { + header := make(map[string][]string) + header["Proxy-Authenticate"] = []string{"Basic"} + res := &http.Response{ + Status: "407 Not authorized", + StatusCode: 407, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: header, + } + return res +} diff --git a/models/plugin/client/https2http.go b/models/plugin/client/https2http.go new file mode 100644 index 0000000..c5c1675 --- /dev/null +++ b/models/plugin/client/https2http.go @@ -0,0 +1,133 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "net/http" + "net/http/httputil" + "strings" + + frpNet "github.com/charlesbao/frpc/utils/net" +) + +const PluginHTTPS2HTTP = "https2http" + +func init() { + Register(PluginHTTPS2HTTP, NewHTTPS2HTTPPlugin) +} + +type HTTPS2HTTPPlugin struct { + crtPath string + keyPath string + hostHeaderRewrite string + localAddr string + headers map[string]string + + l *Listener + s *http.Server +} + +func NewHTTPS2HTTPPlugin(params map[string]string) (Plugin, error) { + crtPath := params["plugin_crt_path"] + keyPath := params["plugin_key_path"] + localAddr := params["plugin_local_addr"] + hostHeaderRewrite := params["plugin_host_header_rewrite"] + headers := make(map[string]string) + for k, v := range params { + if !strings.HasPrefix(k, "plugin_header_") { + continue + } + if k = strings.TrimPrefix(k, "plugin_header_"); k != "" { + headers[k] = v + } + } + + if crtPath == "" { + return nil, fmt.Errorf("plugin_crt_path is required") + } + if keyPath == "" { + return nil, fmt.Errorf("plugin_key_path is required") + } + if localAddr == "" { + return nil, fmt.Errorf("plugin_local_addr is required") + } + + listener := NewProxyListener() + + p := &HTTPS2HTTPPlugin{ + crtPath: crtPath, + keyPath: keyPath, + localAddr: localAddr, + hostHeaderRewrite: hostHeaderRewrite, + headers: headers, + l: listener, + } + + rp := &httputil.ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = "http" + req.URL.Host = p.localAddr + if p.hostHeaderRewrite != "" { + req.Host = p.hostHeaderRewrite + } + for k, v := range p.headers { + req.Header.Set(k, v) + } + }, + } + + p.s = &http.Server{ + Handler: rp, + } + + tlsConfig, err := p.genTLSConfig() + if err != nil { + return nil, fmt.Errorf("gen TLS config error: %v", err) + } + ln := tls.NewListener(listener, tlsConfig) + + go p.s.Serve(ln) + return p, nil +} + +func (p *HTTPS2HTTPPlugin) genTLSConfig() (*tls.Config, error) { + cert, err := tls.LoadX509KeyPair(p.crtPath, p.keyPath) + if err != nil { + return nil, err + } + + config := &tls.Config{Certificates: []tls.Certificate{cert}} + return config, nil +} + +func (p *HTTPS2HTTPPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { + wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) + p.l.PutConn(wrapConn) +} + +func (p *HTTPS2HTTPPlugin) Name() string { + return PluginHTTPS2HTTP +} + +func (p *HTTPS2HTTPPlugin) Close() error { + if err := p.s.Close(); err != nil { + return err + } + return nil +} diff --git a/models/plugin/client/plugin.go b/models/plugin/client/plugin.go new file mode 100644 index 0000000..6850919 --- /dev/null +++ b/models/plugin/client/plugin.go @@ -0,0 +1,92 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "fmt" + "io" + "net" + "sync" + + "github.com/fatedier/golib/errors" +) + +// Creators is used for create plugins to handle connections. +var creators = make(map[string]CreatorFn) + +// params has prefix "plugin_" +type CreatorFn func(params map[string]string) (Plugin, error) + +func Register(name string, fn CreatorFn) { + creators[name] = fn +} + +func Create(name string, params map[string]string) (p Plugin, err error) { + if fn, ok := creators[name]; ok { + p, err = fn(params) + } else { + err = fmt.Errorf("plugin [%s] is not registered", name) + } + return +} + +type Plugin interface { + Name() string + + // extraBufToLocal will send to local connection first, then join conn with local connection + Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) + Close() error +} + +type Listener struct { + conns chan net.Conn + closed bool + mu sync.Mutex +} + +func NewProxyListener() *Listener { + return &Listener{ + conns: make(chan net.Conn, 64), + } +} + +func (l *Listener) Accept() (net.Conn, error) { + conn, ok := <-l.conns + if !ok { + return nil, fmt.Errorf("listener closed") + } + return conn, nil +} + +func (l *Listener) PutConn(conn net.Conn) error { + err := errors.PanicToError(func() { + l.conns <- conn + }) + return err +} + +func (l *Listener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + if !l.closed { + close(l.conns) + l.closed = true + } + return nil +} + +func (l *Listener) Addr() net.Addr { + return (*net.TCPAddr)(nil) +} diff --git a/models/plugin/client/socks5.go b/models/plugin/client/socks5.go new file mode 100644 index 0000000..c80b642 --- /dev/null +++ b/models/plugin/client/socks5.go @@ -0,0 +1,69 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "io" + "io/ioutil" + "log" + "net" + + frpNet "github.com/charlesbao/frpc/utils/net" + + gosocks5 "github.com/armon/go-socks5" +) + +const PluginSocks5 = "socks5" + +func init() { + Register(PluginSocks5, NewSocks5Plugin) +} + +type Socks5Plugin struct { + Server *gosocks5.Server + + user string + passwd string +} + +func NewSocks5Plugin(params map[string]string) (p Plugin, err error) { + user := params["plugin_user"] + passwd := params["plugin_passwd"] + + cfg := &gosocks5.Config{ + Logger: log.New(ioutil.Discard, "", log.LstdFlags), + } + if user != "" || passwd != "" { + cfg.Credentials = gosocks5.StaticCredentials(map[string]string{user: passwd}) + } + sp := &Socks5Plugin{} + sp.Server, err = gosocks5.New(cfg) + p = sp + return +} + +func (sp *Socks5Plugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { + defer conn.Close() + wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) + sp.Server.ServeConn(wrapConn) +} + +func (sp *Socks5Plugin) Name() string { + return PluginSocks5 +} + +func (sp *Socks5Plugin) Close() error { + return nil +} diff --git a/models/plugin/client/static_file.go b/models/plugin/client/static_file.go new file mode 100644 index 0000000..0a5626b --- /dev/null +++ b/models/plugin/client/static_file.go @@ -0,0 +1,89 @@ +// Copyright 2018 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "io" + "net" + "net/http" + + frpNet "github.com/charlesbao/frpc/utils/net" + + "github.com/gorilla/mux" +) + +const PluginStaticFile = "static_file" + +func init() { + Register(PluginStaticFile, NewStaticFilePlugin) +} + +type StaticFilePlugin struct { + localPath string + stripPrefix string + httpUser string + httpPasswd string + + l *Listener + s *http.Server +} + +func NewStaticFilePlugin(params map[string]string) (Plugin, error) { + localPath := params["plugin_local_path"] + stripPrefix := params["plugin_strip_prefix"] + httpUser := params["plugin_http_user"] + httpPasswd := params["plugin_http_passwd"] + + listener := NewProxyListener() + + sp := &StaticFilePlugin{ + localPath: localPath, + stripPrefix: stripPrefix, + httpUser: httpUser, + httpPasswd: httpPasswd, + + l: listener, + } + var prefix string + if stripPrefix != "" { + prefix = "/" + stripPrefix + "/" + } else { + prefix = "/" + } + + router := mux.NewRouter() + router.Use(frpNet.NewHttpAuthMiddleware(httpUser, httpPasswd).Middleware) + router.PathPrefix(prefix).Handler(frpNet.MakeHttpGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(localPath))))).Methods("GET") + sp.s = &http.Server{ + Handler: router, + } + go sp.s.Serve(listener) + return sp, nil +} + +func (sp *StaticFilePlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { + wrapConn := frpNet.WrapReadWriteCloserToConn(conn, realConn) + sp.l.PutConn(wrapConn) +} + +func (sp *StaticFilePlugin) Name() string { + return PluginStaticFile +} + +func (sp *StaticFilePlugin) Close() error { + sp.s.Close() + sp.l.Close() + return nil +} diff --git a/models/plugin/client/unix_domain_socket.go b/models/plugin/client/unix_domain_socket.go new file mode 100644 index 0000000..a85ada7 --- /dev/null +++ b/models/plugin/client/unix_domain_socket.go @@ -0,0 +1,72 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "fmt" + "io" + "net" + + frpIo "github.com/fatedier/golib/io" +) + +const PluginUnixDomainSocket = "unix_domain_socket" + +func init() { + Register(PluginUnixDomainSocket, NewUnixDomainSocketPlugin) +} + +type UnixDomainSocketPlugin struct { + UnixAddr *net.UnixAddr +} + +func NewUnixDomainSocketPlugin(params map[string]string) (p Plugin, err error) { + unixPath, ok := params["plugin_unix_path"] + if !ok { + err = fmt.Errorf("plugin_unix_path not found") + return + } + + unixAddr, errRet := net.ResolveUnixAddr("unix", unixPath) + if errRet != nil { + err = errRet + return + } + + p = &UnixDomainSocketPlugin{ + UnixAddr: unixAddr, + } + return +} + +func (uds *UnixDomainSocketPlugin) Handle(conn io.ReadWriteCloser, realConn net.Conn, extraBufToLocal []byte) { + localConn, err := net.DialUnix("unix", nil, uds.UnixAddr) + if err != nil { + return + } + if len(extraBufToLocal) > 0 { + localConn.Write(extraBufToLocal) + } + + frpIo.Join(localConn, conn) +} + +func (uds *UnixDomainSocketPlugin) Name() string { + return PluginUnixDomainSocket +} + +func (uds *UnixDomainSocketPlugin) Close() error { + return nil +} diff --git a/models/plugin/server/http.go b/models/plugin/server/http.go new file mode 100644 index 0000000..155c470 --- /dev/null +++ b/models/plugin/server/http.go @@ -0,0 +1,104 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "reflect" +) + +type HTTPPluginOptions struct { + Name string + Addr string + Path string + Ops []string +} + +type httpPlugin struct { + options HTTPPluginOptions + + url string + client *http.Client +} + +func NewHTTPPluginOptions(options HTTPPluginOptions) Plugin { + return &httpPlugin{ + options: options, + url: fmt.Sprintf("http://%s%s", options.Addr, options.Path), + client: &http.Client{}, + } +} + +func (p *httpPlugin) Name() string { + return p.options.Name +} + +func (p *httpPlugin) IsSupport(op string) bool { + for _, v := range p.options.Ops { + if v == op { + return true + } + } + return false +} + +func (p *httpPlugin) Handle(ctx context.Context, op string, content interface{}) (*Response, interface{}, error) { + r := &Request{ + Version: APIVersion, + Op: op, + Content: content, + } + var res Response + res.Content = reflect.New(reflect.TypeOf(content)).Interface() + if err := p.do(ctx, r, &res); err != nil { + return nil, nil, err + } + return &res, res.Content, nil +} + +func (p *httpPlugin) do(ctx context.Context, r *Request, res *Response) error { + buf, err := json.Marshal(r) + if err != nil { + return err + } + req, err := http.NewRequest("POST", p.url, bytes.NewReader(buf)) + if err != nil { + return err + } + req = req.WithContext(ctx) + req.Header.Set("X-Frp-Reqid", GetReqidFromContext(ctx)) + resp, err := p.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("do http request error code: %d", resp.StatusCode) + } + buf, err = ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + if err = json.Unmarshal(buf, res); err != nil { + return err + } + return nil +} diff --git a/models/plugin/server/manager.go b/models/plugin/server/manager.go new file mode 100644 index 0000000..aee8f80 --- /dev/null +++ b/models/plugin/server/manager.go @@ -0,0 +1,105 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" + "errors" + "fmt" + + "github.com/charlesbao/frpc/utils/util" + "github.com/charlesbao/frpc/utils/xlog" +) + +type Manager struct { + loginPlugins []Plugin + newProxyPlugins []Plugin +} + +func NewManager() *Manager { + return &Manager{ + loginPlugins: make([]Plugin, 0), + newProxyPlugins: make([]Plugin, 0), + } +} + +func (m *Manager) Register(p Plugin) { + if p.IsSupport(OpLogin) { + m.loginPlugins = append(m.loginPlugins, p) + } + if p.IsSupport(OpNewProxy) { + m.newProxyPlugins = append(m.newProxyPlugins, p) + } +} + +func (m *Manager) Login(content *LoginContent) (*LoginContent, error) { + var ( + res = &Response{ + Reject: false, + Unchange: true, + } + retContent interface{} + err error + ) + reqid, _ := util.RandId() + xl := xlog.New().AppendPrefix("reqid: " + reqid) + ctx := xlog.NewContext(context.Background(), xl) + ctx = NewReqidContext(ctx, reqid) + + for _, p := range m.loginPlugins { + res, retContent, err = p.Handle(ctx, OpLogin, *content) + if err != nil { + xl.Warn("send Login request to plugin [%s] error: %v", p.Name(), err) + return nil, errors.New("send Login request to plugin error") + } + if res.Reject { + return nil, fmt.Errorf("%s", res.RejectReason) + } + if !res.Unchange { + content = retContent.(*LoginContent) + } + } + return content, nil +} + +func (m *Manager) NewProxy(content *NewProxyContent) (*NewProxyContent, error) { + var ( + res = &Response{ + Reject: false, + Unchange: true, + } + retContent interface{} + err error + ) + reqid, _ := util.RandId() + xl := xlog.New().AppendPrefix("reqid: " + reqid) + ctx := xlog.NewContext(context.Background(), xl) + ctx = NewReqidContext(ctx, reqid) + + for _, p := range m.newProxyPlugins { + res, retContent, err = p.Handle(ctx, OpNewProxy, *content) + if err != nil { + xl.Warn("send NewProxy request to plugin [%s] error: %v", p.Name(), err) + return nil, errors.New("send NewProxy request to plugin error") + } + if res.Reject { + return nil, fmt.Errorf("%s", res.RejectReason) + } + if !res.Unchange { + content = retContent.(*NewProxyContent) + } + } + return content, nil +} diff --git a/models/plugin/server/plugin.go b/models/plugin/server/plugin.go new file mode 100644 index 0000000..fd16b14 --- /dev/null +++ b/models/plugin/server/plugin.go @@ -0,0 +1,32 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" +) + +const ( + APIVersion = "0.1.0" + + OpLogin = "Login" + OpNewProxy = "NewProxy" +) + +type Plugin interface { + Name() string + IsSupport(op string) bool + Handle(ctx context.Context, op string, content interface{}) (res *Response, retContent interface{}, err error) +} diff --git a/models/plugin/server/tracer.go b/models/plugin/server/tracer.go new file mode 100644 index 0000000..2f4f2cc --- /dev/null +++ b/models/plugin/server/tracer.go @@ -0,0 +1,34 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "context" +) + +type key int + +const ( + reqidKey key = 0 +) + +func NewReqidContext(ctx context.Context, reqid string) context.Context { + return context.WithValue(ctx, reqidKey, reqid) +} + +func GetReqidFromContext(ctx context.Context) string { + ret, _ := ctx.Value(reqidKey).(string) + return ret +} diff --git a/models/plugin/server/types.go b/models/plugin/server/types.go new file mode 100644 index 0000000..120f53c --- /dev/null +++ b/models/plugin/server/types.go @@ -0,0 +1,46 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package plugin + +import ( + "github.com/charlesbao/frpc/models/msg" +) + +type Request struct { + Version string `json:"version"` + Op string `json:"op"` + Content interface{} `json:"content"` +} + +type Response struct { + Reject bool `json:"reject"` + RejectReason string `json:"reject_reason"` + Unchange bool `json:"unchange"` + Content interface{} `json:"content"` +} + +type LoginContent struct { + msg.Login +} + +type UserInfo struct { + User string `json:"user"` + Metas map[string]string `json:"metas"` +} + +type NewProxyContent struct { + User UserInfo `json:"user"` + msg.NewProxy +} diff --git a/models/proto/udp/udp.go b/models/proto/udp/udp.go new file mode 100644 index 0000000..9d45717 --- /dev/null +++ b/models/proto/udp/udp.go @@ -0,0 +1,137 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "encoding/base64" + "net" + "sync" + "time" + + "github.com/charlesbao/frpc/models/msg" + + "github.com/fatedier/golib/errors" + "github.com/fatedier/golib/pool" +) + +func NewUdpPacket(buf []byte, laddr, raddr *net.UDPAddr) *msg.UdpPacket { + return &msg.UdpPacket{ + Content: base64.StdEncoding.EncodeToString(buf), + LocalAddr: laddr, + RemoteAddr: raddr, + } +} + +func GetContent(m *msg.UdpPacket) (buf []byte, err error) { + buf, err = base64.StdEncoding.DecodeString(m.Content) + return +} + +func ForwardUserConn(udpConn *net.UDPConn, readCh <-chan *msg.UdpPacket, sendCh chan<- *msg.UdpPacket) { + // read + go func() { + for udpMsg := range readCh { + buf, err := GetContent(udpMsg) + if err != nil { + continue + } + udpConn.WriteToUDP(buf, udpMsg.RemoteAddr) + } + }() + + // write + buf := pool.GetBuf(1500) + defer pool.PutBuf(buf) + for { + n, remoteAddr, err := udpConn.ReadFromUDP(buf) + if err != nil { + udpConn.Close() + return + } + // buf[:n] will be encoded to string, so the bytes can be reused + udpMsg := NewUdpPacket(buf[:n], nil, remoteAddr) + select { + case sendCh <- udpMsg: + default: + } + } +} + +func Forwarder(dstAddr *net.UDPAddr, readCh <-chan *msg.UdpPacket, sendCh chan<- msg.Message) { + var ( + mu sync.RWMutex + ) + udpConnMap := make(map[string]*net.UDPConn) + + // read from dstAddr and write to sendCh + writerFn := func(raddr *net.UDPAddr, udpConn *net.UDPConn) { + addr := raddr.String() + defer func() { + mu.Lock() + delete(udpConnMap, addr) + mu.Unlock() + udpConn.Close() + }() + + buf := pool.GetBuf(1500) + for { + udpConn.SetReadDeadline(time.Now().Add(30 * time.Second)) + n, _, err := udpConn.ReadFromUDP(buf) + if err != nil { + return + } + + udpMsg := NewUdpPacket(buf[:n], nil, raddr) + if err = errors.PanicToError(func() { + select { + case sendCh <- udpMsg: + default: + } + }); err != nil { + return + } + } + } + + // read from readCh + go func() { + for udpMsg := range readCh { + buf, err := GetContent(udpMsg) + if err != nil { + continue + } + mu.Lock() + udpConn, ok := udpConnMap[udpMsg.RemoteAddr.String()] + if !ok { + udpConn, err = net.DialUDP("udp", nil, dstAddr) + if err != nil { + mu.Unlock() + continue + } + udpConnMap[udpMsg.RemoteAddr.String()] = udpConn + } + mu.Unlock() + + _, err = udpConn.Write(buf) + if err != nil { + udpConn.Close() + } + + if !ok { + go writerFn(udpMsg.RemoteAddr, udpConn) + } + } + }() +} diff --git a/models/proto/udp/udp_test.go b/models/proto/udp/udp_test.go new file mode 100644 index 0000000..6f364b8 --- /dev/null +++ b/models/proto/udp/udp_test.go @@ -0,0 +1,18 @@ +package udp + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestUdpPacket(t *testing.T) { + assert := assert.New(t) + + buf := []byte("hello world") + udpMsg := NewUdpPacket(buf, nil, nil) + + newBuf, err := GetContent(udpMsg) + assert.NoError(err) + assert.EqualValues(buf, newBuf) +} diff --git a/utils/limit/reader.go b/utils/limit/reader.go new file mode 100644 index 0000000..efa828f --- /dev/null +++ b/utils/limit/reader.go @@ -0,0 +1,51 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package limit + +import ( + "context" + "io" + + "golang.org/x/time/rate" +) + +type Reader struct { + r io.Reader + limiter *rate.Limiter +} + +func NewReader(r io.Reader, limiter *rate.Limiter) *Reader { + return &Reader{ + r: r, + limiter: limiter, + } +} + +func (r *Reader) Read(p []byte) (n int, err error) { + b := r.limiter.Burst() + if b < len(p) { + p = p[:b] + } + n, err = r.r.Read(p) + if err != nil { + return + } + + err = r.limiter.WaitN(context.Background(), n) + if err != nil { + return + } + return +} diff --git a/utils/limit/writer.go b/utils/limit/writer.go new file mode 100644 index 0000000..5256d1e --- /dev/null +++ b/utils/limit/writer.go @@ -0,0 +1,60 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package limit + +import ( + "context" + "io" + + "golang.org/x/time/rate" +) + +type Writer struct { + w io.Writer + limiter *rate.Limiter +} + +func NewWriter(w io.Writer, limiter *rate.Limiter) *Writer { + return &Writer{ + w: w, + limiter: limiter, + } +} + +func (w *Writer) Write(p []byte) (n int, err error) { + var nn int + b := w.limiter.Burst() + for { + end := len(p) + if end == 0 { + break + } + if b < len(p) { + end = b + } + err = w.limiter.WaitN(context.Background(), end) + if err != nil { + return + } + + nn, err = w.w.Write(p[:end]) + n += nn + if err != nil { + return + } + p = p[end:] + } + return +} diff --git a/utils/log/log.go b/utils/log/log.go new file mode 100644 index 0000000..1ddf4cd --- /dev/null +++ b/utils/log/log.go @@ -0,0 +1,93 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package log + +import ( + "fmt" + + "github.com/fatedier/beego/logs" +) + +// Log is the under log object +var Log *logs.BeeLogger + +func init() { + Log = logs.NewLogger(200) + Log.EnableFuncCallDepth(true) + Log.SetLogFuncCallDepth(Log.GetLogFuncCallDepth() + 1) +} + +func InitLog(logWay string, logFile string, logLevel string, maxdays int64, disableLogColor bool) { + SetLogFile(logWay, logFile, maxdays, disableLogColor) + SetLogLevel(logLevel) +} + +// SetLogFile to configure log params +// logWay: file or console +func SetLogFile(logWay string, logFile string, maxdays int64, disableLogColor bool) { + if logWay == "console" { + params := "" + if disableLogColor { + params = fmt.Sprintf(`{"color": false}`) + } + Log.SetLogger("console", params) + } else { + params := fmt.Sprintf(`{"filename": "%s", "maxdays": %d}`, logFile, maxdays) + Log.SetLogger("file", params) + } +} + +// SetLogLevel set log level, default is warning +// value: error, warning, info, debug, trace +func SetLogLevel(logLevel string) { + level := 4 // warning + switch logLevel { + case "error": + level = 3 + case "warn": + level = 4 + case "info": + level = 6 + case "debug": + level = 7 + case "trace": + level = 8 + default: + level = 4 + } + Log.SetLevel(level) +} + +// wrap log + +func Error(format string, v ...interface{}) { + Log.Error(format, v...) +} + +func Warn(format string, v ...interface{}) { + Log.Warn(format, v...) +} + +func Info(format string, v ...interface{}) { + Log.Info(format, v...) +} + +func Debug(format string, v ...interface{}) { + Log.Debug(format, v...) +} + +func Trace(format string, v ...interface{}) { + Log.Trace(format, v...) +} diff --git a/utils/metric/counter.go b/utils/metric/counter.go new file mode 100644 index 0000000..4b9c7a6 --- /dev/null +++ b/utils/metric/counter.go @@ -0,0 +1,60 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "sync/atomic" +) + +type Counter interface { + Count() int64 + Inc(int64) + Dec(int64) + Snapshot() Counter + Clear() +} + +func NewCounter() Counter { + return &StandardCounter{ + count: 0, + } +} + +type StandardCounter struct { + count int64 +} + +func (c *StandardCounter) Count() int64 { + return atomic.LoadInt64(&c.count) +} + +func (c *StandardCounter) Inc(count int64) { + atomic.AddInt64(&c.count, count) +} + +func (c *StandardCounter) Dec(count int64) { + atomic.AddInt64(&c.count, -count) +} + +func (c *StandardCounter) Snapshot() Counter { + tmp := &StandardCounter{ + count: atomic.LoadInt64(&c.count), + } + return tmp +} + +func (c *StandardCounter) Clear() { + atomic.StoreInt64(&c.count, 0) +} diff --git a/utils/metric/counter_test.go b/utils/metric/counter_test.go new file mode 100644 index 0000000..4925c25 --- /dev/null +++ b/utils/metric/counter_test.go @@ -0,0 +1,23 @@ +package metric + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCounter(t *testing.T) { + assert := assert.New(t) + c := NewCounter() + c.Inc(10) + assert.EqualValues(10, c.Count()) + + c.Dec(5) + assert.EqualValues(5, c.Count()) + + cTmp := c.Snapshot() + assert.EqualValues(5, cTmp.Count()) + + c.Clear() + assert.EqualValues(0, c.Count()) +} diff --git a/utils/metric/date_counter.go b/utils/metric/date_counter.go new file mode 100644 index 0000000..4524fec --- /dev/null +++ b/utils/metric/date_counter.go @@ -0,0 +1,134 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package metric + +import ( + "sync" + "time" +) + +type DateCounter interface { + TodayCount() int64 + GetLastDaysCount(lastdays int64) []int64 + Inc(int64) + Dec(int64) + Snapshot() DateCounter + Clear() +} + +func NewDateCounter(reserveDays int64) DateCounter { + if reserveDays <= 0 { + reserveDays = 1 + } + return newStandardDateCounter(reserveDays) +} + +type StandardDateCounter struct { + reserveDays int64 + counts []int64 + + lastUpdateDate time.Time + mu sync.Mutex +} + +func newStandardDateCounter(reserveDays int64) *StandardDateCounter { + now := time.Now() + now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + s := &StandardDateCounter{ + reserveDays: reserveDays, + counts: make([]int64, reserveDays), + lastUpdateDate: now, + } + return s +} + +func (c *StandardDateCounter) TodayCount() int64 { + c.mu.Lock() + defer c.mu.Unlock() + + c.rotate(time.Now()) + return c.counts[0] +} + +func (c *StandardDateCounter) GetLastDaysCount(lastdays int64) []int64 { + if lastdays > c.reserveDays { + lastdays = c.reserveDays + } + counts := make([]int64, lastdays) + + c.mu.Lock() + defer c.mu.Unlock() + c.rotate(time.Now()) + for i := 0; i < int(lastdays); i++ { + counts[i] = c.counts[i] + } + return counts +} + +func (c *StandardDateCounter) Inc(count int64) { + c.mu.Lock() + defer c.mu.Unlock() + c.rotate(time.Now()) + c.counts[0] += count +} + +func (c *StandardDateCounter) Dec(count int64) { + c.mu.Lock() + defer c.mu.Unlock() + c.rotate(time.Now()) + c.counts[0] -= count +} + +func (c *StandardDateCounter) Snapshot() DateCounter { + c.mu.Lock() + defer c.mu.Unlock() + tmp := newStandardDateCounter(c.reserveDays) + for i := 0; i < int(c.reserveDays); i++ { + tmp.counts[i] = c.counts[i] + } + return tmp +} + +func (c *StandardDateCounter) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + for i := 0; i < int(c.reserveDays); i++ { + c.counts[i] = 0 + } +} + +// rotate +// Must hold the lock before calling this function. +func (c *StandardDateCounter) rotate(now time.Time) { + now = time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()) + days := int(now.Sub(c.lastUpdateDate).Hours() / 24) + + defer func() { + c.lastUpdateDate = now + }() + + if days <= 0 { + return + } else if days >= int(c.reserveDays) { + c.counts = make([]int64, c.reserveDays) + return + } + newCounts := make([]int64, c.reserveDays) + + for i := days; i < int(c.reserveDays); i++ { + newCounts[i] = c.counts[i-days] + } + c.counts = newCounts +} diff --git a/utils/metric/date_counter_test.go b/utils/metric/date_counter_test.go new file mode 100644 index 0000000..c9997c7 --- /dev/null +++ b/utils/metric/date_counter_test.go @@ -0,0 +1,27 @@ +package metric + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDateCounter(t *testing.T) { + assert := assert.New(t) + + dc := NewDateCounter(3) + dc.Inc(10) + assert.EqualValues(10, dc.TodayCount()) + + dc.Dec(5) + assert.EqualValues(5, dc.TodayCount()) + + counts := dc.GetLastDaysCount(3) + assert.EqualValues(3, len(counts)) + assert.EqualValues(5, counts[0]) + assert.EqualValues(0, counts[1]) + assert.EqualValues(0, counts[2]) + + dcTmp := dc.Snapshot() + assert.EqualValues(5, dcTmp.TodayCount()) +} diff --git a/utils/net/conn.go b/utils/net/conn.go new file mode 100644 index 0000000..51aa52f --- /dev/null +++ b/utils/net/conn.go @@ -0,0 +1,242 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "io" + "net" + "sync/atomic" + "time" + + "github.com/charlesbao/frpc/utils/xlog" + gnet "github.com/fatedier/golib/net" + kcp "github.com/fatedier/kcp-go" +) + +type ContextGetter interface { + Context() context.Context +} + +type ContextSetter interface { + WithContext(ctx context.Context) +} + +func NewLogFromConn(conn net.Conn) *xlog.Logger { + if c, ok := conn.(ContextGetter); ok { + return xlog.FromContextSafe(c.Context()) + } + return xlog.New() +} + +func NewContextFromConn(conn net.Conn) context.Context { + if c, ok := conn.(ContextGetter); ok { + return c.Context() + } + return context.Background() +} + +// ContextConn is the connection with context +type ContextConn struct { + net.Conn + + ctx context.Context +} + +func NewContextConn(c net.Conn, ctx context.Context) *ContextConn { + return &ContextConn{ + Conn: c, + ctx: ctx, + } +} + +func (c *ContextConn) WithContext(ctx context.Context) { + c.ctx = ctx +} + +func (c *ContextConn) Context() context.Context { + return c.ctx +} + +type WrapReadWriteCloserConn struct { + io.ReadWriteCloser + + underConn net.Conn +} + +func WrapReadWriteCloserToConn(rwc io.ReadWriteCloser, underConn net.Conn) net.Conn { + return &WrapReadWriteCloserConn{ + ReadWriteCloser: rwc, + underConn: underConn, + } +} + +func (conn *WrapReadWriteCloserConn) LocalAddr() net.Addr { + if conn.underConn != nil { + return conn.underConn.LocalAddr() + } + return (*net.TCPAddr)(nil) +} + +func (conn *WrapReadWriteCloserConn) RemoteAddr() net.Addr { + if conn.underConn != nil { + return conn.underConn.RemoteAddr() + } + return (*net.TCPAddr)(nil) +} + +func (conn *WrapReadWriteCloserConn) SetDeadline(t time.Time) error { + if conn.underConn != nil { + return conn.underConn.SetDeadline(t) + } + return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (conn *WrapReadWriteCloserConn) SetReadDeadline(t time.Time) error { + if conn.underConn != nil { + return conn.underConn.SetReadDeadline(t) + } + return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (conn *WrapReadWriteCloserConn) SetWriteDeadline(t time.Time) error { + if conn.underConn != nil { + return conn.underConn.SetWriteDeadline(t) + } + return &net.OpError{Op: "set", Net: "wrap", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +type CloseNotifyConn struct { + net.Conn + + // 1 means closed + closeFlag int32 + + closeFn func() +} + +// closeFn will be only called once +func WrapCloseNotifyConn(c net.Conn, closeFn func()) net.Conn { + return &CloseNotifyConn{ + Conn: c, + closeFn: closeFn, + } +} + +func (cc *CloseNotifyConn) Close() (err error) { + pflag := atomic.SwapInt32(&cc.closeFlag, 1) + if pflag == 0 { + err = cc.Close() + if cc.closeFn != nil { + cc.closeFn() + } + } + return +} + +type StatsConn struct { + net.Conn + + closed int64 // 1 means closed + totalRead int64 + totalWrite int64 + statsFunc func(totalRead, totalWrite int64) +} + +func WrapStatsConn(conn net.Conn, statsFunc func(total, totalWrite int64)) *StatsConn { + return &StatsConn{ + Conn: conn, + statsFunc: statsFunc, + } +} + +func (statsConn *StatsConn) Read(p []byte) (n int, err error) { + n, err = statsConn.Conn.Read(p) + statsConn.totalRead += int64(n) + return +} + +func (statsConn *StatsConn) Write(p []byte) (n int, err error) { + n, err = statsConn.Conn.Write(p) + statsConn.totalWrite += int64(n) + return +} + +func (statsConn *StatsConn) Close() (err error) { + old := atomic.SwapInt64(&statsConn.closed, 1) + if old != 1 { + err = statsConn.Conn.Close() + if statsConn.statsFunc != nil { + statsConn.statsFunc(statsConn.totalRead, statsConn.totalWrite) + } + } + return +} + +func ConnectServer(protocol string, addr string) (c net.Conn, err error) { + switch protocol { + case "tcp": + return net.Dial("tcp", addr) + case "kcp": + kcpConn, errRet := kcp.DialWithOptions(addr, nil, 10, 3) + if errRet != nil { + err = errRet + return + } + kcpConn.SetStreamMode(true) + kcpConn.SetWriteDelay(true) + kcpConn.SetNoDelay(1, 20, 2, 1) + kcpConn.SetWindowSize(128, 512) + kcpConn.SetMtu(1350) + kcpConn.SetACKNoDelay(false) + kcpConn.SetReadBuffer(4194304) + kcpConn.SetWriteBuffer(4194304) + c = kcpConn + return + default: + return nil, fmt.Errorf("unsupport protocol: %s", protocol) + } +} + +func ConnectServerByProxy(proxyURL string, protocol string, addr string) (c net.Conn, err error) { + switch protocol { + case "tcp": + return gnet.DialTcpByProxy(proxyURL, addr) + case "kcp": + // http proxy is not supported for kcp + return ConnectServer(protocol, addr) + case "websocket": + return ConnectWebsocketServer(addr) + default: + return nil, fmt.Errorf("unsupport protocol: %s", protocol) + } +} + +func ConnectServerByProxyWithTLS(proxyUrl string, protocol string, addr string, tlsConfig *tls.Config) (c net.Conn, err error) { + c, err = ConnectServerByProxy(proxyUrl, protocol, addr) + if err != nil { + return + } + + if tlsConfig == nil { + return + } + + c = WrapTLSClientConn(c, tlsConfig) + return +} diff --git a/utils/net/http.go b/utils/net/http.go new file mode 100644 index 0000000..f5df84d --- /dev/null +++ b/utils/net/http.go @@ -0,0 +1,115 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "compress/gzip" + "io" + "net/http" + "strings" +) + +type HttpAuthWraper struct { + h http.Handler + user string + passwd string +} + +func NewHttpBasicAuthWraper(h http.Handler, user, passwd string) http.Handler { + return &HttpAuthWraper{ + h: h, + user: user, + passwd: passwd, + } +} + +func (aw *HttpAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + user, passwd, hasAuth := r.BasicAuth() + if (aw.user == "" && aw.passwd == "") || (hasAuth && user == aw.user && passwd == aw.passwd) { + aw.h.ServeHTTP(w, r) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } +} + +type HttpAuthMiddleware struct { + user string + passwd string +} + +func NewHttpAuthMiddleware(user, passwd string) *HttpAuthMiddleware { + return &HttpAuthMiddleware{ + user: user, + passwd: passwd, + } +} + +func (authMid *HttpAuthMiddleware) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqUser, reqPasswd, hasAuth := r.BasicAuth() + if (authMid.user == "" && authMid.passwd == "") || + (hasAuth && reqUser == authMid.user && reqPasswd == authMid.passwd) { + next.ServeHTTP(w, r) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + }) +} + +func HttpBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + reqUser, reqPasswd, hasAuth := r.BasicAuth() + if (user == "" && passwd == "") || + (hasAuth && reqUser == user && reqPasswd == passwd) { + h.ServeHTTP(w, r) + } else { + w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + } + } +} + +type HttpGzipWraper struct { + h http.Handler +} + +func (gw *HttpGzipWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + gw.h.ServeHTTP(w, r) + return + } + w.Header().Set("Content-Encoding", "gzip") + gz := gzip.NewWriter(w) + defer gz.Close() + gzr := gzipResponseWriter{Writer: gz, ResponseWriter: w} + gw.h.ServeHTTP(gzr, r) +} + +func MakeHttpGzipHandler(h http.Handler) http.Handler { + return &HttpGzipWraper{ + h: h, + } +} + +type gzipResponseWriter struct { + io.Writer + http.ResponseWriter +} + +func (w gzipResponseWriter) Write(b []byte) (int, error) { + return w.Writer.Write(b) +} diff --git a/utils/net/kcp.go b/utils/net/kcp.go new file mode 100644 index 0000000..39eb898 --- /dev/null +++ b/utils/net/kcp.go @@ -0,0 +1,99 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "fmt" + "net" + + kcp "github.com/fatedier/kcp-go" +) + +type KcpListener struct { + listener net.Listener + acceptCh chan net.Conn + closeFlag bool +} + +func ListenKcp(bindAddr string, bindPort int) (l *KcpListener, err error) { + listener, err := kcp.ListenWithOptions(fmt.Sprintf("%s:%d", bindAddr, bindPort), nil, 10, 3) + if err != nil { + return l, err + } + listener.SetReadBuffer(4194304) + listener.SetWriteBuffer(4194304) + + l = &KcpListener{ + listener: listener, + acceptCh: make(chan net.Conn), + closeFlag: false, + } + + go func() { + for { + conn, err := listener.AcceptKCP() + if err != nil { + if l.closeFlag { + close(l.acceptCh) + return + } + continue + } + conn.SetStreamMode(true) + conn.SetWriteDelay(true) + conn.SetNoDelay(1, 20, 2, 1) + conn.SetMtu(1350) + conn.SetWindowSize(1024, 1024) + conn.SetACKNoDelay(false) + + l.acceptCh <- conn + } + }() + return l, err +} + +func (l *KcpListener) Accept() (net.Conn, error) { + conn, ok := <-l.acceptCh + if !ok { + return conn, fmt.Errorf("channel for kcp listener closed") + } + return conn, nil +} + +func (l *KcpListener) Close() error { + if !l.closeFlag { + l.closeFlag = true + l.listener.Close() + } + return nil +} + +func (l *KcpListener) Addr() net.Addr { + return l.listener.Addr() +} + +func NewKcpConnFromUdp(conn *net.UDPConn, connected bool, raddr string) (net.Conn, error) { + kcpConn, err := kcp.NewConnEx(1, connected, raddr, nil, 10, 3, conn) + if err != nil { + return nil, err + } + kcpConn.SetStreamMode(true) + kcpConn.SetWriteDelay(true) + kcpConn.SetNoDelay(1, 20, 2, 1) + kcpConn.SetMtu(1350) + kcpConn.SetWindowSize(1024, 1024) + kcpConn.SetACKNoDelay(false) + return kcpConn, nil +} diff --git a/utils/net/listener.go b/utils/net/listener.go new file mode 100644 index 0000000..3b199c8 --- /dev/null +++ b/utils/net/listener.go @@ -0,0 +1,69 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "fmt" + "net" + "sync" + + "github.com/fatedier/golib/errors" +) + +// Custom listener +type CustomListener struct { + acceptCh chan net.Conn + closed bool + mu sync.Mutex +} + +func NewCustomListener() *CustomListener { + return &CustomListener{ + acceptCh: make(chan net.Conn, 64), + } +} + +func (l *CustomListener) Accept() (net.Conn, error) { + conn, ok := <-l.acceptCh + if !ok { + return nil, fmt.Errorf("listener closed") + } + return conn, nil +} + +func (l *CustomListener) PutConn(conn net.Conn) error { + err := errors.PanicToError(func() { + select { + case l.acceptCh <- conn: + default: + conn.Close() + } + }) + return err +} + +func (l *CustomListener) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + if !l.closed { + close(l.acceptCh) + l.closed = true + } + return nil +} + +func (l *CustomListener) Addr() net.Addr { + return (*net.TCPAddr)(nil) +} diff --git a/utils/net/tls.go b/utils/net/tls.go new file mode 100644 index 0000000..b9fca31 --- /dev/null +++ b/utils/net/tls.go @@ -0,0 +1,52 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "crypto/tls" + "net" + "time" + + gnet "github.com/fatedier/golib/net" +) + +var ( + FRP_TLS_HEAD_BYTE = 0x17 +) + +func WrapTLSClientConn(c net.Conn, tlsConfig *tls.Config) (out net.Conn) { + c.Write([]byte{byte(FRP_TLS_HEAD_BYTE)}) + out = tls.Client(c, tlsConfig) + return +} + +func CheckAndEnableTLSServerConnWithTimeout(c net.Conn, tlsConfig *tls.Config, timeout time.Duration) (out net.Conn, err error) { + sc, r := gnet.NewSharedConnSize(c, 2) + buf := make([]byte, 1) + var n int + c.SetReadDeadline(time.Now().Add(timeout)) + n, err = r.Read(buf) + c.SetReadDeadline(time.Time{}) + if err != nil { + return + } + + if n == 1 && int(buf[0]) == FRP_TLS_HEAD_BYTE { + out = tls.Server(c, tlsConfig) + } else { + out = sc + } + return +} diff --git a/utils/net/udp.go b/utils/net/udp.go new file mode 100644 index 0000000..28a6813 --- /dev/null +++ b/utils/net/udp.go @@ -0,0 +1,256 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package net + +import ( + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/fatedier/golib/pool" +) + +type UdpPacket struct { + Buf []byte + LocalAddr net.Addr + RemoteAddr net.Addr +} + +type FakeUdpConn struct { + l *UdpListener + + localAddr net.Addr + remoteAddr net.Addr + packets chan []byte + closeFlag bool + + lastActive time.Time + mu sync.RWMutex +} + +func NewFakeUdpConn(l *UdpListener, laddr, raddr net.Addr) *FakeUdpConn { + fc := &FakeUdpConn{ + l: l, + localAddr: laddr, + remoteAddr: raddr, + packets: make(chan []byte, 20), + } + + go func() { + for { + time.Sleep(5 * time.Second) + fc.mu.RLock() + if time.Now().Sub(fc.lastActive) > 10*time.Second { + fc.mu.RUnlock() + fc.Close() + break + } + fc.mu.RUnlock() + } + }() + return fc +} + +func (c *FakeUdpConn) putPacket(content []byte) { + defer func() { + if err := recover(); err != nil { + } + }() + + select { + case c.packets <- content: + default: + } +} + +func (c *FakeUdpConn) Read(b []byte) (n int, err error) { + content, ok := <-c.packets + if !ok { + return 0, io.EOF + } + c.mu.Lock() + c.lastActive = time.Now() + c.mu.Unlock() + + if len(b) < len(content) { + n = len(b) + } else { + n = len(content) + } + copy(b, content) + return n, nil +} + +func (c *FakeUdpConn) Write(b []byte) (n int, err error) { + c.mu.RLock() + if c.closeFlag { + c.mu.RUnlock() + return 0, io.ErrClosedPipe + } + c.mu.RUnlock() + + packet := &UdpPacket{ + Buf: b, + LocalAddr: c.localAddr, + RemoteAddr: c.remoteAddr, + } + c.l.writeUdpPacket(packet) + + c.mu.Lock() + c.lastActive = time.Now() + c.mu.Unlock() + return len(b), nil +} + +func (c *FakeUdpConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closeFlag { + c.closeFlag = true + close(c.packets) + } + return nil +} + +func (c *FakeUdpConn) IsClosed() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.closeFlag +} + +func (c *FakeUdpConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *FakeUdpConn) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *FakeUdpConn) SetDeadline(t time.Time) error { + return nil +} + +func (c *FakeUdpConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (c *FakeUdpConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type UdpListener struct { + addr net.Addr + acceptCh chan net.Conn + writeCh chan *UdpPacket + readConn net.Conn + closeFlag bool + + fakeConns map[string]*FakeUdpConn +} + +func ListenUDP(bindAddr string, bindPort int) (l *UdpListener, err error) { + udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) + if err != nil { + return l, err + } + readConn, err := net.ListenUDP("udp", udpAddr) + + l = &UdpListener{ + addr: udpAddr, + acceptCh: make(chan net.Conn), + writeCh: make(chan *UdpPacket, 1000), + fakeConns: make(map[string]*FakeUdpConn), + } + + // for reading + go func() { + for { + buf := pool.GetBuf(1450) + n, remoteAddr, err := readConn.ReadFromUDP(buf) + if err != nil { + close(l.acceptCh) + close(l.writeCh) + return + } + + fakeConn, exist := l.fakeConns[remoteAddr.String()] + if !exist || fakeConn.IsClosed() { + fakeConn = NewFakeUdpConn(l, l.Addr(), remoteAddr) + l.fakeConns[remoteAddr.String()] = fakeConn + } + fakeConn.putPacket(buf[:n]) + + l.acceptCh <- fakeConn + } + }() + + // for writing + go func() { + for { + packet, ok := <-l.writeCh + if !ok { + return + } + + if addr, ok := packet.RemoteAddr.(*net.UDPAddr); ok { + readConn.WriteToUDP(packet.Buf, addr) + } + } + }() + + return +} + +func (l *UdpListener) writeUdpPacket(packet *UdpPacket) (err error) { + defer func() { + if errRet := recover(); errRet != nil { + err = fmt.Errorf("udp write closed listener") + } + }() + l.writeCh <- packet + return +} + +func (l *UdpListener) WriteMsg(buf []byte, remoteAddr *net.UDPAddr) (err error) { + // only set remote addr here + packet := &UdpPacket{ + Buf: buf, + RemoteAddr: remoteAddr, + } + err = l.writeUdpPacket(packet) + return +} + +func (l *UdpListener) Accept() (net.Conn, error) { + conn, ok := <-l.acceptCh + if !ok { + return conn, fmt.Errorf("channel for udp listener closed") + } + return conn, nil +} + +func (l *UdpListener) Close() error { + if !l.closeFlag { + l.closeFlag = true + l.readConn.Close() + } + return nil +} + +func (l *UdpListener) Addr() net.Addr { + return l.addr +} diff --git a/utils/net/websocket.go b/utils/net/websocket.go new file mode 100644 index 0000000..36b6440 --- /dev/null +++ b/utils/net/websocket.go @@ -0,0 +1,103 @@ +package net + +import ( + "errors" + "fmt" + "net" + "net/http" + "net/url" + "time" + + "golang.org/x/net/websocket" +) + +var ( + ErrWebsocketListenerClosed = errors.New("websocket listener closed") +) + +const ( + FrpWebsocketPath = "/~!frp" +) + +type WebsocketListener struct { + ln net.Listener + acceptCh chan net.Conn + + server *http.Server + httpMutex *http.ServeMux +} + +// NewWebsocketListener to handle websocket connections +// ln: tcp listener for websocket connections +func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) { + wl = &WebsocketListener{ + acceptCh: make(chan net.Conn), + } + + muxer := http.NewServeMux() + muxer.Handle(FrpWebsocketPath, websocket.Handler(func(c *websocket.Conn) { + notifyCh := make(chan struct{}) + conn := WrapCloseNotifyConn(c, func() { + close(notifyCh) + }) + wl.acceptCh <- conn + <-notifyCh + })) + + wl.server = &http.Server{ + Addr: ln.Addr().String(), + Handler: muxer, + } + + go wl.server.Serve(ln) + return +} + +func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) { + tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort)) + if err != nil { + return nil, err + } + l := NewWebsocketListener(tcpLn) + return l, nil +} + +func (p *WebsocketListener) Accept() (net.Conn, error) { + c, ok := <-p.acceptCh + if !ok { + return nil, ErrWebsocketListenerClosed + } + return c, nil +} + +func (p *WebsocketListener) Close() error { + return p.server.Close() +} + +func (p *WebsocketListener) Addr() net.Addr { + return p.ln.Addr() +} + +// addr: domain:port +func ConnectWebsocketServer(addr string) (net.Conn, error) { + addr = "ws://" + addr + FrpWebsocketPath + uri, err := url.Parse(addr) + if err != nil { + return nil, err + } + + origin := "http://" + uri.Host + cfg, err := websocket.NewConfig(addr, origin) + if err != nil { + return nil, err + } + cfg.Dialer = &net.Dialer{ + Timeout: 10 * time.Second, + } + + conn, err := websocket.DialConfig(cfg) + if err != nil { + return nil, err + } + return conn, nil +} diff --git a/utils/util/util.go b/utils/util/util.go new file mode 100644 index 0000000..7ea4e83 --- /dev/null +++ b/utils/util/util.go @@ -0,0 +1,103 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package util + +import ( + "crypto/md5" + "crypto/rand" + "encoding/hex" + "fmt" + "strconv" + "strings" +) + +// RandId return a rand string used in frp. +func RandId() (id string, err error) { + return RandIdWithLen(8) +} + +// RandIdWithLen return a rand string with idLen length. +func RandIdWithLen(idLen int) (id string, err error) { + b := make([]byte, idLen) + _, err = rand.Read(b) + if err != nil { + return + } + + id = fmt.Sprintf("%x", b) + return +} + +func GetAuthKey(token string, timestamp int64) (key string) { + token = token + fmt.Sprintf("%d", timestamp) + md5Ctx := md5.New() + md5Ctx.Write([]byte(token)) + data := md5Ctx.Sum(nil) + return hex.EncodeToString(data) +} + +func CanonicalAddr(host string, port int) (addr string) { + if port == 80 || port == 443 { + addr = host + } else { + addr = fmt.Sprintf("%s:%d", host, port) + } + return +} + +func ParseRangeNumbers(rangeStr string) (numbers []int64, err error) { + rangeStr = strings.TrimSpace(rangeStr) + numbers = make([]int64, 0) + // e.g. 1000-2000,2001,2002,3000-4000 + numRanges := strings.Split(rangeStr, ",") + for _, numRangeStr := range numRanges { + // 1000-2000 or 2001 + numArray := strings.Split(numRangeStr, "-") + // length: only 1 or 2 is correct + rangeType := len(numArray) + if rangeType == 1 { + // single number + singleNum, errRet := strconv.ParseInt(strings.TrimSpace(numArray[0]), 10, 64) + if errRet != nil { + err = fmt.Errorf("range number is invalid, %v", errRet) + return + } + numbers = append(numbers, singleNum) + } else if rangeType == 2 { + // range numbers + min, errRet := strconv.ParseInt(strings.TrimSpace(numArray[0]), 10, 64) + if errRet != nil { + err = fmt.Errorf("range number is invalid, %v", errRet) + return + } + max, errRet := strconv.ParseInt(strings.TrimSpace(numArray[1]), 10, 64) + if errRet != nil { + err = fmt.Errorf("range number is invalid, %v", errRet) + return + } + if max < min { + err = fmt.Errorf("range number is invalid") + return + } + for i := min; i <= max; i++ { + numbers = append(numbers, i) + } + } else { + err = fmt.Errorf("range number is invalid") + return + } + } + return +} diff --git a/utils/util/util_test.go b/utils/util/util_test.go new file mode 100644 index 0000000..a7518f6 --- /dev/null +++ b/utils/util/util_test.go @@ -0,0 +1,48 @@ +package util + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRandId(t *testing.T) { + assert := assert.New(t) + id, err := RandId() + assert.NoError(err) + t.Log(id) + assert.Equal(16, len(id)) +} + +func TestGetAuthKey(t *testing.T) { + assert := assert.New(t) + key := GetAuthKey("1234", 1488720000) + t.Log(key) + assert.Equal("6df41a43725f0c770fd56379e12acf8c", key) +} + +func TestParseRangeNumbers(t *testing.T) { + assert := assert.New(t) + numbers, err := ParseRangeNumbers("2-5") + if assert.NoError(err) { + assert.Equal([]int64{2, 3, 4, 5}, numbers) + } + + numbers, err = ParseRangeNumbers("1") + if assert.NoError(err) { + assert.Equal([]int64{1}, numbers) + } + + numbers, err = ParseRangeNumbers("3-5,8") + if assert.NoError(err) { + assert.Equal([]int64{3, 4, 5, 8}, numbers) + } + + numbers, err = ParseRangeNumbers(" 3-5,8, 10-12 ") + if assert.NoError(err) { + assert.Equal([]int64{3, 4, 5, 8, 10, 11, 12}, numbers) + } + + _, err = ParseRangeNumbers("3-a") + assert.Error(err) +} diff --git a/utils/version/version.go b/utils/version/version.go new file mode 100644 index 0000000..dac3974 --- /dev/null +++ b/utils/version/version.go @@ -0,0 +1,82 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package version + +import ( + "strconv" + "strings" +) + +var version string = "0.31.1" + +func Full() string { + return version +} + +func getSubVersion(v string, position int) int64 { + arr := strings.Split(v, ".") + if len(arr) < 3 { + return 0 + } + res, _ := strconv.ParseInt(arr[position], 10, 64) + return res +} + +func Proto(v string) int64 { + return getSubVersion(v, 0) +} + +func Major(v string) int64 { + return getSubVersion(v, 1) +} + +func Minor(v string) int64 { + return getSubVersion(v, 2) +} + +// add every case there if server will not accept client's protocol and return false +func Compat(client string) (ok bool, msg string) { + if LessThan(client, "0.18.0") { + return false, "Please upgrade your frpc version to at least 0.18.0" + } + return true, "" +} + +func LessThan(client string, server string) bool { + vc := Proto(client) + vs := Proto(server) + if vc > vs { + return false + } else if vc < vs { + return true + } + + vc = Major(client) + vs = Major(server) + if vc > vs { + return false + } else if vc < vs { + return true + } + + vc = Minor(client) + vs = Minor(server) + if vc > vs { + return false + } else if vc < vs { + return true + } + return false +} diff --git a/utils/version/version_test.go b/utils/version/version_test.go new file mode 100644 index 0000000..a77bf42 --- /dev/null +++ b/utils/version/version_test.go @@ -0,0 +1,65 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package version + +import ( + "fmt" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFull(t *testing.T) { + assert := assert.New(t) + version := Full() + arr := strings.Split(version, ".") + assert.Equal(3, len(arr)) + + proto, err := strconv.ParseInt(arr[0], 10, 64) + assert.NoError(err) + assert.True(proto >= 0) + + major, err := strconv.ParseInt(arr[1], 10, 64) + assert.NoError(err) + assert.True(major >= 0) + + minor, err := strconv.ParseInt(arr[2], 10, 64) + assert.NoError(err) + assert.True(minor >= 0) +} + +func TestVersion(t *testing.T) { + assert := assert.New(t) + proto := Proto(Full()) + major := Major(Full()) + minor := Minor(Full()) + parseVerion := fmt.Sprintf("%d.%d.%d", proto, major, minor) + version := Full() + assert.Equal(parseVerion, version) +} + +func TestCompact(t *testing.T) { + assert := assert.New(t) + ok, _ := Compat("0.9.0") + assert.False(ok) + + ok, _ = Compat("10.0.0") + assert.True(ok) + + ok, _ = Compat("0.10.0") + assert.False(ok) +} diff --git a/utils/vhost/http.go b/utils/vhost/http.go new file mode 100644 index 0000000..85c86ac --- /dev/null +++ b/utils/vhost/http.go @@ -0,0 +1,216 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + "bytes" + "context" + "errors" + "fmt" + "log" + "net" + "net/http" + "strings" + "time" + + frpLog "github.com/charlesbao/frpc/utils/log" + + "github.com/fatedier/golib/pool" +) + +var ( + ErrNoDomain = errors.New("no such domain") +) + +func getHostFromAddr(addr string) (host string) { + strs := strings.Split(addr, ":") + if len(strs) > 1 { + host = strs[0] + } else { + host = addr + } + return +} + +type HttpReverseProxyOptions struct { + ResponseHeaderTimeoutS int64 +} + +type HttpReverseProxy struct { + proxy *ReverseProxy + vhostRouter *VhostRouters + + responseHeaderTimeout time.Duration +} + +func NewHttpReverseProxy(option HttpReverseProxyOptions, vhostRouter *VhostRouters) *HttpReverseProxy { + if option.ResponseHeaderTimeoutS <= 0 { + option.ResponseHeaderTimeoutS = 60 + } + rp := &HttpReverseProxy{ + responseHeaderTimeout: time.Duration(option.ResponseHeaderTimeoutS) * time.Second, + vhostRouter: vhostRouter, + } + proxy := &ReverseProxy{ + Director: func(req *http.Request) { + req.URL.Scheme = "http" + url := req.Context().Value("url").(string) + oldHost := getHostFromAddr(req.Context().Value("host").(string)) + host := rp.GetRealHost(oldHost, url) + if host != "" { + req.Host = host + } + req.URL.Host = req.Host + + headers := rp.GetHeaders(oldHost, url) + for k, v := range headers { + req.Header.Set(k, v) + } + }, + Transport: &http.Transport{ + ResponseHeaderTimeout: rp.responseHeaderTimeout, + DisableKeepAlives: true, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + url := ctx.Value("url").(string) + host := getHostFromAddr(ctx.Value("host").(string)) + remote := ctx.Value("remote").(string) + return rp.CreateConnection(host, url, remote) + }, + }, + BufferPool: newWrapPool(), + ErrorLog: log.New(newWrapLogger(), "", 0), + ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { + frpLog.Warn("do http proxy request error: %v", err) + rw.WriteHeader(http.StatusNotFound) + rw.Write(getNotFoundPageContent()) + }, + } + rp.proxy = proxy + return rp +} + +// Register register the route config to reverse proxy +// reverse proxy will use CreateConnFn from routeCfg to create a connection to the remote service +func (rp *HttpReverseProxy) Register(routeCfg VhostRouteConfig) error { + err := rp.vhostRouter.Add(routeCfg.Domain, routeCfg.Location, &routeCfg) + if err != nil { + return err + } + return nil +} + +// UnRegister unregister route config by domain and location +func (rp *HttpReverseProxy) UnRegister(domain string, location string) { + rp.vhostRouter.Del(domain, location) +} + +func (rp *HttpReverseProxy) GetRealHost(domain string, location string) (host string) { + vr, ok := rp.getVhost(domain, location) + if ok { + host = vr.payload.(*VhostRouteConfig).RewriteHost + } + return +} + +func (rp *HttpReverseProxy) GetHeaders(domain string, location string) (headers map[string]string) { + vr, ok := rp.getVhost(domain, location) + if ok { + headers = vr.payload.(*VhostRouteConfig).Headers + } + return +} + +// CreateConnection create a new connection by route config +func (rp *HttpReverseProxy) CreateConnection(domain string, location string, remoteAddr string) (net.Conn, error) { + vr, ok := rp.getVhost(domain, location) + if ok { + fn := vr.payload.(*VhostRouteConfig).CreateConnFn + if fn != nil { + return fn(remoteAddr) + } + } + return nil, fmt.Errorf("%v: %s %s", ErrNoDomain, domain, location) +} + +func (rp *HttpReverseProxy) CheckAuth(domain, location, user, passwd string) bool { + vr, ok := rp.getVhost(domain, location) + if ok { + checkUser := vr.payload.(*VhostRouteConfig).Username + checkPasswd := vr.payload.(*VhostRouteConfig).Password + if (checkUser != "" || checkPasswd != "") && (checkUser != user || checkPasswd != passwd) { + return false + } + } + return true +} + +// getVhost get vhost router by domain and location +func (rp *HttpReverseProxy) getVhost(domain string, location string) (vr *VhostRouter, ok bool) { + // first we check the full hostname + // if not exist, then check the wildcard_domain such as *.example.com + vr, ok = rp.vhostRouter.Get(domain, location) + if ok { + return + } + + domainSplit := strings.Split(domain, ".") + if len(domainSplit) < 3 { + return nil, false + } + + for { + if len(domainSplit) < 3 { + return nil, false + } + + domainSplit[0] = "*" + domain = strings.Join(domainSplit, ".") + vr, ok = rp.vhostRouter.Get(domain, location) + if ok { + return vr, true + } + domainSplit = domainSplit[1:] + } + return +} + +func (rp *HttpReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + domain := getHostFromAddr(req.Host) + location := req.URL.Path + user, passwd, _ := req.BasicAuth() + if !rp.CheckAuth(domain, location, user, passwd) { + rw.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) + http.Error(rw, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + rp.proxy.ServeHTTP(rw, req) +} + +type wrapPool struct{} + +func newWrapPool() *wrapPool { return &wrapPool{} } + +func (p *wrapPool) Get() []byte { return pool.GetBuf(32 * 1024) } + +func (p *wrapPool) Put(buf []byte) { pool.PutBuf(buf) } + +type wrapLogger struct{} + +func newWrapLogger() *wrapLogger { return &wrapLogger{} } + +func (l *wrapLogger) Write(p []byte) (n int, err error) { + frpLog.Warn("%s", string(bytes.TrimRight(p, "\n"))) + return len(p), nil +} diff --git a/utils/vhost/https.go b/utils/vhost/https.go new file mode 100644 index 0000000..5317701 --- /dev/null +++ b/utils/vhost/https.go @@ -0,0 +1,194 @@ +// Copyright 2016 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + "fmt" + "io" + "net" + "strings" + "time" + + gnet "github.com/fatedier/golib/net" + "github.com/fatedier/golib/pool" +) + +const ( + typeClientHello uint8 = 1 // Type client hello +) + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionALPN uint16 = 16 + extensionSCT uint16 = 18 + extensionSessionTicket uint16 = 35 + extensionNextProtoNeg uint16 = 13172 // not IANA assigned + extensionRenegotiationInfo uint16 = 0xff01 +) + +type HttpsMuxer struct { + *VhostMuxer +} + +func NewHttpsMuxer(listener net.Listener, timeout time.Duration) (*HttpsMuxer, error) { + mux, err := NewVhostMuxer(listener, GetHttpsHostname, nil, nil, timeout) + return &HttpsMuxer{mux}, err +} + +func readHandshake(rd io.Reader) (host string, err error) { + data := pool.GetBuf(1024) + origin := data + defer pool.PutBuf(origin) + + _, err = io.ReadFull(rd, data[:47]) + if err != nil { + return + } + + length, err := rd.Read(data[47:]) + if err != nil { + return + } else { + length += 47 + } + data = data[:length] + if uint8(data[5]) != typeClientHello { + err = fmt.Errorf("readHandshake: type[%d] is not clientHello", uint16(data[5])) + return + } + + // session + sessionIdLen := int(data[43]) + if sessionIdLen > 32 || len(data) < 44+sessionIdLen { + err = fmt.Errorf("readHandshake: sessionIdLen[%d] is long", sessionIdLen) + return + } + data = data[44+sessionIdLen:] + if len(data) < 2 { + err = fmt.Errorf("readHandshake: dataLen[%d] after session is short", len(data)) + return + } + + // cipher suite numbers + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + err = fmt.Errorf("readHandshake: dataLen[%d] after cipher suite is short", len(data)) + return + } + data = data[2+cipherSuiteLen:] + if len(data) < 1 { + err = fmt.Errorf("readHandshake: cipherSuiteLen[%d] is long", cipherSuiteLen) + return + } + + // compression method + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + err = fmt.Errorf("readHandshake: compressionMethodsLen[%d] is long", compressionMethodsLen) + return + } + + data = data[1+compressionMethodsLen:] + if len(data) == 0 { + // ClientHello is optionally followed by extension data + err = fmt.Errorf("readHandshake: there is no extension data to get servername") + return + } + if len(data) < 2 { + err = fmt.Errorf("readHandshake: extension dataLen[%d] is too short", len(data)) + return + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if extensionsLength != len(data) { + err = fmt.Errorf("readHandshake: extensionsLen[%d] is not equal to dataLen[%d]", extensionsLength, len(data)) + return + } + for len(data) != 0 { + if len(data) < 4 { + err = fmt.Errorf("readHandshake: extensionsDataLen[%d] is too short", len(data)) + return + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + err = fmt.Errorf("readHandshake: extensionLen[%d] is long", length) + return + } + + switch extension { + case extensionRenegotiationInfo: + if length != 1 || data[0] != 0 { + err = fmt.Errorf("readHandshake: extension reNegotiationInfoLen[%d] is short", length) + return + } + case extensionNextProtoNeg: + case extensionStatusRequest: + case extensionServerName: + d := data[:length] + if len(d) < 2 { + err = fmt.Errorf("readHandshake: remiaining dataLen[%d] is short", len(d)) + return + } + namesLen := int(d[0])<<8 | int(d[1]) + d = d[2:] + if len(d) != namesLen { + err = fmt.Errorf("readHandshake: nameListLen[%d] is not equal to dataLen[%d]", namesLen, len(d)) + return + } + for len(d) > 0 { + if len(d) < 3 { + err = fmt.Errorf("readHandshake: extension serverNameLen[%d] is short", len(d)) + return + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + err = fmt.Errorf("readHandshake: nameLen[%d] is not equal to dataLen[%d]", nameLen, len(d)) + return + } + if nameType == 0 { + serverName := string(d[:nameLen]) + host = strings.TrimSpace(serverName) + return host, nil + } + d = d[nameLen:] + } + } + data = data[length:] + } + err = fmt.Errorf("Unknow error") + return +} + +func GetHttpsHostname(c net.Conn) (_ net.Conn, _ map[string]string, err error) { + reqInfoMap := make(map[string]string, 0) + sc, rd := gnet.NewSharedConn(c) + host, err := readHandshake(rd) + if err != nil { + return nil, reqInfoMap, err + } + reqInfoMap["Host"] = host + reqInfoMap["Scheme"] = "https" + return sc, reqInfoMap, nil +} diff --git a/utils/vhost/resource.go b/utils/vhost/resource.go new file mode 100644 index 0000000..6f5ec92 --- /dev/null +++ b/utils/vhost/resource.go @@ -0,0 +1,122 @@ +// Copyright 2017 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + "bytes" + "io/ioutil" + "net/http" + + frpLog "github.com/charlesbao/frpc/utils/log" + "github.com/charlesbao/frpc/utils/version" +) + +var ( + NotFoundPagePath = "" +) + +const ( + NotFound = ` + + + Not Found + + + +
😭
+

Ooooops!

+

404 NOT FOUND.

+ + +` +) + +func getNotFoundPageContent() []byte { + var ( + buf []byte + err error + ) + if NotFoundPagePath != "" { + buf, err = ioutil.ReadFile(NotFoundPagePath) + if err != nil { + frpLog.Warn("read custom 404 page error: %v", err) + buf = []byte(NotFound) + } + } else { + buf = []byte(NotFound) + } + return buf +} + +func notFoundResponse() *http.Response { + header := make(http.Header) + header.Set("server", "frp/"+version.Full()) + header.Set("Content-Type", "text/html") + + res := &http.Response{ + Status: "Not Found", + StatusCode: 404, + Proto: "HTTP/1.0", + ProtoMajor: 1, + ProtoMinor: 0, + Header: header, + Body: ioutil.NopCloser(bytes.NewReader(getNotFoundPageContent())), + } + return res +} + +func noAuthResponse() *http.Response { + header := make(map[string][]string) + header["WWW-Authenticate"] = []string{`Basic realm="Restricted"`} + res := &http.Response{ + Status: "401 Not authorized", + StatusCode: 401, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: header, + } + return res +} diff --git a/utils/vhost/reverseproxy.go b/utils/vhost/reverseproxy.go new file mode 100644 index 0000000..f606f48 --- /dev/null +++ b/utils/vhost/reverseproxy.go @@ -0,0 +1,563 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package vhost + +import ( + "context" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "golang.org/x/net/http/httpguts" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + // Director must not access the provided Request + // after returning. + Director func(*http.Request) + + // The transport used to perform proxy requests. + // If nil, http.DefaultTransport is used. + Transport http.RoundTripper + + // FlushInterval specifies the flush interval + // to flush to the client while copying the + // response body. + // If zero, no periodic flushing is done. + // A negative value means to flush immediately + // after each write to the client. + // The FlushInterval is ignored when ReverseProxy + // recognizes a response as a streaming response; + // for such responses, writes are flushed to the client + // immediately. + FlushInterval time.Duration + + // ErrorLog specifies an optional logger for errors + // that occur when attempting to proxy the request. + // If nil, logging is done via the log package's standard logger. + ErrorLog *log.Logger + + // BufferPool optionally specifies a buffer pool to + // get byte slices for use by io.CopyBuffer when + // copying HTTP response bodies. + BufferPool BufferPool + + // ModifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional ErrorHandler is + // called without any call to ModifyResponse. + // + // If ModifyResponse returns an error, ErrorHandler is called + // with its error value. If ErrorHandler is nil, its default + // implementation is used. + ModifyResponse func(*http.Response) error + + // ErrorHandler is an optional function that handles errors + // reaching the backend or errors from ModifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + ErrorHandler func(http.ResponseWriter, *http.Request, error) +} + +// A BufferPool is an interface for getting and returning temporary +// byte slices for use by io.CopyBuffer. +type BufferPool interface { + Get() []byte + Put([]byte) +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// NewSingleHostReverseProxy does not rewrite the Host header. +// To rewrite Host headers, use ReverseProxy directly with a custom +// Director policy. +func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { + targetQuery := target.RawQuery + director := func(req *http.Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if targetQuery == "" || req.URL.RawQuery == "" { + req.URL.RawQuery = targetQuery + req.URL.RawQuery + } else { + req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery + } + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } + } + return &ReverseProxy{Director: director} +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) +} + +func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { + if p.ErrorHandler != nil { + return p.ErrorHandler + } + return p.defaultErrorHandler +} + +// modifyResponse conditionally runs the optional ModifyResponse hook +// and reports whether the request should proceed. +func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool { + if p.ModifyResponse == nil { + return true + } + if err := p.ModifyResponse(res); err != nil { + res.Body.Close() + p.getErrorHandler()(rw, req, err) + return false + } + return true +} + +func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + transport := p.Transport + if transport == nil { + transport = http.DefaultTransport + } + + ctx := req.Context() + if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + + outreq := req.WithContext(ctx) + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries + } + + // ============================= + // Modified for frp + outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path)) + outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host)) + outreq = outreq.WithContext(context.WithValue(outreq.Context(), "remote", req.RemoteAddr)) + // ============================= + + p.Director(outreq) + outreq.Close = false + + reqUpType := upgradeType(outreq.Header) + removeConnectionHeaders(outreq.Header) + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. + for _, h := range hopHeaders { + hv := outreq.Header.Get(h) + if hv == "" { + continue + } + if h == "Te" && hv == "trailers" { + // Issue 21096: tell backend applications that + // care about trailer support that we support + // trailers. (We do, but we don't go out of + // our way to advertise that unless the + // incoming client request thought it was + // worth mentioning) + continue + } + outreq.Header.Del(h) + } + + // After stripping all the hop-by-hop connection headers above, add back any + // necessary for protocol upgrades, such as for websockets. + if reqUpType != "" { + outreq.Header.Set("Connection", "Upgrade") + outreq.Header.Set("Upgrade", reqUpType) + } + + if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := outreq.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + outreq.Header.Set("X-Forwarded-For", clientIP) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + p.getErrorHandler()(rw, outreq, err) + return + } + + // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc) + if res.StatusCode == http.StatusSwitchingProtocols { + if !p.modifyResponse(rw, res, outreq) { + return + } + p.handleUpgradeResponse(rw, outreq, res) + return + } + + removeConnectionHeaders(res.Header) + + for _, h := range hopHeaders { + res.Header.Del(h) + } + + if !p.modifyResponse(rw, res, outreq) { + return + } + + copyHeader(rw.Header(), res.Header) + + // The "Trailer" header isn't included in the Transport's response, + // at least for *http.Transport. Build it up from Trailer. + announcedTrailers := len(res.Trailer) + if announcedTrailers > 0 { + trailerKeys := make([]string, 0, len(res.Trailer)) + for k := range res.Trailer { + trailerKeys = append(trailerKeys, k) + } + rw.Header().Add("Trailer", strings.Join(trailerKeys, ", ")) + } + + rw.WriteHeader(res.StatusCode) + + err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + if err != nil { + defer res.Body.Close() + // Since we're streaming the response, if we run into an error all we can do + // is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler + // on read error while copying body. + if !shouldPanicOnCopyError(req) { + p.logf("suppressing panic for copyResponse error in test; copy error: %v", err) + return + } + panic(http.ErrAbortHandler) + } + res.Body.Close() // close now, instead of defer, to populate res.Trailer + + if len(res.Trailer) > 0 { + // Force chunking if we saw a response trailer. + // This prevents net/http from calculating the length for short + // bodies and adding a Content-Length. + if fl, ok := rw.(http.Flusher); ok { + fl.Flush() + } + } + + if len(res.Trailer) == announcedTrailers { + copyHeader(rw.Header(), res.Trailer) + return + } + + for k, vv := range res.Trailer { + k = http.TrailerPrefix + k + for _, v := range vv { + rw.Header().Add(k, v) + } + } +} + +var inOurTests bool // whether we're in our own tests + +// shouldPanicOnCopyError reports whether the reverse proxy should +// panic with http.ErrAbortHandler. This is the right thing to do by +// default, but Go 1.10 and earlier did not, so existing unit tests +// weren't expecting panics. Only panic in our own tests, or when +// running under the HTTP server. +func shouldPanicOnCopyError(req *http.Request) bool { + if inOurTests { + // Our tests know to handle this panic. + return true + } + if req.Context().Value(http.ServerContextKey) != nil { + // We seem to be running under an HTTP server, so + // it'll recover the panic. + return true + } + // Otherwise act like Go 1.10 and earlier to not break + // existing tests. + return false +} + +// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeConnectionHeaders(h http.Header) { + for _, f := range h["Connection"] { + for _, sf := range strings.Split(f, ",") { + if sf = strings.TrimSpace(sf); sf != "" { + h.Del(sf) + } + } + } +} + +// flushInterval returns the p.FlushInterval value, conditionally +// overriding its value for a specific request/response. +func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { + resCT := res.Header.Get("Content-Type") + + // For Server-Sent Events responses, flush immediately. + // The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream + if resCT == "text/event-stream" { + return -1 // negative means immediately + } + + // TODO: more specific cases? e.g. res.ContentLength == -1? + return p.FlushInterval +} + +func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error { + if flushInterval != 0 { + if wf, ok := dst.(writeFlusher); ok { + mlw := &maxLatencyWriter{ + dst: wf, + latency: flushInterval, + } + defer mlw.stop() + + // set up initial timer so headers get flushed even if body writes are delayed + mlw.flushPending = true + mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush) + + dst = mlw + } + } + + var buf []byte + if p.BufferPool != nil { + buf = p.BufferPool.Get() + defer p.BufferPool.Put(buf) + } + _, err := p.copyBuffer(dst, src, buf) + return err +} + +// copyBuffer returns any write errors or non-EOF read errors, and the amount +// of bytes written. +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF && rerr != context.Canceled { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + if rerr == io.EOF { + rerr = nil + } + return written, rerr + } + } +} + +func (p *ReverseProxy) logf(format string, args ...interface{}) { + if p.ErrorLog != nil { + p.ErrorLog.Printf(format, args...) + } else { + log.Printf(format, args...) + } +} + +type writeFlusher interface { + io.Writer + http.Flusher +} + +type maxLatencyWriter struct { + dst writeFlusher + latency time.Duration // non-zero; negative means to flush immediately + + mu sync.Mutex // protects t, flushPending, and dst.Flush + t *time.Timer + flushPending bool +} + +func (m *maxLatencyWriter) Write(p []byte) (n int, err error) { + m.mu.Lock() + defer m.mu.Unlock() + n, err = m.dst.Write(p) + if m.latency < 0 { + m.dst.Flush() + return + } + if m.flushPending { + return + } + if m.t == nil { + m.t = time.AfterFunc(m.latency, m.delayedFlush) + } else { + m.t.Reset(m.latency) + } + m.flushPending = true + return +} + +func (m *maxLatencyWriter) delayedFlush() { + m.mu.Lock() + defer m.mu.Unlock() + if !m.flushPending { // if stop was called but AfterFunc already started this goroutine + return + } + m.dst.Flush() + m.flushPending = false +} + +func (m *maxLatencyWriter) stop() { + m.mu.Lock() + defer m.mu.Unlock() + m.flushPending = false + if m.t != nil { + m.t.Stop() + } +} + +func upgradeType(h http.Header) string { + if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") { + return "" + } + return strings.ToLower(h.Get("Upgrade")) +} + +func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) { + reqUpType := upgradeType(req.Header) + resUpType := upgradeType(res.Header) + if reqUpType != resUpType { + p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType)) + return + } + + copyHeader(res.Header, rw.Header()) + + hj, ok := rw.(http.Hijacker) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) + return + } + backConn, ok := res.Body.(io.ReadWriteCloser) + if !ok { + p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body")) + return + } + defer backConn.Close() + conn, brw, err := hj.Hijack() + if err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err)) + return + } + defer conn.Close() + res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above + if err := res.Write(brw); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) + return + } + if err := brw.Flush(); err != nil { + p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err)) + return + } + errc := make(chan error, 1) + spc := switchProtocolCopier{user: conn, backend: backConn} + go spc.copyToBackend(errc) + go spc.copyFromBackend(errc) + <-errc + return +} + +// switchProtocolCopier exists so goroutines proxying data back and +// forth have nice names in stacks. +type switchProtocolCopier struct { + user, backend io.ReadWriter +} + +func (c switchProtocolCopier) copyFromBackend(errc chan<- error) { + _, err := io.Copy(c.user, c.backend) + errc <- err +} + +func (c switchProtocolCopier) copyToBackend(errc chan<- error) { + _, err := io.Copy(c.backend, c.user) + errc <- err +} diff --git a/utils/vhost/router.go b/utils/vhost/router.go new file mode 100644 index 0000000..bfdcb50 --- /dev/null +++ b/utils/vhost/router.go @@ -0,0 +1,119 @@ +package vhost + +import ( + "errors" + "sort" + "strings" + "sync" +) + +var ( + ErrRouterConfigConflict = errors.New("router config conflict") +) + +type VhostRouters struct { + RouterByDomain map[string][]*VhostRouter + mutex sync.RWMutex +} + +type VhostRouter struct { + domain string + location string + + payload interface{} +} + +func NewVhostRouters() *VhostRouters { + return &VhostRouters{ + RouterByDomain: make(map[string][]*VhostRouter), + } +} + +func (r *VhostRouters) Add(domain, location string, payload interface{}) error { + r.mutex.Lock() + defer r.mutex.Unlock() + + if _, exist := r.exist(domain, location); exist { + return ErrRouterConfigConflict + } + + vrs, found := r.RouterByDomain[domain] + if !found { + vrs = make([]*VhostRouter, 0, 1) + } + + vr := &VhostRouter{ + domain: domain, + location: location, + payload: payload, + } + vrs = append(vrs, vr) + + sort.Sort(sort.Reverse(ByLocation(vrs))) + r.RouterByDomain[domain] = vrs + return nil +} + +func (r *VhostRouters) Del(domain, location string) { + r.mutex.Lock() + defer r.mutex.Unlock() + + vrs, found := r.RouterByDomain[domain] + if !found { + return + } + newVrs := make([]*VhostRouter, 0) + for _, vr := range vrs { + if vr.location != location { + newVrs = append(newVrs, vr) + } + } + r.RouterByDomain[domain] = newVrs +} + +func (r *VhostRouters) Get(host, path string) (vr *VhostRouter, exist bool) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + vrs, found := r.RouterByDomain[host] + if !found { + return + } + + // can't support load balance, will to do + for _, vr = range vrs { + if strings.HasPrefix(path, vr.location) { + return vr, true + } + } + + return +} + +func (r *VhostRouters) exist(host, path string) (vr *VhostRouter, exist bool) { + vrs, found := r.RouterByDomain[host] + if !found { + return + } + + for _, vr = range vrs { + if path == vr.location { + return vr, true + } + } + + return +} + +// sort by location +type ByLocation []*VhostRouter + +func (a ByLocation) Len() int { + return len(a) +} +func (a ByLocation) Swap(i, j int) { + a[i], a[j] = a[j], a[i] +} +func (a ByLocation) Less(i, j int) bool { + return strings.Compare(a[i].location, a[j].location) < 0 +} diff --git a/utils/vhost/vhost.go b/utils/vhost/vhost.go new file mode 100644 index 0000000..88035cc --- /dev/null +++ b/utils/vhost/vhost.go @@ -0,0 +1,227 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vhost + +import ( + "context" + "fmt" + "net" + "strings" + "time" + + "github.com/charlesbao/frpc/utils/log" + frpNet "github.com/charlesbao/frpc/utils/net" + "github.com/charlesbao/frpc/utils/xlog" + + "github.com/fatedier/golib/errors" +) + +type muxFunc func(net.Conn) (net.Conn, map[string]string, error) +type httpAuthFunc func(net.Conn, string, string, string) (bool, error) +type hostRewriteFunc func(net.Conn, string) (net.Conn, error) + +type VhostMuxer struct { + listener net.Listener + timeout time.Duration + vhostFunc muxFunc + authFunc httpAuthFunc + rewriteFunc hostRewriteFunc + registryRouter *VhostRouters +} + +func NewVhostMuxer(listener net.Listener, vhostFunc muxFunc, authFunc httpAuthFunc, rewriteFunc hostRewriteFunc, timeout time.Duration) (mux *VhostMuxer, err error) { + mux = &VhostMuxer{ + listener: listener, + timeout: timeout, + vhostFunc: vhostFunc, + authFunc: authFunc, + rewriteFunc: rewriteFunc, + registryRouter: NewVhostRouters(), + } + go mux.run() + return mux, nil +} + +type CreateConnFunc func(remoteAddr string) (net.Conn, error) + +// VhostRouteConfig is the params used to match HTTP requests +type VhostRouteConfig struct { + Domain string + Location string + RewriteHost string + Username string + Password string + Headers map[string]string + + CreateConnFn CreateConnFunc +} + +// listen for a new domain name, if rewriteHost is not empty and rewriteFunc is not nil +// then rewrite the host header to rewriteHost +func (v *VhostMuxer) Listen(ctx context.Context, cfg *VhostRouteConfig) (l *Listener, err error) { + l = &Listener{ + name: cfg.Domain, + location: cfg.Location, + rewriteHost: cfg.RewriteHost, + userName: cfg.Username, + passWord: cfg.Password, + mux: v, + accept: make(chan net.Conn), + ctx: ctx, + } + err = v.registryRouter.Add(cfg.Domain, cfg.Location, l) + if err != nil { + return + } + return l, nil +} + +func (v *VhostMuxer) getListener(name, path string) (l *Listener, exist bool) { + // first we check the full hostname + // if not exist, then check the wildcard_domain such as *.example.com + vr, found := v.registryRouter.Get(name, path) + if found { + return vr.payload.(*Listener), true + } + + domainSplit := strings.Split(name, ".") + if len(domainSplit) < 3 { + return + } + + for { + if len(domainSplit) < 3 { + return + } + + domainSplit[0] = "*" + name = strings.Join(domainSplit, ".") + + vr, found = v.registryRouter.Get(name, path) + if found { + return vr.payload.(*Listener), true + } + domainSplit = domainSplit[1:] + } + return +} + +func (v *VhostMuxer) run() { + for { + conn, err := v.listener.Accept() + if err != nil { + return + } + go v.handle(conn) + } +} + +func (v *VhostMuxer) handle(c net.Conn) { + if err := c.SetDeadline(time.Now().Add(v.timeout)); err != nil { + c.Close() + return + } + + sConn, reqInfoMap, err := v.vhostFunc(c) + if err != nil { + log.Warn("get hostname from http/https request error: %v", err) + c.Close() + return + } + + name := strings.ToLower(reqInfoMap["Host"]) + path := strings.ToLower(reqInfoMap["Path"]) + l, ok := v.getListener(name, path) + if !ok { + res := notFoundResponse() + res.Write(c) + log.Debug("http request for host [%s] path [%s] not found", name, path) + c.Close() + return + } + xl := xlog.FromContextSafe(l.ctx) + + // if authFunc is exist and userName/password is set + // then verify user access + if l.mux.authFunc != nil && l.userName != "" && l.passWord != "" { + bAccess, err := l.mux.authFunc(c, l.userName, l.passWord, reqInfoMap["Authorization"]) + if bAccess == false || err != nil { + xl.Debug("check http Authorization failed") + res := noAuthResponse() + res.Write(c) + c.Close() + return + } + } + + if err = sConn.SetDeadline(time.Time{}); err != nil { + c.Close() + return + } + c = sConn + + xl.Debug("get new http request host [%s] path [%s]", name, path) + err = errors.PanicToError(func() { + l.accept <- c + }) + if err != nil { + xl.Warn("listener is already closed, ignore this request") + } +} + +type Listener struct { + name string + location string + rewriteHost string + userName string + passWord string + mux *VhostMuxer // for closing VhostMuxer + accept chan net.Conn + ctx context.Context +} + +func (l *Listener) Accept() (net.Conn, error) { + xl := xlog.FromContextSafe(l.ctx) + conn, ok := <-l.accept + if !ok { + return nil, fmt.Errorf("Listener closed") + } + + // if rewriteFunc is exist + // rewrite http requests with a modified host header + // if l.rewriteHost is empty, nothing to do + if l.mux.rewriteFunc != nil { + sConn, err := l.mux.rewriteFunc(conn, l.rewriteHost) + if err != nil { + xl.Warn("host header rewrite failed: %v", err) + return nil, fmt.Errorf("host header rewrite failed") + } + xl.Debug("rewrite host to [%s] success", l.rewriteHost) + conn = sConn + } + return frpNet.NewContextConn(conn, l.ctx), nil +} + +func (l *Listener) Close() error { + l.mux.registryRouter.Del(l.name, l.location) + close(l.accept) + return nil +} + +func (l *Listener) Name() string { + return l.name +} + +func (l *Listener) Addr() net.Addr { + return (*net.TCPAddr)(nil) +} diff --git a/utils/xlog/ctx.go b/utils/xlog/ctx.go new file mode 100644 index 0000000..1d3619b --- /dev/null +++ b/utils/xlog/ctx.go @@ -0,0 +1,42 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xlog + +import ( + "context" +) + +type key int + +const ( + xlogKey key = 0 +) + +func NewContext(ctx context.Context, xl *Logger) context.Context { + return context.WithValue(ctx, xlogKey, xl) +} + +func FromContext(ctx context.Context) (xl *Logger, ok bool) { + xl, ok = ctx.Value(xlogKey).(*Logger) + return +} + +func FromContextSafe(ctx context.Context) *Logger { + xl, ok := ctx.Value(xlogKey).(*Logger) + if !ok { + xl = New() + } + return xl +} diff --git a/utils/xlog/xlog.go b/utils/xlog/xlog.go new file mode 100644 index 0000000..4dfd597 --- /dev/null +++ b/utils/xlog/xlog.go @@ -0,0 +1,73 @@ +// Copyright 2019 fatedier, fatedier@gmail.com +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package xlog + +import ( + "github.com/charlesbao/frpc/utils/log" +) + +// Logger is not thread safety for operations on prefix +type Logger struct { + prefixes []string + + prefixString string +} + +func New() *Logger { + return &Logger{ + prefixes: make([]string, 0), + } +} + +func (l *Logger) ResetPrefixes() (old []string) { + old = l.prefixes + l.prefixes = make([]string, 0) + l.prefixString = "" + return +} + +func (l *Logger) AppendPrefix(prefix string) *Logger { + l.prefixes = append(l.prefixes, prefix) + l.prefixString += "[" + prefix + "] " + return l +} + +func (l *Logger) Spawn() *Logger { + nl := New() + for _, v := range l.prefixes { + nl.AppendPrefix(v) + } + return nl +} + +func (l *Logger) Error(format string, v ...interface{}) { + log.Log.Error(l.prefixString+format, v...) +} + +func (l *Logger) Warn(format string, v ...interface{}) { + log.Log.Warn(l.prefixString+format, v...) +} + +func (l *Logger) Info(format string, v ...interface{}) { + log.Log.Info(l.prefixString+format, v...) +} + +func (l *Logger) Debug(format string, v ...interface{}) { + log.Log.Debug(l.prefixString+format, v...) +} + +func (l *Logger) Trace(format string, v ...interface{}) { + log.Log.Trace(l.prefixString+format, v...) +}