From f2abefa25e11941f293795f3a65818e4dffef875 Mon Sep 17 00:00:00 2001 From: Daniel Ding Date: Thu, 13 Mar 2025 20:38:15 +0800 Subject: [PATCH] fea: support forwarding socks5. --- pkg/config/proxy.go | 5 +- pkg/libol/logger.go | 4 ++ pkg/proxy/http.go | 4 +- pkg/proxy/socks.go | 9 ++- pkg/socks5/request.go | 1 - pkg/socks5/request_test.go | 4 -- pkg/socks5/socks5.go | 123 ++++++++++++++++++++++++++++++++++--- pkg/socks5/socks5_test.go | 3 - 8 files changed, 129 insertions(+), 24 deletions(-) diff --git a/pkg/config/proxy.go b/pkg/config/proxy.go index 2293546..dc51292 100755 --- a/pkg/config/proxy.go +++ b/pkg/config/proxy.go @@ -19,8 +19,9 @@ type ShadowProxy struct { } type SocksProxy struct { - Listen string `json:"listen,omitempty"` - Auth *Password `json:"auth,omitempty"` + Listen string `json:"listen,omitempty" yaml:"listen,omitempty"` + Auth *Password `json:"auth,omitempty" yaml:"auth,omitempty"` + Backends []*HttpForward `json:"backends,omitempty" yaml:"backends,omitempty"` } type HttpForward struct { diff --git a/pkg/libol/logger.go b/pkg/libol/logger.go index d710078..8b97b7d 100755 --- a/pkg/libol/logger.go +++ b/pkg/libol/logger.go @@ -199,6 +199,10 @@ func (s *SubLogger) Print(format string, v ...interface{}) { s.logger.Write(PRINT, s.Fmt(format), v...) } +func (s *SubLogger) Printf(format string, v ...interface{}) { + s.logger.Write(PRINT, s.Fmt(format), v...) +} + func (s *SubLogger) Log(format string, v ...interface{}) { s.logger.Write(LOG, s.Fmt(format), v...) } diff --git a/pkg/proxy/http.go b/pkg/proxy/http.go index 78c1cd6..22fc812 100755 --- a/pkg/proxy/http.go +++ b/pkg/proxy/http.go @@ -60,7 +60,7 @@ type HttpProxy struct { } var ( - connectOkay = []byte("HTTP/1.1 200 Connection established\r\n\r\n") + httpOkay = "HTTP/1.1 200 OK\r\n\r\n" ) func decodeBasicAuth(auth string) (username, password string, ok bool) { @@ -430,7 +430,7 @@ func (t *HttpProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.out.Warn("HttpProxy.ServeHTTP %s: %s", r.URL.Host, err) return } - w.Write(connectOkay) + fmt.Fprint(w, httpOkay) t.toTunnel(w, conn, func(bytes int64) { t.doRecord(r, bytes) }) diff --git a/pkg/proxy/socks.go b/pkg/proxy/socks.go index d77f4b9..9803191 100755 --- a/pkg/proxy/socks.go +++ b/pkg/proxy/socks.go @@ -22,15 +22,20 @@ func NewSocksProxy(cfg *config.SocksProxy) *SocksProxy { // Create a SOCKS5 server auth := cfg.Auth authMethods := make([]socks5.Authenticator, 0, 2) - if len(auth.Username) > 0 { + if auth != nil && len(auth.Username) > 0 { author := socks5.UserPassAuthenticator{ Credentials: socks5.StaticCredentials{ auth.Username: auth.Password, }, } authMethods = append(authMethods, author) + + } + conf := &socks5.Config{ + Backends: cfg.Backends, + AuthMethods: authMethods, + Logger: s.out, } - conf := &socks5.Config{AuthMethods: authMethods} server, err := socks5.New(conf) if err != nil { s.out.Error("NewSocksProxy %s", err) diff --git a/pkg/socks5/request.go b/pkg/socks5/request.go index b615fcb..5d165bc 100644 --- a/pkg/socks5/request.go +++ b/pkg/socks5/request.go @@ -118,7 +118,6 @@ func NewRequest(bufConn io.Reader) (*Request, error) { // handleRequest is used for request processing after authentication func (s *Server) handleRequest(req *Request, conn conn) error { ctx := context.Background() - // Resolve the address if we have a FQDN dest := req.DestAddr if dest.FQDN != "" { diff --git a/pkg/socks5/request_test.go b/pkg/socks5/request_test.go index 5465113..8f8765c 100644 --- a/pkg/socks5/request_test.go +++ b/pkg/socks5/request_test.go @@ -4,9 +4,7 @@ import ( "bytes" "encoding/binary" "io" - "log" "net" - "os" "strings" "testing" ) @@ -52,7 +50,6 @@ func TestRequest_Connect(t *testing.T) { s := &Server{config: &Config{ Rules: PermitAll(), Resolver: DNSResolver{}, - Logger: log.New(os.Stdout, "", log.LstdFlags), }} // Create the connect request @@ -127,7 +124,6 @@ func TestRequest_Connect_RuleFail(t *testing.T) { s := &Server{config: &Config{ Rules: PermitNone(), Resolver: DNSResolver{}, - Logger: log.New(os.Stdout, "", log.LstdFlags), }} // Create the connect request diff --git a/pkg/socks5/socks5.go b/pkg/socks5/socks5.go index a17be68..9ddf917 100644 --- a/pkg/socks5/socks5.go +++ b/pkg/socks5/socks5.go @@ -2,11 +2,15 @@ package socks5 import ( "bufio" + "encoding/binary" "fmt" - "log" + "io" "net" - "os" + "regexp" + "time" + co "github.com/luscis/openlan/pkg/config" + "github.com/luscis/openlan/pkg/libol" "golang.org/x/net/context" ) @@ -44,10 +48,13 @@ type Config struct { // Logger can be used to provide a custom log target. // Defaults to stdout. - Logger *log.Logger + Logger *libol.SubLogger // Optional function for dialing out Dial func(ctx context.Context, network, addr string) (net.Conn, error) + + // Backends forwarding socks request + Backends []*co.HttpForward } // Server is reponsible for accepting connections and handling @@ -80,7 +87,7 @@ func New(conf *Config) (*Server, error) { // Ensure we have a log target if conf.Logger == nil { - conf.Logger = log.New(os.Stdout, "", log.LstdFlags) + conf.Logger = libol.NewSubLogger("") } server := &Server{ @@ -125,14 +132,14 @@ func (s *Server) ServeConn(conn net.Conn) error { // Read the version byte version := []byte{0} if _, err := bufConn.Read(version); err != nil { - s.config.Logger.Printf("[ERR] socks: Failed to get version byte: %v", err) + s.config.Logger.Error("socks: Failed to get version byte: %v", err) return err } // Ensure we are compatible if version[0] != socks5Version { err := fmt.Errorf("Unsupported SOCKS version: %v", version) - s.config.Logger.Printf("[ERR] socks: %v", err) + s.config.Logger.Error("socks: %v", err) return err } @@ -140,7 +147,7 @@ func (s *Server) ServeConn(conn net.Conn) error { authContext, err := s.authenticate(conn, bufConn) if err != nil { err = fmt.Errorf("Failed to authenticate: %v", err) - s.config.Logger.Printf("[ERR] socks: %v", err) + s.config.Logger.Error("socks: %v", err) return err } @@ -158,12 +165,108 @@ func (s *Server) ServeConn(conn net.Conn) error { request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} } - // Process the client request + dstAddr := request.DestAddr + via := s.findForward(dstAddr.Address()) + if via != nil { + if err := s.toForward(request, conn, via); err != nil { + s.config.Logger.Error("forward: %v", err) + return err + } + return nil + } + + s.config.Logger.Info("ServeConn: %s", dstAddr.Address()) + //Process the client request if err := s.handleRequest(request, conn); err != nil { err = fmt.Errorf("Failed to handle request: %v", err) - s.config.Logger.Printf("[ERR] socks: %v", err) - return err + s.config.Logger.Error("socks: %v", err) } return nil } + +func (s *Server) toTunnel(local net.Conn, target net.Conn) { + defer local.Close() + defer target.Close() + wait := libol.NewWaitOne(2) + + libol.Go(func() { + defer wait.Done() + io.Copy(local, target) + }) + libol.Go(func() { + defer wait.Done() + io.Copy(target, local) + }) + wait.Wait() +} + +func (s *Server) openConn(remote string) (net.Conn, error) { + return net.DialTimeout("tcp", remote, 10*time.Second) +} + +func (s *Server) isMatch(value string, rules []string) bool { + if len(rules) == 0 { + return true + } + for _, rule := range rules { + pattern := fmt.Sprintf(`(^|\.)%s(:\d+)?$`, regexp.QuoteMeta(rule)) + re := regexp.MustCompile(pattern) + if re.MatchString(value) { + return true + } + } + return false +} + +func (s *Server) findForward(host string) *co.HttpForward { + for _, via := range s.config.Backends { + if via != nil && s.isMatch(host, via.Match) { + return via + } + } + return nil +} + +func (s *Server) toForward(req *Request, local net.Conn, via *co.HttpForward) error { + dstAddr := req.DestAddr + s.config.Logger.Info("Connect %s via %s", dstAddr.Address(), via.Server) + + target, err := s.openConn(via.Server) + if err != nil { + sendReply(local, networkUnreachable, nil) + return err + } + + // Handshake: SOCKS5 no auth + _, err = target.Write([]byte{socks5Version, 1, 0}) + if err != nil { + sendReply(local, serverFailure, nil) + return err + } + + reply := make([]byte, 2) + _, err = target.Read(reply) + if reply[0] != socks5Version || reply[1] != successReply { + sendReply(local, serverFailure, nil) + return err + } + + domain := []byte(dstAddr.FQDN) + port := []byte{0, 0} + binary.BigEndian.PutUint16(port, uint16(dstAddr.Port)) + + // Request: CONNECT to domain + bind := []byte{socks5Version, 1, 0, 3} + bind = append(bind, byte(len(domain))) + bind = append(bind, domain...) + bind = append(bind, port...) + _, err = target.Write(bind) + if err != nil { + sendReply(local, serverFailure, nil) + return err + } + + s.toTunnel(local, target) + return nil +} diff --git a/pkg/socks5/socks5_test.go b/pkg/socks5/socks5_test.go index 8cfbee0..e91929c 100644 --- a/pkg/socks5/socks5_test.go +++ b/pkg/socks5/socks5_test.go @@ -4,9 +4,7 @@ import ( "bytes" "encoding/binary" "io" - "log" "net" - "os" "testing" "time" ) @@ -43,7 +41,6 @@ func TestSOCKS5_Connect(t *testing.T) { cator := UserPassAuthenticator{Credentials: creds} conf := &Config{ AuthMethods: []Authenticator{cator}, - Logger: log.New(os.Stdout, "", log.LstdFlags), } serv, err := New(conf) if err != nil {