diff --git a/gost/Makefile b/gost/Makefile deleted file mode 100644 index f3721b7d..00000000 --- a/gost/Makefile +++ /dev/null @@ -1,99 +0,0 @@ -NAME=gost -BINDIR=bin -VERSION=$(shell cat gost.go | grep 'Version =' | sed 's/.*\"\(.*\)\".*/\1/g') -GOBUILD=CGO_ENABLED=0 go build --ldflags="-s -w" -v -x -a -GOFILES=cmd/gost/* - -PLATFORM_LIST = \ - darwin-amd64 \ - darwin-arm64 \ - linux-386 \ - linux-amd64 \ - linux-armv5 \ - linux-armv6 \ - linux-armv7 \ - linux-armv8 \ - linux-mips-softfloat \ - linux-mips-hardfloat \ - linux-mipsle-softfloat \ - linux-mipsle-hardfloat \ - linux-mips64 \ - linux-mips64le \ - freebsd-386 \ - freebsd-amd64 - -WINDOWS_ARCH_LIST = \ - windows-386 \ - windows-amd64 - -all: linux-amd64 darwin-amd64 windows-amd64 # Most used - -darwin-amd64: - GOARCH=amd64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -darwin-arm64: - GOARCH=arm64 GOOS=darwin $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-386: - GOARCH=386 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-amd64: - GOARCH=amd64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-armv5: - GOARCH=arm GOOS=linux GOARM=5 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-armv6: - GOARCH=arm GOOS=linux GOARM=6 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-armv7: - GOARCH=arm GOOS=linux GOARM=7 $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-armv8: - GOARCH=arm64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-mips-softfloat: - GOARCH=mips GOMIPS=softfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-mips-hardfloat: - GOARCH=mips GOMIPS=hardfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-mipsle-softfloat: - GOARCH=mipsle GOMIPS=softfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-mipsle-hardfloat: - GOARCH=mipsle GOMIPS=hardfloat GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-mips64: - GOARCH=mips64 GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -linux-mips64le: - GOARCH=mips64le GOOS=linux $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -freebsd-386: - GOARCH=386 GOOS=freebsd $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -freebsd-amd64: - GOARCH=amd64 GOOS=freebsd $(GOBUILD) -o $(BINDIR)/$(NAME)-$@ $(GOFILES) - -windows-386: - GOARCH=386 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe $(GOFILES) - -windows-amd64: - GOARCH=amd64 GOOS=windows $(GOBUILD) -o $(BINDIR)/$(NAME)-$@.exe $(GOFILES) - -gz_releases=$(addsuffix .gz, $(PLATFORM_LIST)) -zip_releases=$(addsuffix .zip, $(WINDOWS_ARCH_LIST)) - -$(gz_releases): %.gz : % - chmod +x $(BINDIR)/$(NAME)-$(basename $@) - gzip -f -S -$(VERSION).gz $(BINDIR)/$(NAME)-$(basename $@) - -$(zip_releases): %.zip : % - zip -m -j $(BINDIR)/$(NAME)-$(basename $@)-$(VERSION).zip $(BINDIR)/$(NAME)-$(basename $@).exe - -all-arch: $(PLATFORM_LIST) $(WINDOWS_ARCH_LIST) - -releases: $(gz_releases) $(zip_releases) -clean: - rm $(BINDIR)/* diff --git a/gost/bypass.go b/gost/bypass.go deleted file mode 100644 index 28ca8c87..00000000 --- a/gost/bypass.go +++ /dev/null @@ -1,298 +0,0 @@ -package gost - -import ( - "bufio" - "bytes" - "fmt" - "io" - "net" - "strconv" - "strings" - "sync" - "time" - - glob "github.com/gobwas/glob" -) - -// Matcher is a generic pattern matcher, -// it gives the match result of the given pattern for specific v. -type Matcher interface { - Match(v string) bool - String() string -} - -// NewMatcher creates a Matcher for the given pattern. -// The acutal Matcher depends on the pattern: -// IP Matcher if pattern is a valid IP address. -// CIDR Matcher if pattern is a valid CIDR address. -// Domain Matcher if both of the above are not. -func NewMatcher(pattern string) Matcher { - if pattern == "" { - return nil - } - if ip := net.ParseIP(pattern); ip != nil { - return IPMatcher(ip) - } - if _, inet, err := net.ParseCIDR(pattern); err == nil { - return CIDRMatcher(inet) - } - return DomainMatcher(pattern) -} - -type ipMatcher struct { - ip net.IP -} - -// IPMatcher creates a Matcher for a specific IP address. -func IPMatcher(ip net.IP) Matcher { - return &ipMatcher{ - ip: ip, - } -} - -func (m *ipMatcher) Match(ip string) bool { - if m == nil { - return false - } - return m.ip.Equal(net.ParseIP(ip)) -} - -func (m *ipMatcher) String() string { - return "ip " + m.ip.String() -} - -type cidrMatcher struct { - ipNet *net.IPNet -} - -// CIDRMatcher creates a Matcher for a specific CIDR notation IP address. -func CIDRMatcher(inet *net.IPNet) Matcher { - return &cidrMatcher{ - ipNet: inet, - } -} - -func (m *cidrMatcher) Match(ip string) bool { - if m == nil || m.ipNet == nil { - return false - } - return m.ipNet.Contains(net.ParseIP(ip)) -} - -func (m *cidrMatcher) String() string { - return "cidr " + m.ipNet.String() -} - -type domainMatcher struct { - pattern string - glob glob.Glob -} - -// DomainMatcher creates a Matcher for a specific domain pattern, -// the pattern can be a plain domain such as 'example.com', -// a wildcard such as '*.exmaple.com' or a special wildcard '.example.com'. -func DomainMatcher(pattern string) Matcher { - p := pattern - if strings.HasPrefix(pattern, ".") { - p = pattern[1:] // trim the prefix '.' - pattern = "*" + p - } - return &domainMatcher{ - pattern: p, - glob: glob.MustCompile(pattern), - } -} - -func (m *domainMatcher) Match(domain string) bool { - if m == nil || m.glob == nil { - return false - } - - if domain == m.pattern { - return true - } - return m.glob.Match(domain) -} - -func (m *domainMatcher) String() string { - return "domain " + m.pattern -} - -// Bypass is a filter for address (IP or domain). -// It contains a list of matchers. -type Bypass struct { - matchers []Matcher - period time.Duration // the period for live reloading - reversed bool - stopped chan struct{} - mux sync.RWMutex -} - -// NewBypass creates and initializes a new Bypass using matchers as its match rules. -// The rules will be reversed if the reversed is true. -func NewBypass(reversed bool, matchers ...Matcher) *Bypass { - return &Bypass{ - matchers: matchers, - reversed: reversed, - stopped: make(chan struct{}), - } -} - -// NewBypassPatterns creates and initializes a new Bypass using matcher patterns as its match rules. -// The rules will be reversed if the reverse is true. -func NewBypassPatterns(reversed bool, patterns ...string) *Bypass { - var matchers []Matcher - for _, pattern := range patterns { - if m := NewMatcher(pattern); m != nil { - matchers = append(matchers, m) - } - } - bp := NewBypass(reversed) - bp.AddMatchers(matchers...) - return bp -} - -// Contains reports whether the bypass includes addr. -func (bp *Bypass) Contains(addr string) bool { - if bp == nil || addr == "" { - return false - } - - // try to strip the port - if host, port, _ := net.SplitHostPort(addr); host != "" && port != "" { - if p, _ := strconv.Atoi(port); p > 0 { // port is valid - addr = host - } - } - - bp.mux.RLock() - defer bp.mux.RUnlock() - - if len(bp.matchers) == 0 { - return false - } - - var matched bool - for _, matcher := range bp.matchers { - if matcher == nil { - continue - } - if matcher.Match(addr) { - matched = true - break - } - } - return !bp.reversed && matched || - bp.reversed && !matched -} - -// AddMatchers appends matchers to the bypass matcher list. -func (bp *Bypass) AddMatchers(matchers ...Matcher) { - bp.mux.Lock() - defer bp.mux.Unlock() - - bp.matchers = append(bp.matchers, matchers...) -} - -// Matchers return the bypass matcher list. -func (bp *Bypass) Matchers() []Matcher { - bp.mux.RLock() - defer bp.mux.RUnlock() - - return bp.matchers -} - -// Reversed reports whether the rules of the bypass are reversed. -func (bp *Bypass) Reversed() bool { - bp.mux.RLock() - defer bp.mux.RUnlock() - - return bp.reversed -} - -// Reload parses config from r, then live reloads the bypass. -func (bp *Bypass) Reload(r io.Reader) error { - var matchers []Matcher - var period time.Duration - var reversed bool - - if r == nil || bp.Stopped() { - return nil - } - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - ss := splitLine(line) - if len(ss) == 0 { - continue - } - switch ss[0] { - case "reload": // reload option - if len(ss) > 1 { - period, _ = time.ParseDuration(ss[1]) - } - case "reverse": // reverse option - if len(ss) > 1 { - reversed, _ = strconv.ParseBool(ss[1]) - } - default: - matchers = append(matchers, NewMatcher(ss[0])) - } - } - - if err := scanner.Err(); err != nil { - return err - } - - bp.mux.Lock() - defer bp.mux.Unlock() - - bp.matchers = matchers - bp.period = period - bp.reversed = reversed - - return nil -} - -// Period returns the reload period. -func (bp *Bypass) Period() time.Duration { - if bp.Stopped() { - return -1 - } - - bp.mux.RLock() - defer bp.mux.RUnlock() - - return bp.period -} - -// Stop stops reloading. -func (bp *Bypass) Stop() { - select { - case <-bp.stopped: - default: - close(bp.stopped) - } -} - -// Stopped checks whether the reloader is stopped. -func (bp *Bypass) Stopped() bool { - select { - case <-bp.stopped: - return true - default: - return false - } -} - -func (bp *Bypass) String() string { - b := &bytes.Buffer{} - fmt.Fprintf(b, "reversed: %v\n", bp.Reversed()) - fmt.Fprintf(b, "reload: %v\n", bp.Period()) - for _, m := range bp.Matchers() { - b.WriteString(m.String()) - b.WriteByte('\n') - } - return b.String() -} diff --git a/gost/bypass_test.go b/gost/bypass_test.go deleted file mode 100644 index d895121c..00000000 --- a/gost/bypass_test.go +++ /dev/null @@ -1,303 +0,0 @@ -package gost - -import ( - "bytes" - "fmt" - "io" - "testing" - "time" -) - -var bypassContainTests = []struct { - patterns []string - reversed bool - addr string - bypassed bool -}{ - // empty pattern - {[]string{""}, false, "", false}, - {[]string{""}, false, "192.168.1.1", false}, - {[]string{""}, true, "", false}, - {[]string{""}, true, "192.168.1.1", false}, - - // IP address - {[]string{"192.168.1.1"}, false, "192.168.1.1", true}, - {[]string{"192.168.1.1"}, true, "192.168.1.1", false}, - {[]string{"192.168.1.1"}, false, "192.168.1.2", false}, - {[]string{"192.168.1.1"}, true, "192.168.1.2", true}, - {[]string{"0.0.0.0"}, false, "0.0.0.0", true}, - {[]string{"0.0.0.0"}, true, "0.0.0.0", false}, - - // CIDR address - {[]string{"192.168.1.0/0"}, false, "1.2.3.4", true}, - {[]string{"192.168.1.0/0"}, true, "1.2.3.4", false}, - {[]string{"192.168.1.0/8"}, false, "192.1.0.255", true}, - {[]string{"192.168.1.0/8"}, true, "192.1.0.255", false}, - {[]string{"192.168.1.0/8"}, false, "191.1.0.255", false}, - {[]string{"192.168.1.0/8"}, true, "191.1.0.255", true}, - {[]string{"192.168.1.0/16"}, false, "192.168.0.255", true}, - {[]string{"192.168.1.0/16"}, true, "192.168.0.255", false}, - {[]string{"192.168.1.0/16"}, false, "192.0.1.255", false}, - {[]string{"192.168.1.0/16"}, true, "192.0.0.255", true}, - {[]string{"192.168.1.0/24"}, false, "192.168.1.255", true}, - {[]string{"192.168.1.0/24"}, true, "192.168.1.255", false}, - {[]string{"192.168.1.0/24"}, false, "192.168.0.255", false}, - {[]string{"192.168.1.0/24"}, true, "192.168.0.255", true}, - {[]string{"192.168.1.1/32"}, false, "192.168.1.1", true}, - {[]string{"192.168.1.1/32"}, true, "192.168.1.1", false}, - {[]string{"192.168.1.1/32"}, false, "192.168.1.2", false}, - {[]string{"192.168.1.1/32"}, true, "192.168.1.2", true}, - - // plain domain - {[]string{"www.example.com"}, false, "www.example.com", true}, - {[]string{"www.example.com"}, true, "www.example.com", false}, - {[]string{"http://www.example.com"}, false, "http://www.example.com", true}, - {[]string{"http://www.example.com"}, true, "http://www.example.com", false}, - {[]string{"http://www.example.com"}, false, "http://example.com", false}, - {[]string{"http://www.example.com"}, true, "http://example.com", true}, - {[]string{"www.example.com"}, false, "example.com", false}, - {[]string{"www.example.com"}, true, "example.com", true}, - - // host:port - {[]string{"192.168.1.1"}, false, "192.168.1.1:80", true}, - {[]string{"192.168.1.1"}, true, "192.168.1.1:80", false}, - {[]string{"192.168.1.1:80"}, false, "192.168.1.1", false}, - {[]string{"192.168.1.1:80"}, true, "192.168.1.1", true}, - {[]string{"192.168.1.1:80"}, false, "192.168.1.1:80", false}, - {[]string{"192.168.1.1:80"}, true, "192.168.1.1:80", true}, - {[]string{"192.168.1.1:80"}, false, "192.168.1.1:8080", false}, - {[]string{"192.168.1.1:80"}, true, "192.168.1.1:8080", true}, - - {[]string{"example.com"}, false, "example.com:80", true}, - {[]string{"example.com"}, true, "example.com:80", false}, - {[]string{"example.com:80"}, false, "example.com", false}, - {[]string{"example.com:80"}, true, "example.com", true}, - {[]string{"example.com:80"}, false, "example.com:80", false}, - {[]string{"example.com:80"}, true, "example.com:80", true}, - {[]string{"example.com:80"}, false, "example.com:8080", false}, - {[]string{"example.com:80"}, true, "example.com:8080", true}, - - // domain wildcard - - {[]string{"*"}, false, "", false}, - {[]string{"*"}, false, "192.168.1.1", true}, - {[]string{"*"}, false, "192.168.0.0/16", true}, - {[]string{"*"}, false, "http://example.com", true}, - {[]string{"*"}, false, "example.com:80", true}, - {[]string{"*"}, true, "", false}, - {[]string{"*"}, true, "192.168.1.1", false}, - {[]string{"*"}, true, "192.168.0.0/16", false}, - {[]string{"*"}, true, "http://example.com", false}, - {[]string{"*"}, true, "example.com:80", false}, - - // sub-domain - {[]string{"*.example.com"}, false, "example.com", false}, - {[]string{"*.example.com"}, false, "http://example.com", false}, - {[]string{"*.example.com"}, false, "www.example.com", true}, - {[]string{"*.example.com"}, false, "http://www.example.com", true}, - {[]string{"*.example.com"}, false, "abc.def.example.com", true}, - - {[]string{"*.*.example.com"}, false, "example.com", false}, - {[]string{"*.*.example.com"}, false, "www.example.com", false}, - {[]string{"*.*.example.com"}, false, "abc.def.example.com", true}, - {[]string{"*.*.example.com"}, false, "abc.def.ghi.example.com", true}, - - {[]string{"**.example.com"}, false, "example.com", false}, - {[]string{"**.example.com"}, false, "www.example.com", true}, - {[]string{"**.example.com"}, false, "abc.def.ghi.example.com", true}, - - // prefix wildcard - {[]string{"*example.com"}, false, "example.com", true}, - {[]string{"*example.com"}, false, "www.example.com", true}, - {[]string{"*example.com"}, false, "abc.defexample.com", true}, - {[]string{"*example.com"}, false, "abc.def-example.com", true}, - {[]string{"*example.com"}, false, "abc.def.example.com", true}, - {[]string{"*example.com"}, false, "http://www.example.com", true}, - {[]string{"*example.com"}, false, "e-xample.com", false}, - - {[]string{"http://*.example.com"}, false, "example.com", false}, - {[]string{"http://*.example.com"}, false, "http://example.com", false}, - {[]string{"http://*.example.com"}, false, "http://www.example.com", true}, - {[]string{"http://*.example.com"}, false, "https://www.example.com", false}, - {[]string{"http://*.example.com"}, false, "http://abc.def.example.com", true}, - - {[]string{"www.*.com"}, false, "www.example.com", true}, - {[]string{"www.*.com"}, false, "www.abc.def.com", true}, - - {[]string{"www.*.*.com"}, false, "www.example.com", false}, - {[]string{"www.*.*.com"}, false, "www.abc.def.com", true}, - {[]string{"www.*.*.com"}, false, "www.abc.def.ghi.com", true}, - - {[]string{"www.*example*.com"}, false, "www.example.com", true}, - {[]string{"www.*example*.com"}, false, "www.abc.example.def.com", true}, - {[]string{"www.*example*.com"}, false, "www.e-xample.com", false}, - - {[]string{"www.example.*"}, false, "www.example.com", true}, - {[]string{"www.example.*"}, false, "www.example.io", true}, - {[]string{"www.example.*"}, false, "www.example.com.cn", true}, - - {[]string{".example.com"}, false, "www.example.com", true}, - {[]string{".example.com"}, false, "example.com", true}, - {[]string{".example.com"}, false, "www.example.com.cn", false}, - - {[]string{"example.com*"}, false, "example.com", true}, - {[]string{"example.com:*"}, false, "example.com", false}, - {[]string{"example.com:*"}, false, "example.com:80", false}, - {[]string{"example.com:*"}, false, "example.com:8080", false}, - {[]string{"example.com:*"}, false, "example.com:http", true}, - {[]string{"example.com:*"}, false, "http://example.com:80", false}, - - {[]string{"*example.com*"}, false, "example.com:80", true}, - {[]string{"*example.com:*"}, false, "example.com:80", false}, - - {[]string{".example.com:*"}, false, "www.example.com", false}, - {[]string{".example.com:*"}, false, "http://www.example.com", false}, - {[]string{".example.com:*"}, false, "example.com:80", false}, - {[]string{".example.com:*"}, false, "www.example.com:8080", false}, - {[]string{".example.com:*"}, false, "http://www.example.com:80", true}, -} - -func TestBypassContains(t *testing.T) { - for i, tc := range bypassContainTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - bp := NewBypassPatterns(tc.reversed, tc.patterns...) - if bp.Contains(tc.addr) != tc.bypassed { - t.Errorf("#%d test failed: %v, %s", i, tc.patterns, tc.addr) - } - }) - } -} - -var bypassReloadTests = []struct { - r io.Reader - - reversed bool - period time.Duration - - addr string - bypassed bool - stopped bool -}{ - { - r: nil, - reversed: false, - period: 0, - addr: "192.168.1.1", - bypassed: false, - stopped: false, - }, - { - r: bytes.NewBufferString(""), - reversed: false, - period: 0, - addr: "192.168.1.1", - bypassed: false, - stopped: false, - }, - { - r: bytes.NewBufferString("reverse true\nreload 10s"), - reversed: true, - period: 10 * time.Second, - addr: "192.168.1.1", - bypassed: false, - stopped: false, - }, - { - r: bytes.NewBufferString("reverse false\nreload 10s\n192.168.1.1"), - reversed: false, - period: 10 * time.Second, - addr: "192.168.1.1", - bypassed: true, - stopped: false, - }, - { - r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.0.0/16"), - reversed: false, - period: 0, - addr: "192.168.10.2", - bypassed: true, - stopped: true, - }, - { - r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.0/24 #comment"), - reversed: false, - period: 0, - addr: "192.168.10.2", - bypassed: false, - stopped: true, - }, - { - r: bytes.NewBufferString("reverse false\nreload 10s\n192.168.1.1\n#example.com"), - reversed: false, - period: 10 * time.Second, - addr: "example.com", - bypassed: false, - stopped: false, - }, - { - r: bytes.NewBufferString("#reverse true\n#reload 10s\n192.168.1.1\n#example.com"), - reversed: false, - period: 0, - addr: "192.168.1.1", - bypassed: true, - stopped: true, - }, - { - r: bytes.NewBufferString("#reverse true\n#reload 10s\nexample.com #comment"), - reversed: false, - period: 0, - addr: "example.com", - bypassed: true, - stopped: true, - }, - { - r: bytes.NewBufferString("#reverse true\n#reload 10s\n.example.com"), - reversed: false, - period: 0, - addr: "example.com", - bypassed: true, - stopped: true, - }, - { - r: bytes.NewBufferString("#reverse true\n#reload 10s\n*.example.com"), - reversed: false, - period: 0, - addr: "example.com", - bypassed: false, - stopped: true, - }, -} - -func TestByapssReload(t *testing.T) { - for i, tc := range bypassReloadTests { - bp := NewBypass(false) - if err := bp.Reload(tc.r); err != nil { - t.Error(err) - } - t.Log(bp.String()) - - if bp.Reversed() != tc.reversed { - t.Errorf("#%d test failed: reversed value should be %v, got %v", - i, tc.reversed, bp.reversed) - } - if bp.Period() != tc.period { - t.Errorf("#%d test failed: period value should be %v, got %v", - i, tc.period, bp.Period()) - } - if bp.Contains(tc.addr) != tc.bypassed { - t.Errorf("#%d test failed: %v, %s", i, bp.reversed, tc.addr) - } - if tc.stopped { - bp.Stop() - if bp.Period() >= 0 { - t.Errorf("period of the stopped reloader should be minus value") - } - bp.Stop() - } - if bp.Stopped() != tc.stopped { - t.Errorf("#%d test failed: stopped value should be %v, got %v", - i, tc.stopped, bp.Stopped()) - } - } -} diff --git a/gost/chain.go b/gost/chain.go index bedeb6f2..cd55e1d9 100644 --- a/gost/chain.go +++ b/gost/chain.go @@ -3,10 +3,9 @@ package gost import ( "context" "errors" + "fmt" "net" "time" - - "github.com/go-log/log" ) var ( @@ -142,7 +141,7 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op ipAddr := address if address != "" { - ipAddr = c.resolve(address, options.Resolver, options.Hosts) + ipAddr = c.resolve(address) } timeout := options.Timeout @@ -179,22 +178,10 @@ func (c *Chain) dialWithOptions(ctx context.Context, network, address string, op return cc, nil } -func (*Chain) resolve(addr string, resolver Resolver, hosts *Hosts) string { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return addr - } - - if ip := hosts.Lookup(host); ip != nil { - return net.JoinHostPort(ip.String(), port) - } - if resolver != nil { - ips, err := resolver.Resolve(host) - if err != nil { - log.Logf("[resolver] %s: %v", host, err) - } - if len(ips) > 0 { - return net.JoinHostPort(ips[0].String(), port) +func (*Chain) resolve(addr string) string { + if host, port, err := net.SplitHostPort(addr); err == nil { + if ips, err := net.LookupIP(host); err == nil && len(ips) > 0 { + return fmt.Sprintf("%s:%s", ips[0].String(), port) } } return addr @@ -302,10 +289,6 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { return } - if node.Bypass.Contains(addr) { - break - } - if node.Client.Transporter.Multiplex() { node.DialOptions = append(node.DialOptions, ChainDialOption(route), @@ -324,10 +307,8 @@ func (c *Chain) selectRouteFor(addr string) (route *Chain, err error) { // ChainOptions holds options for Chain. type ChainOptions struct { - Retries int - Timeout time.Duration - Hosts *Hosts - Resolver Resolver + Retries int + Timeout time.Duration } // ChainOption allows a common way to set chain options. @@ -346,17 +327,3 @@ func TimeoutChainOption(timeout time.Duration) ChainOption { opts.Timeout = timeout } } - -// HostsChainOption specifies the hosts used by Chain.Dial. -func HostsChainOption(hosts *Hosts) ChainOption { - return func(opts *ChainOptions) { - opts.Hosts = hosts - } -} - -// ResolverChainOption specifies the Resolver used by Chain.Dial. -func ResolverChainOption(resolver Resolver) ChainOption { - return func(opts *ChainOptions) { - opts.Resolver = resolver - } -} diff --git a/gost/client.go b/gost/client.go index c840067f..679de9d8 100644 --- a/gost/client.go +++ b/gost/client.go @@ -3,6 +3,7 @@ package gost import ( "context" "crypto/tls" + "fmt" "net" "net/url" "time" @@ -19,24 +20,6 @@ type Client struct { Transporter } -// DefaultClient is a standard HTTP proxy client. -var DefaultClient = &Client{Connector: HTTPConnector(nil), Transporter: TCPTransporter()} - -// Dial connects to the address addr via the DefaultClient. -func Dial(addr string, options ...DialOption) (net.Conn, error) { - return DefaultClient.Dial(addr, options...) -} - -// Handshake performs a handshake via the DefaultClient. -func Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return DefaultClient.Handshake(conn, options...) -} - -// Connect connects to the address addr via the DefaultClient. -func Connect(conn net.Conn, addr string) (net.Conn, error) { - return DefaultClient.Connect(conn, addr) -} - // Connector is responsible for connecting to the destination address. type Connector interface { // Deprecated: use ConnectContext instead. @@ -63,7 +46,8 @@ func (c *autoConnector) ConnectContext(ctx context.Context, conn net.Conn, netwo var cnr Connector switch network { case "tcp", "tcp4", "tcp6": - cnr = &httpConnector{User: c.User} + fmt.Println("xxxxxxxxxxxxxxxxxxxxxxx------------------------------------------") + //cnr = &httpConnector{User: c.User} default: cnr = &socks5UDPTunConnector{User: c.User} } @@ -119,10 +103,7 @@ type HandshakeOptions struct { Interval time.Duration Retry int TLSConfig *tls.Config - WSOptions *WSOptions - KCPConfig *KCPConfig QUICConfig *QUICConfig - SSHConfig *SSHConfig } // HandshakeOption allows a common way to set HandshakeOptions. @@ -177,20 +158,6 @@ func TLSConfigHandshakeOption(config *tls.Config) HandshakeOption { } } -// WSOptionsHandshakeOption specifies the websocket options used by websocket handshake -func WSOptionsHandshakeOption(options *WSOptions) HandshakeOption { - return func(opts *HandshakeOptions) { - opts.WSOptions = options - } -} - -// KCPConfigHandshakeOption specifies the KCP config used by KCP handshake -func KCPConfigHandshakeOption(config *KCPConfig) HandshakeOption { - return func(opts *HandshakeOptions) { - opts.KCPConfig = config - } -} - // QUICConfigHandshakeOption specifies the QUIC config used by QUIC handshake func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption { return func(opts *HandshakeOptions) { @@ -198,13 +165,6 @@ func QUICConfigHandshakeOption(config *QUICConfig) HandshakeOption { } } -// SSHConfigHandshakeOption specifies the ssh config used by SSH client handshake. -func SSHConfigHandshakeOption(config *SSHConfig) HandshakeOption { - return func(opts *HandshakeOptions) { - opts.SSHConfig = config - } -} - // ConnectOptions describes the options for Connector.Connect. type ConnectOptions struct { Addr string diff --git a/gost/dns.go b/gost/dns.go deleted file mode 100644 index 1b024040..00000000 --- a/gost/dns.go +++ /dev/null @@ -1,422 +0,0 @@ -package gost - -import ( - "bytes" - "context" - "crypto/tls" - "encoding/base64" - "errors" - "io" - "io/ioutil" - "net" - "net/http" - "strconv" - "strings" - "time" - - "github.com/go-log/log" - "github.com/miekg/dns" -) - -var ( - defaultResolver Resolver -) - -func init() { - defaultResolver = NewResolver( - DefaultResolverTimeout, - NameServer{ - Addr: "127.0.0.1:53", - Protocol: "udp", - }) - defaultResolver.Init() -} - -type dnsHandler struct { - options *HandlerOptions -} - -// DNSHandler creates a Handler for DNS server. -func DNSHandler(raddr string, opts ...HandlerOption) Handler { - h := &dnsHandler{} - - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *dnsHandler) Init(opts ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - - for _, opt := range opts { - opt(h.options) - } -} - -func (h *dnsHandler) Handle(conn net.Conn) { - defer conn.Close() - - b := mPool.Get().([]byte) - defer mPool.Put(b) - - n, err := conn.Read(b) - if err != nil { - log.Logf("[dns] %s - %s: %v", conn.RemoteAddr(), conn.LocalAddr(), err) - } - - mq := &dns.Msg{} - if err = mq.Unpack(b[:n]); err != nil { - log.Logf("[dns] %s - %s request unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - log.Logf("[dns] %s -> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mq)) - if Debug { - log.Logf("[dns] %s >>> %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mq.String()) - } - - start := time.Now() - - resolver := h.options.Resolver - if resolver == nil { - resolver = defaultResolver - } - reply, err := resolver.Exchange(context.Background(), b[:n]) - if err != nil { - log.Logf("[dns] %s - %s exchange: %v", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - - rtt := time.Since(start) - - mr := &dns.Msg{} - if err = mr.Unpack(reply); err != nil { - log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - log.Logf("[dns] %s <- %s: %s [%s]", - conn.RemoteAddr(), conn.LocalAddr(), h.dumpMsgHeader(mr), rtt) - if Debug { - log.Logf("[dns] %s <<< %s: %s", conn.RemoteAddr(), conn.LocalAddr(), mr.String()) - } - - if _, err = conn.Write(reply); err != nil { - log.Logf("[dns] %s - %s reply unpack: %v", conn.RemoteAddr(), conn.LocalAddr(), err) - } -} - -func (h *dnsHandler) dumpMsgHeader(m *dns.Msg) string { - buf := new(bytes.Buffer) - buf.WriteString(m.MsgHdr.String() + " ") - buf.WriteString("QUERY: " + strconv.Itoa(len(m.Question)) + ", ") - buf.WriteString("ANSWER: " + strconv.Itoa(len(m.Answer)) + ", ") - buf.WriteString("AUTHORITY: " + strconv.Itoa(len(m.Ns)) + ", ") - buf.WriteString("ADDITIONAL: " + strconv.Itoa(len(m.Extra))) - return buf.String() -} - -// DNSOptions is options for DNS Listener. -type DNSOptions struct { - Mode string - UDPSize int - ReadTimeout time.Duration - WriteTimeout time.Duration - TLSConfig *tls.Config -} - -type dnsListener struct { - addr net.Addr - server dnsServer - connChan chan net.Conn - errc chan error -} - -// DNSListener creates a Listener for DNS proxy server. -func DNSListener(addr string, options *DNSOptions) (Listener, error) { - if options == nil { - options = &DNSOptions{} - } - - tlsConfig := options.TLSConfig - if tlsConfig == nil { - tlsConfig = DefaultTLSConfig - } - - ln := &dnsListener{ - connChan: make(chan net.Conn, 128), - errc: make(chan error, 1), - } - - var srv dnsServer - var err error - switch strings.ToLower(options.Mode) { - case "tcp": - srv = &dns.Server{ - Net: "tcp", - Addr: addr, - Handler: ln, - ReadTimeout: options.ReadTimeout, - WriteTimeout: options.WriteTimeout, - } - case "tls": - srv = &dns.Server{ - Net: "tcp-tls", - Addr: addr, - Handler: ln, - TLSConfig: tlsConfig, - ReadTimeout: options.ReadTimeout, - WriteTimeout: options.WriteTimeout, - } - case "https": - srv = &dohServer{ - addr: addr, - tlsConfig: tlsConfig, - server: &http.Server{ - Handler: ln, - ReadTimeout: options.ReadTimeout, - WriteTimeout: options.WriteTimeout, - }, - } - - default: - ln.addr, err = net.ResolveTCPAddr("tcp", addr) - srv = &dns.Server{ - Net: "udp", - Addr: addr, - Handler: ln, - UDPSize: options.UDPSize, - ReadTimeout: options.ReadTimeout, - WriteTimeout: options.WriteTimeout, - } - } - if err != nil { - return nil, err - } - - if ln.addr == nil { - ln.addr, err = net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - } - - ln.server = srv - - go func() { - if err := ln.server.ListenAndServe(); err != nil { - ln.errc <- err - return - } - }() - - select { - case err := <-ln.errc: - return nil, err - default: - } - - return ln, nil -} - -func (l *dnsListener) serve(w dnsResponseWriter, mq []byte) (err error) { - conn := newDNSServerConn(l.addr, w.RemoteAddr()) - conn.mq <- mq - - select { - case l.connChan <- conn: - default: - return errors.New("connection queue is full") - } - - select { - case mr := <-conn.mr: - _, err = w.Write(mr) - case <-conn.cclose: - err = io.EOF - } - return -} - -func (l *dnsListener) ServeDNS(w dns.ResponseWriter, m *dns.Msg) { - b, err := m.Pack() - if err != nil { - log.Logf("[dns] %s: %v", l.addr, err) - return - } - if err := l.serve(w, b); err != nil { - log.Logf("[dns] %s: %v", l.addr, err) - } -} - -// Based on https://github.com/semihalev/sdns -func (l *dnsListener) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var buf []byte - var err error - switch r.Method { - case http.MethodGet: - buf, err = base64.RawURLEncoding.DecodeString(r.URL.Query().Get("dns")) - if len(buf) == 0 || err != nil { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } - case http.MethodPost: - if r.Header.Get("Content-Type") != "application/dns-message" { - http.Error(w, http.StatusText(http.StatusUnsupportedMediaType), http.StatusUnsupportedMediaType) - return - } - - buf, err = ioutil.ReadAll(r.Body) - if err != nil { - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return - } - default: - http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) - return - } - - mq := &dns.Msg{} - if err := mq.Unpack(buf); err != nil { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) - return - } - - w.Header().Set("Server", "SDNS") - w.Header().Set("Content-Type", "application/dns-message") - - raddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) - if err := l.serve(newDoHResponseWriter(raddr, w), buf); err != nil { - log.Logf("[dns] %s: %v", l.addr, err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - } -} - -func (l *dnsListener) Accept() (conn net.Conn, err error) { - select { - case conn = <-l.connChan: - case err = <-l.errc: - } - return -} - -func (l *dnsListener) Close() error { - return l.server.Shutdown() -} - -func (l *dnsListener) Addr() net.Addr { - return l.addr -} - -type dnsServer interface { - ListenAndServe() error - Shutdown() error -} - -type dohServer struct { - addr string - tlsConfig *tls.Config - server *http.Server -} - -func (s *dohServer) ListenAndServe() error { - ln, err := net.Listen("tcp", s.addr) - if err != nil { - return err - } - ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, s.tlsConfig) - return s.server.Serve(ln) -} - -func (s *dohServer) Shutdown() error { - return s.server.Shutdown(context.Background()) -} - -type dnsServerConn struct { - mq chan []byte - mr chan []byte - cclose chan struct{} - laddr, raddr net.Addr -} - -func newDNSServerConn(laddr, raddr net.Addr) *dnsServerConn { - return &dnsServerConn{ - mq: make(chan []byte, 1), - mr: make(chan []byte, 1), - laddr: laddr, - raddr: raddr, - cclose: make(chan struct{}), - } -} - -func (c *dnsServerConn) Read(b []byte) (n int, err error) { - select { - case mb := <-c.mq: - n = copy(b, mb) - case <-c.cclose: - err = errors.New("connection is closed") - } - return -} - -func (c *dnsServerConn) Write(b []byte) (n int, err error) { - select { - case c.mr <- b: - n = len(b) - case <-c.cclose: - err = errors.New("broken pipe") - } - - return -} - -func (c *dnsServerConn) Close() error { - select { - case <-c.cclose: - default: - close(c.cclose) - } - return nil -} - -func (c *dnsServerConn) LocalAddr() net.Addr { - return c.laddr -} - -func (c *dnsServerConn) RemoteAddr() net.Addr { - return c.raddr -} - -func (c *dnsServerConn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *dnsServerConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *dnsServerConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "dns", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -type dnsResponseWriter interface { - io.Writer - RemoteAddr() net.Addr -} - -type dohResponseWriter struct { - raddr net.Addr - http.ResponseWriter -} - -func newDoHResponseWriter(raddr net.Addr, w http.ResponseWriter) dnsResponseWriter { - return &dohResponseWriter{ - raddr: raddr, - ResponseWriter: w, - } -} - -func (w *dohResponseWriter) RemoteAddr() net.Addr { - return w.raddr -} diff --git a/gost/ftcp.go b/gost/ftcp.go deleted file mode 100644 index a1cfcf0b..00000000 --- a/gost/ftcp.go +++ /dev/null @@ -1,175 +0,0 @@ -package gost - -import ( - "errors" - "net" - "time" - - "github.com/go-log/log" - "github.com/xtaci/tcpraw" -) - -type fakeTCPTransporter struct{} - -// FakeTCPTransporter creates a Transporter that is used by fake tcp client. -func FakeTCPTransporter() Transporter { - return &fakeTCPTransporter{} -} - -func (tr *fakeTCPTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - raddr, er := net.ResolveTCPAddr("tcp", addr) - if er != nil { - return nil, er - } - c, err := tcpraw.Dial("tcp", addr) - if err != nil { - return - } - conn = &fakeTCPConn{ - raddr: raddr, - PacketConn: c, - } - return conn, nil -} - -func (tr *fakeTCPTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return conn, nil -} - -func (tr *fakeTCPTransporter) Multiplex() bool { - return false -} - -// FakeTCPListenConfig is config for fake TCP Listener. -type FakeTCPListenConfig struct { - TTL time.Duration - Backlog int - QueueSize int -} - -type fakeTCPListener struct { - ln net.PacketConn - connChan chan net.Conn - errChan chan error - connMap udpConnMap - config *FakeTCPListenConfig -} - -// FakeTCPListener creates a Listener for fake TCP server. -func FakeTCPListener(addr string, cfg *FakeTCPListenConfig) (Listener, error) { - ln, err := tcpraw.Listen("tcp", addr) - if err != nil { - return nil, err - } - - if cfg == nil { - cfg = &FakeTCPListenConfig{} - } - - backlog := cfg.Backlog - if backlog <= 0 { - backlog = defaultBacklog - } - - l := &fakeTCPListener{ - ln: ln, - connChan: make(chan net.Conn, backlog), - errChan: make(chan error, 1), - config: cfg, - } - go l.listenLoop() - return l, nil -} - -func (l *fakeTCPListener) listenLoop() { - for { - b := make([]byte, mediumBufferSize) - n, raddr, err := l.ln.ReadFrom(b) - if err != nil { - log.Logf("[ftcp] peer -> %s : %s", l.Addr(), err) - l.Close() - l.errChan <- err - close(l.errChan) - return - } - - conn, ok := l.connMap.Get(raddr.String()) - if !ok { - conn = newUDPServerConn(l.ln, raddr, &udpServerConnConfig{ - ttl: l.config.TTL, - qsize: l.config.QueueSize, - onClose: func() { - l.connMap.Delete(raddr.String()) - log.Logf("[ftcp] %s closed (%d)", raddr, l.connMap.Size()) - }, - }) - - select { - case l.connChan <- conn: - l.connMap.Set(raddr.String(), conn) - log.Logf("[ftcp] %s -> %s (%d)", raddr, l.Addr(), l.connMap.Size()) - default: - conn.Close() - log.Logf("[ftcp] %s - %s: connection queue is full (%d)", raddr, l.Addr(), cap(l.connChan)) - } - } - - select { - case conn.rChan <- b[:n]: - if Debug { - log.Logf("[ftcp] %s >>> %s : length %d", raddr, l.Addr(), n) - } - default: - log.Logf("[ftcp] %s -> %s : recv queue is full (%d)", raddr, l.Addr(), cap(conn.rChan)) - } - } -} - -func (l *fakeTCPListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *fakeTCPListener) Addr() net.Addr { - return l.ln.LocalAddr() -} - -func (l *fakeTCPListener) Close() error { - err := l.ln.Close() - l.connMap.Range(func(k interface{}, v *udpServerConn) bool { - v.Close() - return true - }) - - return err -} - -type fakeTCPConn struct { - raddr net.Addr - net.PacketConn -} - -func (c *fakeTCPConn) Read(b []byte) (n int, err error) { - n, _, err = c.ReadFrom(b) - return -} - -func (c *fakeTCPConn) Write(b []byte) (n int, err error) { - return c.WriteTo(b, c.raddr) -} - -func (c *fakeTCPConn) RemoteAddr() net.Addr { - return c.raddr -} diff --git a/gost/handler.go b/gost/handler.go index db531472..c22d523c 100644 --- a/gost/handler.go +++ b/gost/handler.go @@ -7,7 +7,6 @@ import ( "net/url" "time" - "github.com/ginuerzh/gosocks4" "github.com/ginuerzh/gosocks5" "github.com/go-log/log" ) @@ -25,16 +24,11 @@ type HandlerOptions struct { Users []*url.Userinfo Authenticator Authenticator TLSConfig *tls.Config - Whitelist *Permissions - Blacklist *Permissions Strategy Strategy MaxFails int FailTimeout time.Duration - Bypass *Bypass Retries int Timeout time.Duration - Resolver Resolver - Hosts *Hosts ProbeResist string KnockingHost string Node Node @@ -92,27 +86,6 @@ func TLSConfigHandlerOption(config *tls.Config) HandlerOption { } } -// WhitelistHandlerOption sets the Whitelist option of HandlerOptions. -func WhitelistHandlerOption(whitelist *Permissions) HandlerOption { - return func(opts *HandlerOptions) { - opts.Whitelist = whitelist - } -} - -// BlacklistHandlerOption sets the Blacklist option of HandlerOptions. -func BlacklistHandlerOption(blacklist *Permissions) HandlerOption { - return func(opts *HandlerOptions) { - opts.Blacklist = blacklist - } -} - -// BypassHandlerOption sets the bypass option of HandlerOptions. -func BypassHandlerOption(bypass *Bypass) HandlerOption { - return func(opts *HandlerOptions) { - opts.Bypass = bypass - } -} - // StrategyHandlerOption sets the strategy option of HandlerOptions. func StrategyHandlerOption(strategy Strategy) HandlerOption { return func(opts *HandlerOptions) { @@ -148,27 +121,6 @@ func TimeoutHandlerOption(timeout time.Duration) HandlerOption { } } -// ResolverHandlerOption sets the resolver option of HandlerOptions. -func ResolverHandlerOption(resolver Resolver) HandlerOption { - return func(opts *HandlerOptions) { - opts.Resolver = resolver - } -} - -// HostsHandlerOption sets the Hosts option of HandlerOptions. -func HostsHandlerOption(hosts *Hosts) HandlerOption { - return func(opts *HandlerOptions) { - opts.Hosts = hosts - } -} - -// ProbeResistHandlerOption adds the probe resistance for HTTP proxy. -func ProbeResistHandlerOption(pr string) HandlerOption { - return func(opts *HandlerOptions) { - opts.ProbeResist = pr - } -} - // KnockingHandlerOption adds the knocking host for probe resistance. func KnockingHandlerOption(host string) HandlerOption { return func(opts *HandlerOptions) { @@ -243,18 +195,8 @@ func (h *autoHandler) Handle(conn net.Conn) { cc := &bufferdConn{Conn: conn, br: br} var handler Handler switch b[0] { - case gosocks4.Ver4: - // SOCKS4(a) does not suppport authentication method, - // so we ignore it when credentials are specified for security reason. - if len(h.options.Users) > 0 { - cc.Close() - return - } - handler = &socks4Handler{options: h.options} case gosocks5.Ver5: // socks5 handler = &socks5Handler{options: h.options} - default: // http - handler = &httpHandler{options: h.options} } handler.Init() handler.Handle(cc) diff --git a/gost/hosts.go b/gost/hosts.go deleted file mode 100644 index 6df0325e..00000000 --- a/gost/hosts.go +++ /dev/null @@ -1,160 +0,0 @@ -package gost - -import ( - "bufio" - "io" - "net" - "sync" - "time" - - "github.com/go-log/log" -) - -// Host is a static mapping from hostname to IP. -type Host struct { - IP net.IP - Hostname string - Aliases []string -} - -// NewHost creates a Host. -func NewHost(ip net.IP, hostname string, aliases ...string) Host { - return Host{ - IP: ip, - Hostname: hostname, - Aliases: aliases, - } -} - -// Hosts is a static table lookup for hostnames. -// For each host a single line should be present with the following information: -// IP_address canonical_hostname [aliases...] -// Fields of the entry are separated by any number of blanks and/or tab characters. -// Text from a "#" character until the end of the line is a comment, and is ignored. -type Hosts struct { - hosts []Host - period time.Duration - stopped chan struct{} - mux sync.RWMutex -} - -// NewHosts creates a Hosts with optional list of hosts. -func NewHosts(hosts ...Host) *Hosts { - return &Hosts{ - hosts: hosts, - stopped: make(chan struct{}), - } -} - -// AddHost adds host(s) to the host table. -func (h *Hosts) AddHost(host ...Host) { - h.mux.Lock() - defer h.mux.Unlock() - - h.hosts = append(h.hosts, host...) -} - -// Lookup searches the IP address corresponds to the given host from the host table. -func (h *Hosts) Lookup(host string) (ip net.IP) { - if h == nil || host == "" { - return - } - - h.mux.RLock() - defer h.mux.RUnlock() - - for _, h := range h.hosts { - if h.Hostname == host { - ip = h.IP - break - } - for _, alias := range h.Aliases { - if alias == host { - ip = h.IP - break - } - } - } - if ip != nil && Debug { - log.Logf("[hosts] hit: %s %s", host, ip.String()) - } - return -} - -// Reload parses config from r, then live reloads the hosts. -func (h *Hosts) Reload(r io.Reader) error { - var period time.Duration - var hosts []Host - - if r == nil || h.Stopped() { - return nil - } - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := scanner.Text() - ss := splitLine(line) - if len(ss) < 2 { - continue // invalid lines are ignored - } - - switch ss[0] { - case "reload": // reload option - period, _ = time.ParseDuration(ss[1]) - default: - ip := net.ParseIP(ss[0]) - if ip == nil { - break // invalid IP addresses are ignored - } - host := Host{ - IP: ip, - Hostname: ss[1], - } - if len(ss) > 2 { - host.Aliases = ss[2:] - } - hosts = append(hosts, host) - } - } - if err := scanner.Err(); err != nil { - return err - } - - h.mux.Lock() - h.period = period - h.hosts = hosts - h.mux.Unlock() - - return nil -} - -// Period returns the reload period -func (h *Hosts) Period() time.Duration { - if h.Stopped() { - return -1 - } - - h.mux.RLock() - defer h.mux.RUnlock() - - return h.period -} - -// Stop stops reloading. -func (h *Hosts) Stop() { - select { - case <-h.stopped: - default: - close(h.stopped) - } -} - -// Stopped checks whether the reloader is stopped. -func (h *Hosts) Stopped() bool { - select { - case <-h.stopped: - return true - default: - return false - } -} diff --git a/gost/hosts_test.go b/gost/hosts_test.go deleted file mode 100644 index 2fbae328..00000000 --- a/gost/hosts_test.go +++ /dev/null @@ -1,130 +0,0 @@ -package gost - -import ( - "bytes" - "io" - "net" - "testing" - "time" -) - -var hostsLookupTests = []struct { - hosts []Host - host string - ip net.IP -}{ - {nil, "", nil}, - {nil, "example.com", nil}, - {[]Host{}, "", nil}, - {[]Host{}, "example.com", nil}, - {[]Host{NewHost(nil, "")}, "", nil}, - {[]Host{NewHost(nil, "example.com")}, "example.com", nil}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "")}, "", nil}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example.com", net.IPv4(192, 168, 1, 1)}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com")}, "example", nil}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "example", net.IPv4(192, 168, 1, 1)}, - {[]Host{NewHost(net.IPv4(192, 168, 1, 1), "example.com", "example", "examples")}, "examples", net.IPv4(192, 168, 1, 1)}, -} - -func TestHostsLookup(t *testing.T) { - for i, tc := range hostsLookupTests { - hosts := NewHosts() - hosts.AddHost(tc.hosts...) - ip := hosts.Lookup(tc.host) - if !ip.Equal(tc.ip) { - t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) - } - } -} - -var HostsReloadTests = []struct { - r io.Reader - period time.Duration - host string - ip net.IP - stopped bool -}{ - { - r: nil, - period: 0, - host: "", - ip: nil, - }, - { - r: bytes.NewBufferString(""), - period: 0, - host: "example.com", - ip: nil, - }, - { - r: bytes.NewBufferString("reload 10s"), - period: 10 * time.Second, - host: "example.com", - ip: nil, - }, - { - r: bytes.NewBufferString("#reload 10s\ninvalid.ip.addr example.com"), - period: 0, - ip: nil, - }, - { - r: bytes.NewBufferString("reload 10s\n192.168.1.1"), - period: 10 * time.Second, - host: "", - ip: nil, - }, - { - r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com"), - period: 0, - host: "example.com", - ip: net.IPv4(192, 168, 1, 1), - }, - { - r: bytes.NewBufferString("#reload 10s\n#192.168.1.1 example.com"), - period: 0, - host: "example.com", - ip: nil, - stopped: true, - }, - { - r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), - period: 0, - host: "example", - ip: net.IPv4(192, 168, 1, 1), - stopped: true, - }, - { - r: bytes.NewBufferString("#reload 10s\n192.168.1.1 example.com example examples"), - period: 0, - host: "examples", - ip: net.IPv4(192, 168, 1, 1), - stopped: true, - }, -} - -func TestHostsReload(t *testing.T) { - for i, tc := range HostsReloadTests { - hosts := NewHosts() - if err := hosts.Reload(tc.r); err != nil { - t.Error(err) - } - if hosts.Period() != tc.period { - t.Errorf("#%d test failed: period value should be %v, got %v", - i, tc.period, hosts.Period()) - } - ip := hosts.Lookup(tc.host) - if !ip.Equal(tc.ip) { - t.Errorf("#%d test failed: lookup should be %s, got %s", i, tc.ip, ip) - } - if tc.stopped { - hosts.Stop() - if hosts.Period() >= 0 { - t.Errorf("period of the stopped reloader should be minus value") - } - } - if hosts.Stopped() != tc.stopped { - t.Errorf("#%d test failed: stopped value should be %v, got %v", - i, tc.stopped, hosts.Stopped()) - } - } -} diff --git a/gost/http.go b/gost/http.go deleted file mode 100644 index 1b17e6fd..00000000 --- a/gost/http.go +++ /dev/null @@ -1,474 +0,0 @@ -package gost - -import ( - "bufio" - "bytes" - "context" - "encoding/base64" - "fmt" - "net" - "net/http" - "net/http/httputil" - "net/url" - "os" - "strconv" - "strings" - "time" - - "github.com/go-log/log" -) - -type httpConnector struct { - User *url.Userinfo -} - -// HTTPConnector creates a Connector for HTTP proxy client. -// It accepts an optional auth info for HTTP Basic Authentication. -func HTTPConnector(user *url.Userinfo) Connector { - return &httpConnector{User: user} -} - -func (c *httpConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { - return c.ConnectContext(context.Background(), conn, "tcp", address, options...) -} - -func (c *httpConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { - switch network { - case "udp", "udp4", "udp6": - return nil, fmt.Errorf("%s unsupported", network) - } - - opts := &ConnectOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = ConnectTimeout - } - ua := opts.UserAgent - if ua == "" { - ua = DefaultUserAgent - } - - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - req := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Host: address}, - Host: address, - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - } - req.Header.Set("User-Agent", ua) - req.Header.Set("Proxy-Connection", "keep-alive") - - user := opts.User - if user == nil { - user = c.User - } - - if user != nil { - u := user.Username() - p, _ := user.Password() - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) - } - - if err := req.Write(conn); err != nil { - return nil, err - } - - if Debug { - dump, _ := httputil.DumpRequest(req, false) - log.Log(string(dump)) - } - - resp, err := http.ReadResponse(bufio.NewReader(conn), req) - if err != nil { - return nil, err - } - - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Log(string(dump)) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("%s", resp.Status) - } - - return conn, nil -} - -type httpHandler struct { - options *HandlerOptions -} - -// HTTPHandler creates a server Handler for HTTP proxy server. -func HTTPHandler(opts ...HandlerOption) Handler { - h := &httpHandler{} - h.Init(opts...) - return h -} - -func (h *httpHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - for _, opt := range options { - opt(h.options) - } -} - -func (h *httpHandler) Handle(conn net.Conn) { - defer conn.Close() - - req, err := http.ReadRequest(bufio.NewReader(conn)) - if err != nil { - log.Logf("[http] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - defer req.Body.Close() - - h.handleRequest(conn, req) -} - -func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { - if req == nil { - return - } - - // try to get the actual host. - if v := req.Header.Get("Gost-Target"); v != "" { - if h, err := decodeServerName(v); err == nil { - req.Host = h - } - } - - host := req.Host - if _, port, _ := net.SplitHostPort(host); port == "" { - host = net.JoinHostPort(host, "80") - } - - u, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) - if u != "" { - u += "@" - } - log.Logf("[http] %s%s -> %s -> %s", - u, conn.RemoteAddr(), h.options.Node.String(), host) - - if Debug { - dump, _ := httputil.DumpRequest(req, false) - log.Logf("[http] %s -> %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - - req.Header.Del("Gost-Target") - - resp := &http.Response{ - ProtoMajor: 1, - ProtoMinor: 1, - Header: http.Header{}, - } - resp.Header.Add("Proxy-Agent", "gost/"+Version) - - if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[http] %s - %s : Unauthorized to tcp connect to %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - resp.StatusCode = http.StatusForbidden - - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - - resp.Write(conn) - return - } - - if h.options.Bypass.Contains(host) { - resp.StatusCode = http.StatusForbidden - - log.Logf("[http] %s - %s bypass %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - - resp.Write(conn) - return - } - - if !h.authenticate(conn, req, resp) { - return - } - - if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") { - resp.StatusCode = http.StatusBadRequest - - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[http] %s <- %s\n%s", - conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - - resp.Write(conn) - return - } - - req.Header.Del("Proxy-Authorization") - - retries := 1 - if h.options.Chain != nil && h.options.Chain.Retries > 0 { - retries = h.options.Chain.Retries - } - if h.options.Retries > 0 { - retries = h.options.Retries - } - - var err error - var cc net.Conn - var route *Chain - for i := 0; i < retries; i++ { - route, err = h.options.Chain.selectRouteFor(host) - if err != nil { - log.Logf("[http] %s -> %s : %s", - conn.RemoteAddr(), conn.LocalAddr(), err) - continue - } - - buf := bytes.Buffer{} - fmt.Fprintf(&buf, "%s -> %s -> ", - conn.RemoteAddr(), h.options.Node.String()) - for _, nd := range route.route { - fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) - } - fmt.Fprintf(&buf, "%s", host) - log.Log("[route]", buf.String()) - - // forward http request - lastNode := route.LastNode() - if req.Method != http.MethodConnect && lastNode.Protocol == "http" { - err = h.forwardRequest(conn, req, route) - if err == nil { - return - } - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - continue - } - - cc, err = route.Dial(host, - TimeoutChainOption(h.options.Timeout), - HostsChainOption(h.options.Hosts), - ResolverChainOption(h.options.Resolver), - ) - if err == nil { - break - } - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - } - - if err != nil { - resp.StatusCode = http.StatusServiceUnavailable - - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - - resp.Write(conn) - return - } - defer cc.Close() - - if req.Method == http.MethodConnect { - b := []byte("HTTP/1.1 200 Connection established\r\n" + - "Proxy-Agent: gost/" + Version + "\r\n\r\n") - if Debug { - log.Logf("[http] %s <- %s\n%s", conn.RemoteAddr(), conn.LocalAddr(), string(b)) - } - conn.Write(b) - } else { - req.Header.Del("Proxy-Connection") - - if err = req.Write(cc); err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - } - - log.Logf("[http] %s <-> %s", conn.RemoteAddr(), host) - transport(conn, cc) - log.Logf("[http] %s >-< %s", conn.RemoteAddr(), host) -} - -func (h *httpHandler) authenticate(conn net.Conn, req *http.Request, resp *http.Response) (ok bool) { - u, p, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) - if Debug && (u != "" || p != "") { - log.Logf("[http] %s -> %s : Authorization '%s' '%s'", - conn.RemoteAddr(), conn.LocalAddr(), u, p) - } - if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { - return true - } - - // probing resistance is enabled, and knocking host is mismatch. - if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 && - (h.options.KnockingHost == "" || !strings.EqualFold(req.URL.Hostname(), h.options.KnockingHost)) { - resp.StatusCode = http.StatusServiceUnavailable // default status code - - switch ss[0] { - case "code": - resp.StatusCode, _ = strconv.Atoi(ss[1]) - case "web": - url := ss[1] - if !strings.HasPrefix(url, "http") { - url = "http://" + url - } - if r, err := http.Get(url); err == nil { - resp = r - } - case "host": - cc, err := net.Dial("tcp", ss[1]) - if err == nil { - defer cc.Close() - - req.Write(cc) - log.Logf("[http] %s <-> %s : forward to %s", - conn.RemoteAddr(), conn.LocalAddr(), ss[1]) - transport(conn, cc) - log.Logf("[http] %s >-< %s : forward to %s", - conn.RemoteAddr(), conn.LocalAddr(), ss[1]) - return - } - case "file": - f, _ := os.Open(ss[1]) - if f != nil { - resp.StatusCode = http.StatusOK - if finfo, _ := f.Stat(); finfo != nil { - resp.ContentLength = finfo.Size() - } - resp.Header.Set("Content-Type", "text/html") - resp.Body = f - } - } - } - - if resp.StatusCode == 0 { - log.Logf("[http] %s <- %s : proxy authentication required", - conn.RemoteAddr(), conn.LocalAddr()) - resp.StatusCode = http.StatusProxyAuthRequired - resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"") - if strings.ToLower(req.Header.Get("Proxy-Connection")) == "keep-alive" { - // XXX libcurl will keep sending auth request in same conn - // which we don't supported yet. - resp.Header.Add("Connection", "close") - resp.Header.Add("Proxy-Connection", "close") - } - } else { - resp.Header = http.Header{} - resp.Header.Set("Server", "nginx/1.14.1") - resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) - if resp.StatusCode == http.StatusOK { - resp.Header.Set("Connection", "keep-alive") - } - } - - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[http] %s <- %s\n%s", - conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - - resp.Write(conn) - return -} - -func (h *httpHandler) forwardRequest(conn net.Conn, req *http.Request, route *Chain) error { - if route.IsEmpty() { - return nil - } - - host := req.Host - var userpass string - - if user := route.LastNode().User; user != nil { - u := user.Username() - p, _ := user.Password() - userpass = base64.StdEncoding.EncodeToString([]byte(u + ":" + p)) - } - - cc, err := route.Conn() - if err != nil { - return err - } - defer cc.Close() - - errc := make(chan error, 1) - go func() { - errc <- copyBuffer(conn, cc) - }() - - go func() { - for { - if userpass != "" { - req.Header.Set("Proxy-Authorization", "Basic "+userpass) - } - - cc.SetWriteDeadline(time.Now().Add(WriteTimeout)) - if !req.URL.IsAbs() { - req.URL.Scheme = "http" // make sure that the URL is absolute - } - err := req.WriteProxy(cc) - if err != nil { - log.Logf("[http] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - errc <- err - return - } - cc.SetWriteDeadline(time.Time{}) - - req, err = http.ReadRequest(bufio.NewReader(conn)) - if err != nil { - errc <- err - return - } - - if Debug { - dump, _ := httputil.DumpRequest(req, false) - log.Logf("[http] %s -> %s\n%s", - conn.RemoteAddr(), conn.LocalAddr(), string(dump)) - } - } - }() - - log.Logf("[http] %s <-> %s", conn.RemoteAddr(), host) - <-errc - log.Logf("[http] %s >-< %s", conn.RemoteAddr(), host) - - return nil -} - -func basicProxyAuth(proxyAuth string) (username, password string, ok bool) { - if proxyAuth == "" { - return - } - - if !strings.HasPrefix(proxyAuth, "Basic ") { - return - } - c, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(proxyAuth, "Basic ")) - if err != nil { - return - } - cs := string(c) - s := strings.IndexByte(cs, ':') - if s < 0 { - return - } - - return cs[:s], cs[s+1:], true -} diff --git a/gost/http2.go b/gost/http2.go deleted file mode 100644 index 1c3cf5a9..00000000 --- a/gost/http2.go +++ /dev/null @@ -1,965 +0,0 @@ -package gost - -import ( - "bufio" - "bytes" - "context" - "crypto/tls" - "encoding/base64" - "errors" - "fmt" - "io" - "io/ioutil" - "net" - "net/http" - "net/http/httputil" - "net/url" - "os" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-log/log" - "golang.org/x/net/http2" -) - -type http2Connector struct { - User *url.Userinfo -} - -// HTTP2Connector creates a Connector for HTTP2 proxy client. -// It accepts an optional auth info for HTTP Basic Authentication. -func HTTP2Connector(user *url.Userinfo) Connector { - return &http2Connector{User: user} -} - -func (c *http2Connector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { - return c.ConnectContext(context.Background(), conn, "tcp", address, options...) -} - -func (c *http2Connector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { - switch network { - case "udp", "udp4", "udp6": - return nil, fmt.Errorf("%s unsupported", network) - } - - opts := &ConnectOptions{} - for _, option := range options { - option(opts) - } - ua := opts.UserAgent - if ua == "" { - ua = DefaultUserAgent - } - - cc, ok := conn.(*http2ClientConn) - if !ok { - return nil, errors.New("wrong connection type") - } - - pr, pw := io.Pipe() - req := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Scheme: "https", Host: cc.addr}, - Header: make(http.Header), - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - Body: pr, - Host: address, - ContentLength: -1, - } - req.Header.Set("User-Agent", ua) - - user := opts.User - if user == nil { - user = c.User - } - - if user != nil { - u := user.Username() - p, _ := user.Password() - req.Header.Set("Proxy-Authorization", - "Basic "+base64.StdEncoding.EncodeToString([]byte(u+":"+p))) - } - if Debug { - dump, _ := httputil.DumpRequest(req, false) - log.Log("[http2]", string(dump)) - } - resp, err := cc.client.Do(req) - if err != nil { - cc.Close() - return nil, err - } - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Log("[http2]", string(dump)) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, errors.New(resp.Status) - } - hc := &http2Conn{ - r: resp.Body, - w: pw, - closed: make(chan struct{}), - } - - hc.remoteAddr, _ = net.ResolveTCPAddr("tcp", address) - hc.localAddr, _ = net.ResolveTCPAddr("tcp", cc.addr) - - return hc, nil -} - -type http2Transporter struct { - clients map[string]*http.Client - clientMutex sync.Mutex - tlsConfig *tls.Config -} - -// HTTP2Transporter creates a Transporter that is used by HTTP2 h2 proxy client. -func HTTP2Transporter(config *tls.Config) Transporter { - if config == nil { - config = &tls.Config{InsecureSkipVerify: true} - } - return &http2Transporter{ - clients: make(map[string]*http.Client), - tlsConfig: config, - } -} - -func (tr *http2Transporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.clientMutex.Lock() - defer tr.clientMutex.Unlock() - - client, ok := tr.clients[addr] - if !ok { - // NOTE: There is no real connection to the HTTP2 server at this moment. - // So we try to connect to the server to check the server health. - conn, err := opts.Chain.Dial(addr) - if err != nil { - log.Log("http2 dial:", addr, err) - return nil, err - } - conn.Close() - - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - transport := http2.Transport{ - TLSClientConfig: tr.tlsConfig, - DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { - conn, err := opts.Chain.Dial(adr) - if err != nil { - return nil, err - } - return wrapTLSClient(conn, cfg, timeout) - }, - } - client = &http.Client{ - Transport: &transport, - // Timeout: timeout, - } - tr.clients[addr] = client - } - - return &http2ClientConn{ - addr: addr, - client: client, - onClose: func() { - tr.clientMutex.Lock() - defer tr.clientMutex.Unlock() - delete(tr.clients, addr) - }, - }, nil -} - -func (tr *http2Transporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - return conn, nil -} - -func (tr *http2Transporter) Multiplex() bool { - return true -} - -// TODO: clean closed clients -type h2Transporter struct { - clients map[string]*http.Client - clientMutex sync.Mutex - tlsConfig *tls.Config - path string -} - -// H2Transporter creates a Transporter that is used by HTTP2 h2 tunnel client. -func H2Transporter(config *tls.Config, path string) Transporter { - if config == nil { - config = &tls.Config{InsecureSkipVerify: true} - } - return &h2Transporter{ - clients: make(map[string]*http.Client), - tlsConfig: config, - path: path, - } -} - -// H2CTransporter creates a Transporter that is used by HTTP2 h2c tunnel client. -func H2CTransporter(path string) Transporter { - return &h2Transporter{ - clients: make(map[string]*http.Client), - path: path, - } -} - -func (tr *h2Transporter) Dial(addr string, options ...DialOption) (net.Conn, error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.clientMutex.Lock() - client, ok := tr.clients[addr] - if !ok { - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - - transport := http2.Transport{ - TLSClientConfig: tr.tlsConfig, - DialTLS: func(network, adr string, cfg *tls.Config) (net.Conn, error) { - conn, err := opts.Chain.Dial(addr) - if err != nil { - return nil, err - } - if tr.tlsConfig == nil { - return conn, nil - } - return wrapTLSClient(conn, cfg, timeout) - }, - } - client = &http.Client{ - Transport: &transport, - // Timeout: timeout, - } - tr.clients[addr] = client - } - tr.clientMutex.Unlock() - - pr, pw := io.Pipe() - req := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Scheme: "https", Host: opts.Host}, - Header: make(http.Header), - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - Body: pr, - Host: opts.Host, - ContentLength: -1, - } - if tr.path != "" { - req.Method = http.MethodGet - req.URL.Path = tr.path - } - - if Debug { - dump, _ := httputil.DumpRequest(req, false) - log.Log("[http2]", string(dump)) - } - resp, err := client.Do(req) - if err != nil { - return nil, err - } - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Log("[http2]", string(dump)) - } - - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - return nil, errors.New(resp.Status) - } - conn := &http2Conn{ - r: resp.Body, - w: pw, - closed: make(chan struct{}), - } - conn.remoteAddr, _ = net.ResolveTCPAddr("tcp", addr) - conn.localAddr = &net.TCPAddr{IP: net.IPv4zero, Port: 0} - return conn, nil -} - -func (tr *h2Transporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - return conn, nil -} - -func (tr *h2Transporter) Multiplex() bool { - return true -} - -type http2Handler struct { - options *HandlerOptions -} - -// HTTP2Handler creates a server Handler for HTTP2 proxy server. -func HTTP2Handler(opts ...HandlerOption) Handler { - h := &http2Handler{} - h.Init(opts...) - - return h -} - -func (h *http2Handler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - for _, opt := range options { - opt(h.options) - } -} - -func (h *http2Handler) Handle(conn net.Conn) { - defer conn.Close() - - h2c, ok := conn.(*http2ServerConn) - if !ok { - log.Log("[http2] wrong connection type") - return - } - - h.roundTrip(h2c.w, h2c.r) -} - -func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { - host := r.Header.Get("Gost-Target") - if host == "" { - host = r.Host - } - - if _, port, _ := net.SplitHostPort(host); port == "" { - host = net.JoinHostPort(host, "80") - } - - laddr := h.options.Addr - u, _, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) - if u != "" { - u += "@" - } - log.Logf("[http2] %s%s -> %s -> %s", - u, r.RemoteAddr, h.options.Node.String(), host) - - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Logf("[http2] %s - %s\n%s", r.RemoteAddr, laddr, string(dump)) - } - - w.Header().Set("Proxy-Agent", "gost/"+Version) - - if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[http2] %s - %s : Unauthorized to tcp connect to %s", - r.RemoteAddr, laddr, host) - w.WriteHeader(http.StatusForbidden) - return - } - - if h.options.Bypass.Contains(host) { - log.Logf("[http2] %s - %s bypass %s", - r.RemoteAddr, laddr, host) - w.WriteHeader(http.StatusForbidden) - return - } - - resp := &http.Response{ - ProtoMajor: 2, - ProtoMinor: 0, - Header: http.Header{}, - Body: ioutil.NopCloser(bytes.NewReader([]byte{})), - } - - if !h.authenticate(w, r, resp) { - return - } - - // delete the proxy related headers. - r.Header.Del("Proxy-Authorization") - r.Header.Del("Proxy-Connection") - - retries := 1 - if h.options.Chain != nil && h.options.Chain.Retries > 0 { - retries = h.options.Chain.Retries - } - if h.options.Retries > 0 { - retries = h.options.Retries - } - - var err error - var cc net.Conn - var route *Chain - for i := 0; i < retries; i++ { - route, err = h.options.Chain.selectRouteFor(host) - if err != nil { - log.Logf("[http2] %s -> %s : %s", - r.RemoteAddr, laddr, err) - continue - } - - buf := bytes.Buffer{} - fmt.Fprintf(&buf, "%s -> %s -> ", - r.RemoteAddr, h.options.Node.String()) - for _, nd := range route.route { - fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) - } - fmt.Fprintf(&buf, "%s", host) - log.Log("[route]", buf.String()) - - cc, err = route.Dial(host, - TimeoutChainOption(h.options.Timeout), - HostsChainOption(h.options.Hosts), - ResolverChainOption(h.options.Resolver), - ) - if err == nil { - break - } - log.Logf("[http2] %s -> %s : %s", r.RemoteAddr, laddr, err) - } - - if err != nil { - w.WriteHeader(http.StatusServiceUnavailable) - return - } - defer cc.Close() - - if r.Method == http.MethodConnect { - w.WriteHeader(http.StatusOK) - if fw, ok := w.(http.Flusher); ok { - fw.Flush() - } - - // compatible with HTTP1.x - if hj, ok := w.(http.Hijacker); ok && r.ProtoMajor == 1 { - // we take over the underly connection - conn, _, err := hj.Hijack() - if err != nil { - log.Logf("[http2] %s -> %s : %s", - r.RemoteAddr, laddr, err) - w.WriteHeader(http.StatusInternalServerError) - return - } - defer conn.Close() - - log.Logf("[http2] %s <-> %s : downgrade to HTTP/1.1", r.RemoteAddr, host) - transport(conn, cc) - log.Logf("[http2] %s >-< %s", r.RemoteAddr, host) - return - } - - log.Logf("[http2] %s <-> %s", r.RemoteAddr, host) - transport(&readWriter{r: r.Body, w: flushWriter{w}}, cc) - log.Logf("[http2] %s >-< %s", r.RemoteAddr, host) - return - } - - log.Logf("[http2] %s <-> %s", r.RemoteAddr, host) - if err := h.forwardRequest(w, r, cc); err != nil { - log.Logf("[http2] %s - %s : %s", r.RemoteAddr, host, err) - } - log.Logf("[http2] %s >-< %s", r.RemoteAddr, host) -} - -func (h *http2Handler) authenticate(w http.ResponseWriter, r *http.Request, resp *http.Response) (ok bool) { - laddr := h.options.Addr - u, p, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) - if Debug && (u != "" || p != "") { - log.Logf("[http2] %s - %s : Authorization '%s' '%s'", r.RemoteAddr, laddr, u, p) - } - if h.options.Authenticator == nil || h.options.Authenticator.Authenticate(u, p) { - return true - } - - // probing resistance is enabled, and knocking host is mismatch. - if ss := strings.SplitN(h.options.ProbeResist, ":", 2); len(ss) == 2 && - (h.options.KnockingHost == "" || !strings.EqualFold(r.URL.Hostname(), h.options.KnockingHost)) { - resp.StatusCode = http.StatusServiceUnavailable // default status code - w.Header().Del("Proxy-Agent") - - switch ss[0] { - case "code": - resp.StatusCode, _ = strconv.Atoi(ss[1]) - case "web": - url := ss[1] - if !strings.HasPrefix(url, "http") { - url = "http://" + url - } - if r, err := http.Get(url); err == nil { - resp = r - } - case "host": - cc, err := net.Dial("tcp", ss[1]) - if err == nil { - defer cc.Close() - log.Logf("[http2] %s <-> %s : forward to %s", r.RemoteAddr, laddr, ss[1]) - if err := h.forwardRequest(w, r, cc); err != nil { - log.Logf("[http2] %s - %s : %s", r.RemoteAddr, laddr, err) - } - log.Logf("[http2] %s >-< %s : forward to %s", r.RemoteAddr, laddr, ss[1]) - return - } - case "file": - f, _ := os.Open(ss[1]) - if f != nil { - resp.StatusCode = http.StatusOK - if finfo, _ := f.Stat(); finfo != nil { - resp.ContentLength = finfo.Size() - } - resp.Body = f - } - } - } - - if resp.StatusCode == 0 { - log.Logf("[http2] %s <- %s : proxy authentication required", r.RemoteAddr, laddr) - resp.StatusCode = http.StatusProxyAuthRequired - resp.Header.Add("Proxy-Authenticate", "Basic realm=\"gost\"") - } else { - resp.Header = http.Header{} - resp.Header.Set("Server", "nginx/1.14.1") - resp.Header.Set("Date", time.Now().Format(http.TimeFormat)) - if resp.ContentLength > 0 { - resp.Header.Set("Content-Type", "text/html") - } - if resp.StatusCode == http.StatusOK { - resp.Header.Set("Connection", "keep-alive") - } - } - - if Debug { - dump, _ := httputil.DumpResponse(resp, false) - log.Logf("[http2] %s <- %s\n%s", r.RemoteAddr, laddr, string(dump)) - } - - h.writeResponse(w, resp) - resp.Body.Close() - - return -} - -func (h *http2Handler) forwardRequest(w http.ResponseWriter, r *http.Request, rw io.ReadWriter) (err error) { - if err = r.Write(rw); err != nil { - return - } - - resp, err := http.ReadResponse(bufio.NewReader(rw), r) - if err != nil { - return - } - defer resp.Body.Close() - - return h.writeResponse(w, resp) -} - -func (h *http2Handler) writeResponse(w http.ResponseWriter, resp *http.Response) error { - for k, v := range resp.Header { - for _, vv := range v { - w.Header().Add(k, vv) - } - } - w.WriteHeader(resp.StatusCode) - _, err := io.Copy(flushWriter{w}, resp.Body) - return err -} - -type http2Listener struct { - server *http.Server - connChan chan *http2ServerConn - addr net.Addr - errChan chan error -} - -// HTTP2Listener creates a Listener for HTTP2 proxy server. -func HTTP2Listener(addr string, config *tls.Config) (Listener, error) { - l := &http2Listener{ - connChan: make(chan *http2ServerConn, 1024), - errChan: make(chan error, 1), - } - if config == nil { - config = DefaultTLSConfig - } - server := &http.Server{ - Addr: addr, - Handler: http.HandlerFunc(l.handleFunc), - TLSConfig: config, - } - if err := http2.ConfigureServer(server, nil); err != nil { - return nil, err - } - l.server = server - - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - l.addr = ln.Addr() - - ln = tls.NewListener(tcpKeepAliveListener{ln.(*net.TCPListener)}, config) - go func() { - err := server.Serve(ln) - if err != nil { - log.Log("[http2]", err) - } - }() - - return l, nil -} - -func (l *http2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { - conn := &http2ServerConn{ - r: r, - w: w, - closed: make(chan struct{}), - } - select { - case l.connChan <- conn: - default: - log.Logf("[http2] %s - %s: connection queue is full", r.RemoteAddr, l.server.Addr) - return - } - - <-conn.closed -} - -func (l *http2Listener) Accept() (conn net.Conn, err error) { - select { - case conn = <-l.connChan: - case err = <-l.errChan: - if err == nil { - err = errors.New("accpet on closed listener") - } - } - return -} - -func (l *http2Listener) Addr() net.Addr { - return l.addr -} - -func (l *http2Listener) Close() (err error) { - select { - case <-l.errChan: - default: - err = l.server.Close() - l.errChan <- err - close(l.errChan) - } - return nil -} - -type h2Listener struct { - net.Listener - server *http2.Server - tlsConfig *tls.Config - path string - connChan chan net.Conn - errChan chan error -} - -// H2Listener creates a Listener for HTTP2 h2 tunnel server. -func H2Listener(addr string, config *tls.Config, path string) (Listener, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - if config == nil { - config = DefaultTLSConfig - } - - l := &h2Listener{ - Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, - server: &http2.Server{ - // MaxConcurrentStreams: 1000, - PermitProhibitedCipherSuites: true, - IdleTimeout: 5 * time.Minute, - }, - tlsConfig: config, - path: path, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - go l.listenLoop() - - return l, nil -} - -// H2CListener creates a Listener for HTTP2 h2c tunnel server. -func H2CListener(addr string, path string) (Listener, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - l := &h2Listener{ - Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, - server: &http2.Server{ - // MaxConcurrentStreams: 1000, - }, - path: path, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - go l.listenLoop() - - return l, nil -} - -func (l *h2Listener) listenLoop() { - for { - conn, err := l.Listener.Accept() - if err != nil { - log.Log("[http2] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - go l.handleLoop(conn) - } -} - -func (l *h2Listener) handleLoop(conn net.Conn) { - if l.tlsConfig != nil { - conn = tls.Server(conn, l.tlsConfig) - } - - if tc, ok := conn.(*tls.Conn); ok { - // NOTE: HTTP2 server will check the TLS version, - // so we must ensure that the TLS connection is handshake completed. - if err := tc.Handshake(); err != nil { - log.Logf("[http2] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - } - - opt := http2.ServeConnOpts{ - Handler: http.HandlerFunc(l.handleFunc), - } - l.server.ServeConn(conn, &opt) -} - -func (l *h2Listener) handleFunc(w http.ResponseWriter, r *http.Request) { - log.Logf("[http2] %s -> %s %s %s %s", - r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto) - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Log("[http2]", string(dump)) - } - w.Header().Set("Proxy-Agent", "gost/"+Version) - conn, err := l.upgrade(w, r) - if err != nil { - log.Logf("[http2] %s - %s %s %s %s: %s", - r.RemoteAddr, r.Host, r.Method, r.RequestURI, r.Proto, err) - return - } - select { - case l.connChan <- conn: - default: - conn.Close() - log.Logf("[http2] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) - } - - <-conn.closed // NOTE: we need to wait for streaming end, or the connection will be closed -} - -func (l *h2Listener) upgrade(w http.ResponseWriter, r *http.Request) (*http2Conn, error) { - if l.path == "" && r.Method != http.MethodConnect { - w.WriteHeader(http.StatusMethodNotAllowed) - return nil, errors.New("method not allowed") - } - - if l.path != "" && r.RequestURI != l.path { - w.WriteHeader(http.StatusBadRequest) - return nil, errors.New("bad request") - } - - w.WriteHeader(http.StatusOK) - if fw, ok := w.(http.Flusher); ok { - fw.Flush() // write header to client - } - - remoteAddr, _ := net.ResolveTCPAddr("tcp", r.RemoteAddr) - if remoteAddr == nil { - remoteAddr = &net.TCPAddr{ - IP: net.IPv4zero, - Port: 0, - } - } - conn := &http2Conn{ - r: r.Body, - w: flushWriter{w}, - localAddr: l.Listener.Addr(), - remoteAddr: remoteAddr, - closed: make(chan struct{}), - } - return conn, nil -} - -func (l *h2Listener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -// HTTP2 connection, wrapped up just like a net.Conn -type http2Conn struct { - r io.Reader - w io.Writer - remoteAddr net.Addr - localAddr net.Addr - closed chan struct{} -} - -func (c *http2Conn) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} - -func (c *http2Conn) Write(b []byte) (n int, err error) { - return c.w.Write(b) -} - -func (c *http2Conn) Close() (err error) { - select { - case <-c.closed: - return - default: - close(c.closed) - } - if rc, ok := c.r.(io.Closer); ok { - err = rc.Close() - } - if w, ok := c.w.(io.Closer); ok { - err = w.Close() - } - return -} - -func (c *http2Conn) LocalAddr() net.Addr { - return c.localAddr -} - -func (c *http2Conn) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *http2Conn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *http2Conn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *http2Conn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -// a dummy HTTP2 server conn used by HTTP2 handler -type http2ServerConn struct { - r *http.Request - w http.ResponseWriter - closed chan struct{} -} - -func (c *http2ServerConn) Read(b []byte) (n int, err error) { - return 0, &net.OpError{Op: "read", Net: "http2", Source: nil, Addr: nil, Err: errors.New("read not supported")} -} - -func (c *http2ServerConn) Write(b []byte) (n int, err error) { - return 0, &net.OpError{Op: "write", Net: "http2", Source: nil, Addr: nil, Err: errors.New("write not supported")} -} - -func (c *http2ServerConn) Close() error { - select { - case <-c.closed: - default: - close(c.closed) - } - return nil -} - -func (c *http2ServerConn) LocalAddr() net.Addr { - addr, _ := net.ResolveTCPAddr("tcp", c.r.Host) - return addr -} - -func (c *http2ServerConn) RemoteAddr() net.Addr { - addr, _ := net.ResolveTCPAddr("tcp", c.r.RemoteAddr) - return addr -} - -func (c *http2ServerConn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *http2ServerConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *http2ServerConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "http2", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -// a dummy HTTP2 client conn used by HTTP2 client connector -type http2ClientConn struct { - nopConn - addr string - client *http.Client - onClose func() -} - -func (c *http2ClientConn) Close() error { - if c.onClose != nil { - c.onClose() - } - return nil -} - -type flushWriter struct { - w io.Writer -} - -func (fw flushWriter) Write(p []byte) (n int, err error) { - defer func() { - if r := recover(); r != nil { - if s, ok := r.(string); ok { - err = errors.New(s) - log.Log("[http2]", err) - return - } - err = r.(error) - } - }() - - n, err = fw.w.Write(p) - if err != nil { - // log.Log("flush writer:", err) - return - } - if f, ok := fw.w.(http.Flusher); ok { - f.Flush() - } - return -} diff --git a/gost/http2_test.go b/gost/http2_test.go deleted file mode 100644 index 762c6dfb..00000000 --- a/gost/http2_test.go +++ /dev/null @@ -1,1151 +0,0 @@ -package gost - -import ( - "bytes" - "crypto/rand" - "crypto/tls" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/http/httptest" - "net/url" - "testing" -) - -func http2ProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - ln, err := HTTP2Listener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTP2Connector(clientInfo), - Transporter: HTTP2Transporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTP2ProxyAuth(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := http2ProxyRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTP2Proxy(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := HTTP2Listener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTP2Connector(url.UserPassword("admin", "123456")), - Transporter: HTTP2Transporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTP2ProxyParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := HTTP2Listener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTP2Connector(url.UserPassword("admin", "123456")), - Transporter: HTTP2Transporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func httpOverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := H2Listener("", tlsConfig, "/h2") - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverH2(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverH2Roundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverH2(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := H2Listener("", nil, "/h2") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverH2Parallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := H2Listener("", nil, "/h2") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := H2Listener("", tlsConfig, "/h2") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverH2(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverH2Roundtrip(httpSrv.URL, sendData, - nil, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { - ln, err := H2Listener("", tlsConfig, "/h2") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverH2(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverH2Roundtrip(httpSrv.URL, sendData, nil) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { - ln, err := H2Listener("", tlsConfig, "/h2") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverH2(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverH2Roundtrip(httpSrv.URL, sendData, nil) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverH2Roundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := H2Listener("", tlsConfig, "/h2") - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverH2(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverH2Roundtrip(httpSrv.URL, sendData, - nil, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverH2Roundtrip(targetURL string, data []byte, host string) error { - ln, err := H2Listener("", nil, "/h2") - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: H2Transporter(nil, "/h2"), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverH2(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverH2Roundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func h2ForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := H2Listener("", nil, "") - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: H2Transporter(nil, ""), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestH2ForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := h2ForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} - -func httpOverH2CRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := H2CListener("", "/h2c") - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverH2C(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverH2CRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverH2C(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := H2CListener("", "/h2c") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverH2CParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := H2CListener("", "/h2c") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverH2CRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := H2CListener("", "/h2c") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverH2C(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverH2CRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverH2CRoundtrip(targetURL string, data []byte) error { - ln, err := H2CListener("", "/h2c") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverH2C(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverH2CRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverH2CRoundtrip(targetURL string, data []byte) error { - ln, err := H2CListener("", "/h2c") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverH2C(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverH2CRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverH2CRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := H2CListener("", "/h2c") - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverH2C(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverH2CRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverH2CRoundtrip(targetURL string, data []byte, host string) error { - ln, err := H2CListener("", "/h2c") - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: H2CTransporter("/h2c"), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverH2C(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverH2CRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func h2cForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := H2CListener("", "") - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: H2CTransporter(""), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestH2CForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := h2cForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} - -func TestHTTP2ProxyWithCodeProbeResist(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - ln, err := HTTP2Listener("", nil) - if err != nil { - t.Error(err) - } - - client := &Client{ - Connector: HTTP2Connector(nil), - Transporter: HTTP2Transporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("code:400"), - ), - } - go server.Run() - defer server.Close() - - err = proxyRoundtrip(client, server, httpSrv.URL, nil) - if err == nil { - t.Error("should failed with status code 400") - } else if err.Error() != "400 Bad Request" { - t.Error("should failed with status code 400, got", err.Error()) - } -} - -func TestHTTP2ProxyWithWebProbeResist(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - ln, err := HTTP2Listener("", nil) - if err != nil { - t.Error(err) - } - - client := &Client{ - Connector: HTTP2Connector(nil), - Transporter: HTTP2Transporter(nil), - } - - u, err := url.Parse(httpSrv.URL) - if err != nil { - t.Error(err) - } - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("web:"+u.Host), - ), - } - go server.Run() - defer server.Close() - - conn, err := proxyConn(client, server) - if err != nil { - t.Error(err) - } - defer conn.Close() - - conn, err = client.Connect(conn, "github.com:443") - if err != nil { - t.Error(err) - } - recv, _ := ioutil.ReadAll(conn) - if !bytes.Equal(recv, []byte("Hello World!")) { - t.Error("data not equal") - } -} - -func TestHTTP2ProxyWithHostProbeResist(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := HTTP2Listener("", nil) - if err != nil { - t.Error(err) - } - - client := &Client{ - Connector: HTTP2Connector(nil), - Transporter: HTTP2Transporter(nil), - } - - u, err := url.Parse(httpSrv.URL) - if err != nil { - t.Error(err) - } - - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("host:"+u.Host), - ), - } - go server.Run() - defer server.Close() - - conn, err := proxyConn(client, server) - if err != nil { - t.Error(err) - } - defer conn.Close() - - cc, ok := conn.(*http2ClientConn) - if !ok { - t.Error("wrong connection type") - } - - req := &http.Request{ - Method: http.MethodConnect, - URL: &url.URL{Scheme: "https", Host: cc.addr}, - Header: make(http.Header), - Proto: "HTTP/2.0", - ProtoMajor: 2, - ProtoMinor: 0, - Body: ioutil.NopCloser(bytes.NewReader(sendData)), - Host: "github.com:443", - ContentLength: int64(len(sendData)), - } - - resp, err := cc.client.Do(req) - if err != nil { - t.Error(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - t.Error("got non-200 status:", resp.Status) - } - - recv, _ := ioutil.ReadAll(resp.Body) - if !bytes.Equal(sendData, recv) { - t.Error("data not equal") - } -} - -func TestHTTP2ProxyWithFileProbeResist(t *testing.T) { - ln, err := HTTP2Listener("", nil) - if err != nil { - t.Error(err) - } - - client := &Client{ - Connector: HTTP2Connector(nil), - Transporter: HTTP2Transporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("file:.config/probe_resist.txt"), - ), - } - go server.Run() - defer server.Close() - - conn, err := proxyConn(client, server) - if err != nil { - t.Error(err) - } - defer conn.Close() - - conn, err = client.Connect(conn, "github.com:443") - if err != nil { - t.Error(err) - } - recv, _ := ioutil.ReadAll(conn) - if !bytes.Equal(recv, []byte("Hello World!")) { - t.Error("data not equal") - } -} - -func TestHTTP2ProxyWithBypass(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - u, err := url.Parse(httpSrv.URL) - if err != nil { - t.Error(err) - } - ln, err := HTTP2Listener("", nil) - if err != nil { - t.Error(err) - } - - client := &Client{ - Connector: HTTP2Connector(nil), - Transporter: HTTP2Transporter(nil), - } - - host := u.Host - if h, _, _ := net.SplitHostPort(u.Host); h != "" { - host = h - } - server := &Server{ - Listener: ln, - Handler: HTTP2Handler( - BypassHandlerOption(NewBypassPatterns(false, host)), - ), - } - go server.Run() - defer server.Close() - - if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil { - t.Error("should failed") - } -} diff --git a/gost/http_test.go b/gost/http_test.go deleted file mode 100644 index a4f10411..00000000 --- a/gost/http_test.go +++ /dev/null @@ -1,378 +0,0 @@ -package gost - -import ( - "bytes" - "crypto/rand" - "fmt" - "io/ioutil" - "net" - "net/http" - "net/http/httptest" - "net/url" - "testing" -) - -var httpProxyTests = []struct { - cliUser *url.Userinfo - srvUsers []*url.Userinfo - errStr string -}{ - {nil, nil, ""}, - {nil, []*url.Userinfo{url.User("admin")}, "407 Proxy Authentication Required"}, - {nil, []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"}, - {url.User("admin"), []*url.Userinfo{url.User("test")}, "407 Proxy Authentication Required"}, - {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "123456")}, "407 Proxy Authentication Required"}, - {url.User("admin"), []*url.Userinfo{url.User("admin")}, ""}, - {url.User("admin"), []*url.Userinfo{url.UserPassword("admin", "")}, ""}, - {url.UserPassword("admin", "123456"), nil, ""}, - {url.UserPassword("admin", "123456"), []*url.Userinfo{url.User("admin")}, ""}, - {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, "407 Proxy Authentication Required"}, - {url.UserPassword("", "123456"), []*url.Userinfo{url.UserPassword("", "123456")}, ""}, - {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("admin", "123456")}, ""}, - {url.UserPassword("admin", "123456"), []*url.Userinfo{url.UserPassword("user", "pass"), url.UserPassword("admin", "123456")}, ""}, -} - -func httpProxyRoundtrip(targetURL string, data []byte, clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - ln, err := TCPListener("") - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: TCPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPProxyAuth(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := httpProxyRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - }) - } -} - -func TestHTTPProxyWithInvalidRequest(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := TCPListener("") - if err != nil { - t.Error(err) - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler(), - } - go server.Run() - defer server.Close() - - r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), bytes.NewReader(sendData)) - if err != nil { - t.Error(err) - } - resp, err := http.DefaultClient.Do(r) - if err != nil { - t.Error(err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusBadRequest { - t.Error("got status:", resp.Status) - } -} - -func BenchmarkHTTPProxy(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := TCPListener("") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: TCPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPProxyParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := TCPListener("") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: TCPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func TestHTTPProxyWithCodeProbeResist(t *testing.T) { - ln, err := TCPListener("") - if err != nil { - t.Error(err) - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("code:400"), - ), - } - go server.Run() - defer server.Close() - - resp, err := http.Get("http://" + ln.Addr().String()) - if err != nil { - t.Error(err) - } - defer resp.Body.Close() - - if resp.StatusCode != 400 { - t.Error("should failed with status code 400, got", resp.Status) - } -} - -func TestHTTPProxyWithWebProbeResist(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - ln, err := TCPListener("") - if err != nil { - t.Error(err) - } - - u, err := url.Parse(httpSrv.URL) - if err != nil { - t.Error(err) - } - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("web:"+u.Host), - ), - } - go server.Run() - defer server.Close() - - r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), nil) - if err != nil { - t.Error(err) - } - resp, err := http.DefaultClient.Do(r) - if err != nil { - t.Error(err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Error("got status:", resp.Status) - } - - recv, _ := ioutil.ReadAll(resp.Body) - if !bytes.Equal(recv, []byte("Hello World!")) { - t.Error("data not equal") - } -} - -func TestHTTPProxyWithHostProbeResist(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := TCPListener("") - if err != nil { - t.Error(err) - } - - u, err := url.Parse(httpSrv.URL) - if err != nil { - t.Error(err) - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("host:"+u.Host), - ), - } - go server.Run() - defer server.Close() - - r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), bytes.NewReader(sendData)) - if err != nil { - t.Error(err) - } - resp, err := http.DefaultClient.Do(r) - if err != nil { - t.Error(err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Error("got status:", resp.Status) - } - - recv, _ := ioutil.ReadAll(resp.Body) - if !bytes.Equal(sendData, recv) { - t.Error("data not equal") - } -} - -func TestHTTPProxyWithFileProbeResist(t *testing.T) { - ln, err := TCPListener("") - if err != nil { - t.Error(err) - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ProbeResistHandlerOption("file:.config/probe_resist.txt"), - ), - } - go server.Run() - defer server.Close() - - r, err := http.NewRequest("GET", "http://"+ln.Addr().String(), nil) - if err != nil { - t.Error(err) - } - resp, err := http.DefaultClient.Do(r) - if err != nil { - t.Error(err) - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - t.Error("got status:", resp.Status) - } - - recv, _ := ioutil.ReadAll(resp.Body) - if !bytes.Equal(recv, []byte("Hello World!")) { - t.Error("data not equal, got:", string(recv)) - } -} - -func TestHTTPProxyWithBypass(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - u, err := url.Parse(httpSrv.URL) - if err != nil { - t.Error(err) - } - ln, err := TCPListener("") - if err != nil { - t.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(nil), - Transporter: TCPTransporter(), - } - - host := u.Host - if h, _, _ := net.SplitHostPort(u.Host); h != "" { - host = h - } - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - BypassHandlerOption(NewBypassPatterns(false, host)), - ), - } - go server.Run() - defer server.Close() - - if err = proxyRoundtrip(client, server, httpSrv.URL, sendData); err == nil { - t.Error("should failed") - } -} diff --git a/gost/kcp.go b/gost/kcp.go deleted file mode 100644 index cbb7e9e4..00000000 --- a/gost/kcp.go +++ /dev/null @@ -1,503 +0,0 @@ -package gost - -import ( - "crypto/sha1" - "encoding/csv" - "errors" - "fmt" - "net" - "os" - "time" - - "golang.org/x/crypto/pbkdf2" - - "sync" - - "github.com/go-log/log" - "github.com/klauspost/compress/snappy" - "github.com/xtaci/kcp-go" - "github.com/xtaci/smux" - "github.com/xtaci/tcpraw" -) - -var ( - // KCPSalt is the default salt for KCP cipher. - KCPSalt = "kcp-go" -) - -// KCPConfig describes the config for KCP. -type KCPConfig struct { - Key string `json:"key"` - Crypt string `json:"crypt"` - Mode string `json:"mode"` - MTU int `json:"mtu"` - SndWnd int `json:"sndwnd"` - RcvWnd int `json:"rcvwnd"` - DataShard int `json:"datashard"` - ParityShard int `json:"parityshard"` - DSCP int `json:"dscp"` - NoComp bool `json:"nocomp"` - AckNodelay bool `json:"acknodelay"` - NoDelay int `json:"nodelay"` - Interval int `json:"interval"` - Resend int `json:"resend"` - NoCongestion int `json:"nc"` - SockBuf int `json:"sockbuf"` - KeepAlive int `json:"keepalive"` - SnmpLog string `json:"snmplog"` - SnmpPeriod int `json:"snmpperiod"` - Signal bool `json:"signal"` // Signal enables the signal SIGUSR1 feature. - TCP bool `json:"tcp"` -} - -// Init initializes the KCP config. -func (c *KCPConfig) Init() { - switch c.Mode { - case "normal": - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 40, 2, 1 - case "fast": - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 0, 30, 2, 1 - case "fast2": - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 20, 2, 1 - case "fast3": - c.NoDelay, c.Interval, c.Resend, c.NoCongestion = 1, 10, 2, 1 - } -} - -var ( - // DefaultKCPConfig is the default KCP config. - DefaultKCPConfig = KCPConfig{ - Key: "it's a secrect", - Crypt: "aes", - Mode: "fast", - MTU: 1350, - SndWnd: 1024, - RcvWnd: 1024, - DataShard: 10, - ParityShard: 3, - DSCP: 0, - NoComp: false, - AckNodelay: false, - NoDelay: 0, - Interval: 50, - Resend: 0, - NoCongestion: 0, - SockBuf: 4194304, - KeepAlive: 10, - SnmpLog: "", - SnmpPeriod: 60, - Signal: false, - TCP: false, - } -) - -type kcpTransporter struct { - sessions map[string]*muxSession - sessionMutex sync.Mutex - config *KCPConfig -} - -// KCPTransporter creates a Transporter that is used by KCP proxy client. -func KCPTransporter(config *KCPConfig) Transporter { - if config == nil { - config = &KCPConfig{} - *config = DefaultKCPConfig - } - config.Init() - - go snmpLogger(config.SnmpLog, config.SnmpPeriod) - if config.Signal { - go kcpSigHandler() - } - - return &kcpTransporter{ - config: config, - sessions: make(map[string]*muxSession), - } -} - -func (tr *kcpTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if session != nil && session.session != nil && session.session.IsClosed() { - session.Close() - delete(tr.sessions, addr) // session is dead - ok = false - } - if !ok { - raddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - if tr.config.TCP { - pc, err := tcpraw.Dial("tcp", addr) - if err != nil { - return nil, err - } - conn = &fakeTCPConn{ - raddr: raddr, - PacketConn: pc, - } - } else { - conn, err = net.ListenUDP("udp", nil) - if err != nil { - return nil, err - } - } - session = &muxSession{conn: conn} - tr.sessions[addr] = session - } - return session.conn, nil -} - -func (tr *kcpTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - config := tr.config - if opts.KCPConfig != nil { - config = opts.KCPConfig - } - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - timeout := opts.Timeout - if timeout <= 0 { - timeout = HandshakeTimeout - } - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - session, ok := tr.sessions[opts.Addr] - if !ok || session.session == nil { - s, err := tr.initSession(opts.Addr, conn, config) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - session = s - tr.sessions[opts.Addr] = session - } - cc, err := session.GetConn() - if err != nil { - session.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - - return cc, nil -} - -func (tr *kcpTransporter) initSession(addr string, conn net.Conn, config *KCPConfig) (*muxSession, error) { - pc, ok := conn.(net.PacketConn) - if !ok { - return nil, errors.New("kcp: wrong connection type") - } - - kcpconn, err := kcp.NewConn(addr, - blockCrypt(config.Key, config.Crypt, KCPSalt), - config.DataShard, config.ParityShard, pc) - if err != nil { - return nil, err - } - - kcpconn.SetStreamMode(true) - kcpconn.SetWriteDelay(false) - kcpconn.SetNoDelay(config.NoDelay, config.Interval, config.Resend, config.NoCongestion) - kcpconn.SetWindowSize(config.SndWnd, config.RcvWnd) - kcpconn.SetMtu(config.MTU) - kcpconn.SetACKNoDelay(config.AckNodelay) - - if config.DSCP > 0 { - if err := kcpconn.SetDSCP(config.DSCP); err != nil { - log.Log("[kcp]", err) - } - } - if err := kcpconn.SetReadBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - if err := kcpconn.SetWriteBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - - // stream multiplex - smuxConfig := smux.DefaultConfig() - smuxConfig.MaxReceiveBuffer = config.SockBuf - smuxConfig.KeepAliveInterval = time.Duration(config.KeepAlive) * time.Second - var cc net.Conn = kcpconn - if !config.NoComp { - cc = newCompStreamConn(kcpconn) - } - session, err := smux.Client(cc, smuxConfig) - if err != nil { - return nil, err - } - return &muxSession{conn: conn, session: session}, nil -} - -func (tr *kcpTransporter) Multiplex() bool { - return true -} - -type kcpListener struct { - config *KCPConfig - ln *kcp.Listener - connChan chan net.Conn - errChan chan error -} - -// KCPListener creates a Listener for KCP proxy server. -func KCPListener(addr string, config *KCPConfig) (Listener, error) { - if config == nil { - config = &KCPConfig{} - *config = DefaultKCPConfig - } - config.Init() - - var err error - var ln *kcp.Listener - if config.TCP { - var conn net.PacketConn - conn, err = tcpraw.Listen("tcp", addr) - if err != nil { - return nil, err - } - ln, err = kcp.ServeConn( - blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard, conn) - if err != nil { - return nil, err - } - } else { - ln, err = kcp.ListenWithOptions(addr, - blockCrypt(config.Key, config.Crypt, KCPSalt), config.DataShard, config.ParityShard) - } - if err != nil { - return nil, err - } - if config.DSCP > 0 { - if err = ln.SetDSCP(config.DSCP); err != nil { - log.Log("[kcp]", err) - } - } - if err = ln.SetReadBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - if err = ln.SetWriteBuffer(config.SockBuf); err != nil { - log.Log("[kcp]", err) - } - - go snmpLogger(config.SnmpLog, config.SnmpPeriod) - if config.Signal { - go kcpSigHandler() - } - - l := &kcpListener{ - config: config, - ln: ln, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - go l.listenLoop() - - return l, nil -} - -func (l *kcpListener) listenLoop() { - for { - conn, err := l.ln.AcceptKCP() - if err != nil { - log.Log("[kcp] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - conn.SetStreamMode(true) - conn.SetWriteDelay(false) - conn.SetNoDelay(l.config.NoDelay, l.config.Interval, l.config.Resend, l.config.NoCongestion) - conn.SetMtu(l.config.MTU) - conn.SetWindowSize(l.config.SndWnd, l.config.RcvWnd) - conn.SetACKNoDelay(l.config.AckNodelay) - go l.mux(conn) - } -} - -func (l *kcpListener) mux(conn net.Conn) { - smuxConfig := smux.DefaultConfig() - smuxConfig.MaxReceiveBuffer = l.config.SockBuf - smuxConfig.KeepAliveInterval = time.Duration(l.config.KeepAlive) * time.Second - - log.Logf("[kcp] %s - %s", conn.RemoteAddr(), l.Addr()) - - if !l.config.NoComp { - conn = newCompStreamConn(conn) - } - - mux, err := smux.Server(conn, smuxConfig) - if err != nil { - log.Log("[kcp]", err) - return - } - defer mux.Close() - - log.Logf("[kcp] %s <-> %s", conn.RemoteAddr(), l.Addr()) - defer log.Logf("[kcp] %s >-< %s", conn.RemoteAddr(), l.Addr()) - - for { - stream, err := mux.AcceptStream() - if err != nil { - log.Log("[kcp] accept stream:", err) - return - } - - cc := &muxStreamConn{Conn: conn, stream: stream} - select { - case l.connChan <- cc: - default: - cc.Close() - log.Logf("[kcp] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) - } - } -} - -func (l *kcpListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} -func (l *kcpListener) Addr() net.Addr { - return l.ln.Addr() -} - -func (l *kcpListener) Close() error { - return l.ln.Close() -} - -func blockCrypt(key, crypt, salt string) (block kcp.BlockCrypt) { - pass := pbkdf2.Key([]byte(key), []byte(salt), 4096, 32, sha1.New) - - switch crypt { - case "sm4": - block, _ = kcp.NewSM4BlockCrypt(pass[:16]) - case "tea": - block, _ = kcp.NewTEABlockCrypt(pass[:16]) - case "xor": - block, _ = kcp.NewSimpleXORBlockCrypt(pass) - case "none": - block, _ = kcp.NewNoneBlockCrypt(pass) - case "aes-128": - block, _ = kcp.NewAESBlockCrypt(pass[:16]) - case "aes-192": - block, _ = kcp.NewAESBlockCrypt(pass[:24]) - case "blowfish": - block, _ = kcp.NewBlowfishBlockCrypt(pass) - case "twofish": - block, _ = kcp.NewTwofishBlockCrypt(pass) - case "cast5": - block, _ = kcp.NewCast5BlockCrypt(pass[:16]) - case "3des": - block, _ = kcp.NewTripleDESBlockCrypt(pass[:24]) - case "xtea": - block, _ = kcp.NewXTEABlockCrypt(pass[:16]) - case "salsa20": - block, _ = kcp.NewSalsa20BlockCrypt(pass) - case "aes": - fallthrough - default: // aes - block, _ = kcp.NewAESBlockCrypt(pass) - } - return -} - -func snmpLogger(format string, interval int) { - if format == "" || interval == 0 { - return - } - ticker := time.NewTicker(time.Duration(interval) * time.Second) - defer ticker.Stop() - for { - select { - case <-ticker.C: - f, err := os.OpenFile(time.Now().Format(format), os.O_RDWR|os.O_CREATE|os.O_APPEND, 0666) - if err != nil { - log.Log("[kcp]", err) - return - } - w := csv.NewWriter(f) - // write header in empty file - if stat, err := f.Stat(); err == nil && stat.Size() == 0 { - if err := w.Write(append([]string{"Unix"}, kcp.DefaultSnmp.Header()...)); err != nil { - log.Log("[kcp]", err) - } - } - if err := w.Write(append([]string{fmt.Sprint(time.Now().Unix())}, kcp.DefaultSnmp.ToSlice()...)); err != nil { - log.Log("[kcp]", err) - } - kcp.DefaultSnmp.Reset() - w.Flush() - f.Close() - } - } -} - -type compStreamConn struct { - conn net.Conn - w *snappy.Writer - r *snappy.Reader -} - -func newCompStreamConn(conn net.Conn) *compStreamConn { - c := new(compStreamConn) - c.conn = conn - c.w = snappy.NewBufferedWriter(conn) - c.r = snappy.NewReader(conn) - return c -} - -func (c *compStreamConn) Read(b []byte) (n int, err error) { - return c.r.Read(b) -} - -func (c *compStreamConn) Write(b []byte) (n int, err error) { - n, err = c.w.Write(b) - err = c.w.Flush() - return n, err -} - -func (c *compStreamConn) Close() error { - return c.conn.Close() -} - -func (c *compStreamConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *compStreamConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *compStreamConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *compStreamConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *compStreamConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} diff --git a/gost/kcp_test.go b/gost/kcp_test.go deleted file mode 100644 index d9e6ba00..00000000 --- a/gost/kcp_test.go +++ /dev/null @@ -1,408 +0,0 @@ -package gost - -import ( - "crypto/rand" - "fmt" - "net/http/httptest" - "net/url" - "testing" -) - -func httpOverKCPRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := KCPListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverKCP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverKCPRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverKCP(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := KCPListener("", nil) - if err != nil { - b.Error(err) - } - b.Log(ln.Addr()) - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverKCPParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := KCPListener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverKCPRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := KCPListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverKCP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverKCPRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverKCPRoundtrip(targetURL string, data []byte) error { - ln, err := KCPListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverKCP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverKCPRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverKCPRoundtrip(targetURL string, data []byte) error { - ln, err := KCPListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverKCP(t *testing.T) { - - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverKCPRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverKCPRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := KCPListener("localhost:0", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverKCP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverKCPRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverKCPRoundtrip(targetURL string, data []byte, host string) error { - ln, err := KCPListener("localhost:0", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverKCP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverKCPRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func kcpForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := KCPListener("localhost:0", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: KCPTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestKCPForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := kcpForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} diff --git a/gost/node.go b/gost/node.go index f64afc43..c0915625 100644 --- a/gost/node.go +++ b/gost/node.go @@ -31,7 +31,6 @@ type Node struct { ConnectOptions []ConnectOption Client *Client marker *failMarker - Bypass *Bypass } // ParseNode parses the node info. diff --git a/gost/obfs.go b/gost/obfs.go deleted file mode 100644 index 44546841..00000000 --- a/gost/obfs.go +++ /dev/null @@ -1,818 +0,0 @@ -// obfs4 connection wrappers - -package gost - -import ( - "bufio" - "bytes" - "crypto/rand" - "crypto/tls" - "errors" - "fmt" - "io" - "net" - "net/http" - "net/http/httputil" - "net/url" - "sync" - "time" - - "github.com/go-log/log" - - pt "git.torproject.org/pluggable-transports/goptlib.git" - "git.torproject.org/pluggable-transports/obfs4.git/transports/base" - "git.torproject.org/pluggable-transports/obfs4.git/transports/obfs4" - dissector "github.com/ginuerzh/tls-dissector" -) - -const ( - maxTLSDataLen = 16384 -) - -type obfsHTTPTransporter struct { - tcpTransporter -} - -// ObfsHTTPTransporter creates a Transporter that is used by HTTP obfuscating tunnel client. -func ObfsHTTPTransporter() Transporter { - return &obfsHTTPTransporter{} -} - -func (tr *obfsHTTPTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - return &obfsHTTPConn{Conn: conn, host: opts.Host}, nil -} - -type obfsHTTPListener struct { - net.Listener -} - -// ObfsHTTPListener creates a Listener for HTTP obfuscating tunnel server. -func ObfsHTTPListener(addr string) (Listener, error) { - laddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", laddr) - if err != nil { - return nil, err - } - return &obfsHTTPListener{Listener: tcpKeepAliveListener{ln}}, nil -} - -func (l *obfsHTTPListener) Accept() (net.Conn, error) { - conn, err := l.Listener.Accept() - if err != nil { - return nil, err - } - - return &obfsHTTPConn{Conn: conn, isServer: true}, nil -} - -type obfsHTTPConn struct { - net.Conn - host string - rbuf bytes.Buffer - wbuf bytes.Buffer - isServer bool - headerDrained bool - handshaked bool - handshakeMutex sync.Mutex -} - -func (c *obfsHTTPConn) Handshake() (err error) { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if c.handshaked { - return nil - } - - if c.isServer { - err = c.serverHandshake() - } else { - err = c.clientHandshake() - } - if err != nil { - return - } - - c.handshaked = true - return nil -} - -func (c *obfsHTTPConn) serverHandshake() (err error) { - br := bufio.NewReader(c.Conn) - r, err := http.ReadRequest(br) - if err != nil { - return - } - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Logf("[ohttp] %s -> %s\n%s", c.RemoteAddr(), c.LocalAddr(), string(dump)) - } - - if r.ContentLength > 0 { - _, err = io.Copy(&c.rbuf, r.Body) - } else { - var b []byte - b, err = br.Peek(br.Buffered()) - if len(b) > 0 { - _, err = c.rbuf.Write(b) - } - } - if err != nil { - log.Logf("[ohttp] %s -> %s : %v", c.Conn.RemoteAddr(), c.Conn.LocalAddr(), err) - return - } - - b := bytes.Buffer{} - - if r.Method != http.MethodGet || r.Header.Get("Upgrade") != "websocket" { - b.WriteString("HTTP/1.1 503 Service Unavailable\r\n") - b.WriteString("Content-Length: 0\r\n") - b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - b.WriteString("\r\n") - - if Debug { - log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) - } - - b.WriteTo(c.Conn) - return errors.New("bad request") - } - - b.WriteString("HTTP/1.1 101 Switching Protocols\r\n") - b.WriteString("Server: nginx/1.10.0\r\n") - b.WriteString("Date: " + time.Now().Format(time.RFC1123) + "\r\n") - b.WriteString("Connection: Upgrade\r\n") - b.WriteString("Upgrade: websocket\r\n") - b.WriteString(fmt.Sprintf("Sec-WebSocket-Accept: %s\r\n", computeAcceptKey(r.Header.Get("Sec-WebSocket-Key")))) - b.WriteString("\r\n") - - if Debug { - log.Logf("[ohttp] %s <- %s\n%s", c.RemoteAddr(), c.LocalAddr(), b.String()) - } - - if c.rbuf.Len() > 0 { - c.wbuf = b // cache the response header if there are extra data in the request body. - return - } - - _, err = b.WriteTo(c.Conn) - return -} - -func (c *obfsHTTPConn) clientHandshake() (err error) { - r := &http.Request{ - Method: http.MethodGet, - ProtoMajor: 1, - ProtoMinor: 1, - URL: &url.URL{Scheme: "http", Host: c.host}, - Header: make(http.Header), - } - r.Header.Set("User-Agent", DefaultUserAgent) - r.Header.Set("Connection", "Upgrade") - r.Header.Set("Upgrade", "websocket") - key, _ := generateChallengeKey() - r.Header.Set("Sec-WebSocket-Key", key) - - // cache the request header - if err = r.Write(&c.wbuf); err != nil { - return - } - - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Logf("[ohttp] %s -> %s\n%s", c.LocalAddr(), c.RemoteAddr(), string(dump)) - } - - return nil -} - -func (c *obfsHTTPConn) Read(b []byte) (n int, err error) { - if err = c.Handshake(); err != nil { - return - } - - if !c.isServer { - if err = c.drainHeader(); err != nil { - return - } - } - - if c.rbuf.Len() > 0 { - return c.rbuf.Read(b) - } - return c.Conn.Read(b) -} - -func (c *obfsHTTPConn) drainHeader() (err error) { - if c.headerDrained { - return - } - c.headerDrained = true - - br := bufio.NewReader(c.Conn) - // drain and discard the response header - var line string - var buf bytes.Buffer - for { - line, err = br.ReadString('\n') - if err != nil { - return - } - buf.WriteString(line) - if line == "\r\n" { - break - } - } - - if Debug { - log.Logf("[ohttp] %s <- %s\n%s", c.LocalAddr(), c.RemoteAddr(), buf.String()) - } - // cache the extra data for next read. - var b []byte - b, err = br.Peek(br.Buffered()) - if len(b) > 0 { - _, err = c.rbuf.Write(b) - } - return -} - -func (c *obfsHTTPConn) Write(b []byte) (n int, err error) { - if err = c.Handshake(); err != nil { - return - } - if c.wbuf.Len() > 0 { - c.wbuf.Write(b) // append the data to the cached header - _, err = c.wbuf.WriteTo(c.Conn) - n = len(b) // exclude the header length - return - } - return c.Conn.Write(b) -} - -type obfsTLSTransporter struct { - tcpTransporter -} - -// ObfsTLSTransporter creates a Transporter that is used by TLS obfuscating. -func ObfsTLSTransporter() Transporter { - return &obfsTLSTransporter{} -} - -func (tr *obfsTLSTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - return ClientObfsTLSConn(conn, opts.Host), nil -} - -type obfsTLSListener struct { - net.Listener -} - -// ObfsTLSListener creates a Listener for TLS obfuscating server. -func ObfsTLSListener(addr string) (Listener, error) { - laddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - ln, err := net.ListenTCP("tcp", laddr) - if err != nil { - return nil, err - } - return &obfsTLSListener{Listener: tcpKeepAliveListener{ln}}, nil -} - -func (l *obfsTLSListener) Accept() (net.Conn, error) { - conn, err := l.Listener.Accept() - if err != nil { - return nil, err - } - - return ServerObfsTLSConn(conn, ""), nil -} - -var ( - cipherSuites = []uint16{ - 0xc02c, 0xc030, 0x009f, 0xcca9, 0xcca8, 0xccaa, 0xc02b, 0xc02f, - 0x009e, 0xc024, 0xc028, 0x006b, 0xc023, 0xc027, 0x0067, 0xc00a, - 0xc014, 0x0039, 0xc009, 0xc013, 0x0033, 0x009d, 0x009c, 0x003d, - 0x003c, 0x0035, 0x002f, 0x00ff, - } - - compressionMethods = []uint8{0x00} - - algorithms = []uint16{ - 0x0601, 0x0602, 0x0603, 0x0501, 0x0502, 0x0503, 0x0401, 0x0402, - 0x0403, 0x0301, 0x0302, 0x0303, 0x0201, 0x0202, 0x0203, - } - - tlsRecordTypes = []uint8{0x16, 0x14, 0x16, 0x17} - tlsVersionMinors = []uint8{0x01, 0x03, 0x03, 0x03} - - ErrBadType = errors.New("bad type") - ErrBadMajorVersion = errors.New("bad major version") - ErrBadMinorVersion = errors.New("bad minor version") - ErrMaxDataLen = errors.New("bad tls data len") -) - -const ( - tlsRecordStateType = iota - tlsRecordStateVersion0 - tlsRecordStateVersion1 - tlsRecordStateLength0 - tlsRecordStateLength1 - tlsRecordStateData -) - -type obfsTLSParser struct { - step uint8 - state uint8 - length uint16 -} - -type obfsTLSConn struct { - net.Conn - rbuf bytes.Buffer - wbuf bytes.Buffer - host string - isServer bool - handshaked chan struct{} - parser *obfsTLSParser - handshakeMutex sync.Mutex -} - -func (r *obfsTLSParser) Parse(b []byte) (int, error) { - i := 0 - last := 0 - length := len(b) - - for i < length { - ch := b[i] - switch r.state { - case tlsRecordStateType: - if tlsRecordTypes[r.step] != ch { - return 0, ErrBadType - } - r.state = tlsRecordStateVersion0 - i++ - case tlsRecordStateVersion0: - if ch != 0x03 { - return 0, ErrBadMajorVersion - } - r.state = tlsRecordStateVersion1 - i++ - case tlsRecordStateVersion1: - if ch != tlsVersionMinors[r.step] { - return 0, ErrBadMinorVersion - } - r.state = tlsRecordStateLength0 - i++ - case tlsRecordStateLength0: - r.length = uint16(ch) << 8 - r.state = tlsRecordStateLength1 - i++ - case tlsRecordStateLength1: - r.length |= uint16(ch) - if r.step == 0 { - r.length = 91 - } else if r.step == 1 { - r.length = 1 - } else if r.length > maxTLSDataLen { - return 0, ErrMaxDataLen - } - if r.length > 0 { - r.state = tlsRecordStateData - } else { - r.state = tlsRecordStateType - r.step++ - } - i++ - case tlsRecordStateData: - left := uint16(length - i) - if left > r.length { - left = r.length - } - if r.step >= 2 { - skip := i - last - copy(b[last:], b[i:length]) - length -= int(skip) - last += int(left) - i = last - } else { - i += int(left) - } - r.length -= left - if r.length == 0 { - if r.step < 3 { - r.step++ - } - r.state = tlsRecordStateType - } - } - } - - if last == 0 { - return 0, nil - } else if last < length { - length -= last - } - - return length, nil -} - -// ClientObfsTLSConn creates a connection for obfs-tls client. -func ClientObfsTLSConn(conn net.Conn, host string) net.Conn { - return &obfsTLSConn{ - Conn: conn, - host: host, - handshaked: make(chan struct{}), - parser: &obfsTLSParser{}, - } -} - -// ServerObfsTLSConn creates a connection for obfs-tls server. -func ServerObfsTLSConn(conn net.Conn, host string) net.Conn { - return &obfsTLSConn{ - Conn: conn, - host: host, - isServer: true, - handshaked: make(chan struct{}), - } -} - -func (c *obfsTLSConn) Handshaked() bool { - select { - case <-c.handshaked: - return true - default: - return false - } -} - -func (c *obfsTLSConn) Handshake(payload []byte) (err error) { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - - if c.Handshaked() { - return - } - - if c.isServer { - err = c.serverHandshake() - } else { - err = c.clientHandshake(payload) - } - if err != nil { - return - } - - close(c.handshaked) - return nil -} - -func (c *obfsTLSConn) clientHandshake(payload []byte) error { - clientMsg := &dissector.ClientHelloMsg{ - Version: tls.VersionTLS12, - SessionID: make([]byte, 32), - CipherSuites: cipherSuites, - CompressionMethods: compressionMethods, - Extensions: []dissector.Extension{ - &dissector.SessionTicketExtension{ - Data: payload, - }, - &dissector.ServerNameExtension{ - Name: c.host, - }, - &dissector.ECPointFormatsExtension{ - Formats: []uint8{0x01, 0x00, 0x02}, - }, - &dissector.SupportedGroupsExtension{ - Groups: []uint16{0x001d, 0x0017, 0x0019, 0x0018}, - }, - &dissector.SignatureAlgorithmsExtension{ - Algorithms: algorithms, - }, - &dissector.EncryptThenMacExtension{}, - &dissector.ExtendedMasterSecretExtension{}, - }, - } - clientMsg.Random.Time = uint32(time.Now().Unix()) - rand.Read(clientMsg.Random.Opaque[:]) - rand.Read(clientMsg.SessionID) - b, err := clientMsg.Encode() - if err != nil { - return err - } - - record := &dissector.Record{ - Type: dissector.Handshake, - Version: tls.VersionTLS10, - Opaque: b, - } - if _, err := record.WriteTo(c.Conn); err != nil { - return err - } - return err -} - -func (c *obfsTLSConn) serverHandshake() error { - record := &dissector.Record{} - if _, err := record.ReadFrom(c.Conn); err != nil { - log.Log(err) - return err - } - if record.Type != dissector.Handshake { - return dissector.ErrBadType - } - - clientMsg := &dissector.ClientHelloMsg{} - if err := clientMsg.Decode(record.Opaque); err != nil { - log.Log(err) - return err - } - - for _, ext := range clientMsg.Extensions { - if ext.Type() == dissector.ExtSessionTicket { - b, err := ext.Encode() - if err != nil { - log.Log(err) - return err - } - c.rbuf.Write(b) - break - } - } - - serverMsg := &dissector.ServerHelloMsg{ - Version: tls.VersionTLS12, - SessionID: clientMsg.SessionID, - CipherSuite: 0xcca8, - CompressionMethod: 0x00, - Extensions: []dissector.Extension{ - &dissector.RenegotiationInfoExtension{}, - &dissector.ExtendedMasterSecretExtension{}, - &dissector.ECPointFormatsExtension{ - Formats: []uint8{0x00}, - }, - }, - } - - serverMsg.Random.Time = uint32(time.Now().Unix()) - rand.Read(serverMsg.Random.Opaque[:]) - b, err := serverMsg.Encode() - if err != nil { - return err - } - - record = &dissector.Record{ - Type: dissector.Handshake, - Version: tls.VersionTLS10, - Opaque: b, - } - - if _, err := record.WriteTo(&c.wbuf); err != nil { - return err - } - - record = &dissector.Record{ - Type: dissector.ChangeCipherSpec, - Version: tls.VersionTLS12, - Opaque: []byte{0x01}, - } - if _, err := record.WriteTo(&c.wbuf); err != nil { - return err - } - return nil -} - -func (c *obfsTLSConn) Read(b []byte) (n int, err error) { - if c.isServer { // NOTE: only Write performs the handshake operation on client side. - if err = c.Handshake(nil); err != nil { - return - } - } - - select { - case <-c.handshaked: - } - - if c.isServer { - if c.rbuf.Len() > 0 { - return c.rbuf.Read(b) - } - record := &dissector.Record{} - if _, err = record.ReadFrom(c.Conn); err != nil { - return - } - n = copy(b, record.Opaque) - _, err = c.rbuf.Write(record.Opaque[n:]) - } else { - n, err = c.Conn.Read(b) - if err != nil { - return - } - if n > 0 { - n, err = c.parser.Parse(b[:n]) - } - } - return -} - -func (c *obfsTLSConn) Write(b []byte) (n int, err error) { - n = len(b) - if !c.Handshaked() { - if err = c.Handshake(b); err != nil { - return - } - if !c.isServer { // the data b has been sended during handshake phase. - return - } - } - - for len(b) > 0 { - data := b - if len(b) > maxTLSDataLen { - data = b[:maxTLSDataLen] - b = b[maxTLSDataLen:] - } else { - b = b[:0] - } - record := &dissector.Record{ - Type: dissector.AppData, - Version: tls.VersionTLS12, - Opaque: data, - } - - if c.wbuf.Len() > 0 { - record.Type = dissector.Handshake - record.WriteTo(&c.wbuf) - _, err = c.wbuf.WriteTo(c.Conn) - return - } - - if _, err = record.WriteTo(c.Conn); err != nil { - return - } - } - return -} - -type obfs4Context struct { - cf base.ClientFactory - cargs interface{} // type obfs4ClientArgs - sf base.ServerFactory - sargs *pt.Args -} - -var obfs4Map = make(map[string]obfs4Context) - -// Obfs4Init initializes the obfs client or server based on isServeNode -func Obfs4Init(node Node, isServeNode bool) error { - if _, ok := obfs4Map[node.Addr]; ok { - return fmt.Errorf("obfs4 context already inited") - } - - t := new(obfs4.Transport) - - stateDir := node.Values.Get("state-dir") - if stateDir == "" { - stateDir = "." - } - - ptArgs := pt.Args(node.Values) - - if !isServeNode { - cf, err := t.ClientFactory(stateDir) - if err != nil { - return err - } - - cargs, err := cf.ParseArgs(&ptArgs) - if err != nil { - return err - } - - obfs4Map[node.Addr] = obfs4Context{cf: cf, cargs: cargs} - } else { - sf, err := t.ServerFactory(stateDir, &ptArgs) - if err != nil { - return err - } - - sargs := sf.Args() - - obfs4Map[node.Addr] = obfs4Context{sf: sf, sargs: sargs} - - log.Log("[obfs4] server inited:", obfs4ServerURL(node)) - } - - return nil -} - -func obfs4GetContext(addr string) (obfs4Context, error) { - ctx, ok := obfs4Map[addr] - if !ok { - return obfs4Context{}, fmt.Errorf("obfs4 context not inited") - } - return ctx, nil -} - -func obfs4ServerURL(node Node) string { - ctx, err := obfs4GetContext(node.Addr) - if err != nil { - return "" - } - - values := (*url.Values)(ctx.sargs) - query := values.Encode() - return fmt.Sprintf( - "%s+%s://%s/?%s", //obfs4-cert=%s&iat-mode=%s", - node.Protocol, - node.Transport, - node.Addr, - query, - ) -} - -func obfs4ClientConn(addr string, conn net.Conn) (net.Conn, error) { - ctx, err := obfs4GetContext(addr) - if err != nil { - return nil, err - } - - pseudoDial := func(a, b string) (net.Conn, error) { return conn, nil } - return ctx.cf.Dial("tcp", "", pseudoDial, ctx.cargs) -} - -func obfs4ServerConn(addr string, conn net.Conn) (net.Conn, error) { - ctx, err := obfs4GetContext(addr) - if err != nil { - return nil, err - } - - return ctx.sf.WrapConn(conn) -} - -type obfs4Transporter struct { - tcpTransporter -} - -// Obfs4Transporter creates a Transporter that is used by obfs4 client. -func Obfs4Transporter() Transporter { - return &obfs4Transporter{} -} - -func (tr *obfs4Transporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = HandshakeTimeout - } - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - return obfs4ClientConn(opts.Addr, conn) -} - -type obfs4Listener struct { - addr string - net.Listener -} - -// Obfs4Listener creates a Listener for obfs4 server. -func Obfs4Listener(addr string) (Listener, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - l := &obfs4Listener{ - addr: addr, - Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, - } - return l, nil -} - -func (l *obfs4Listener) Accept() (net.Conn, error) { - conn, err := l.Listener.Accept() - if err != nil { - return nil, err - } - cc, err := obfs4ServerConn(l.addr, conn) - if err != nil { - conn.Close() - return nil, err - } - return cc, nil -} diff --git a/gost/obfs_test.go b/gost/obfs_test.go deleted file mode 100644 index cd33702b..00000000 --- a/gost/obfs_test.go +++ /dev/null @@ -1,424 +0,0 @@ -package gost - -import ( - "crypto/rand" - "fmt" - "net/http/httptest" - "net/url" - "testing" -) - -func httpOverObfsHTTPRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := ObfsHTTPListener("") - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverObfsHTTP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := httpOverObfsHTTPRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - }) - } -} - -func BenchmarkHTTPOverObfsHTTP(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := ObfsHTTPListener("") - if err != nil { - b.Error(err) - } - // b.Log(ln.Addr()) - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverObfsHTTPParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := ObfsHTTPListener("") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverObfsHTTPRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := ObfsHTTPListener("") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverObfsHTTP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverObfsHTTPRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverObfsHTTPRoundtrip(targetURL string, data []byte) error { - ln, err := ObfsHTTPListener("") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverObfsHTTP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverObfsHTTPRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverObfsHTTPRoundtrip(targetURL string, data []byte) error { - ln, err := ObfsHTTPListener("") - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverObfsHTTP(t *testing.T) { - - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverObfsHTTPRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverObfsHTTPRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := ObfsHTTPListener("") - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverObfsHTTP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverObfsHTTPRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverObfsHTTPRoundtrip(targetURL string, data []byte, host string) error { - ln, err := ObfsHTTPListener("") - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: ObfsHTTPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverObfsHTTP(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverObfsHTTPRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func httpOverObfs4Roundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := Obfs4Listener("") - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: Obfs4Transporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func _TestHTTPOverObfs4(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := httpOverObfs4Roundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - }) - } -} diff --git a/gost/permissions.go b/gost/permissions.go deleted file mode 100644 index a2943faf..00000000 --- a/gost/permissions.go +++ /dev/null @@ -1,223 +0,0 @@ -package gost - -import ( - "errors" - "fmt" - "net" - "strconv" - "strings" - - glob "github.com/ryanuber/go-glob" -) - -// Permission is a rule for blacklist and whitelist. -type Permission struct { - Actions StringSet - Hosts StringSet - Ports PortSet -} - -// PortRange specifies the range of port, such as 1000-2000. -type PortRange struct { - Min, Max int -} - -// ParsePortRange parses the s to a PortRange. -// The s may be a '*' means 0-65535. -func ParsePortRange(s string) (*PortRange, error) { - if s == "*" { - return &PortRange{Min: 0, Max: 65535}, nil - } - - minmax := strings.Split(s, "-") - switch len(minmax) { - case 1: - port, err := strconv.Atoi(s) - if err != nil { - return nil, err - } - if port < 0 || port > 65535 { - return nil, fmt.Errorf("invalid port: %s", s) - } - return &PortRange{Min: port, Max: port}, nil - case 2: - min, err := strconv.Atoi(minmax[0]) - if err != nil { - return nil, err - } - max, err := strconv.Atoi(minmax[1]) - if err != nil { - return nil, err - } - - realmin := maxint(0, minint(min, max)) - realmax := minint(65535, maxint(min, max)) - - return &PortRange{Min: realmin, Max: realmax}, nil - default: - return nil, fmt.Errorf("invalid range: %s", s) - } -} - -// Contains checks whether the value is within this range. -func (ir *PortRange) Contains(value int) bool { - return value >= ir.Min && value <= ir.Max -} - -// PortSet is a set of PortRange -type PortSet []PortRange - -// ParsePortSet parses the s to a PortSet. -// The s shoud be a comma separated string. -func ParsePortSet(s string) (*PortSet, error) { - ps := &PortSet{} - - if s == "" { - return nil, errors.New("must specify at least one port") - } - - ranges := strings.Split(s, ",") - - for _, r := range ranges { - portRange, err := ParsePortRange(r) - - if err != nil { - return nil, err - } - - *ps = append(*ps, *portRange) - } - - return ps, nil -} - -// Contains checks whether the value is within this port set. -func (ps *PortSet) Contains(value int) bool { - for _, portRange := range *ps { - if portRange.Contains(value) { - return true - } - } - - return false -} - -// StringSet is a set of string. -type StringSet []string - -// ParseStringSet parses the s to a StringSet. -// The s shoud be a comma separated string. -func ParseStringSet(s string) (*StringSet, error) { - ss := &StringSet{} - if s == "" { - return nil, errors.New("cannot be empty") - } - - *ss = strings.Split(s, ",") - - return ss, nil -} - -// Contains checks whether the string subj within this StringSet. -func (ss *StringSet) Contains(subj string) bool { - for _, s := range *ss { - if glob.Glob(s, subj) { - return true - } - } - - return false -} - -// Permissions is a set of Permission. -type Permissions []Permission - -// ParsePermissions parses the s to a Permissions. -func ParsePermissions(s string) (*Permissions, error) { - ps := &Permissions{} - - if s == "" { - return &Permissions{}, nil - } - - perms := strings.Split(s, " ") - - for _, perm := range perms { - parts := strings.Split(perm, ":") - - switch len(parts) { - case 3: - actions, err := ParseStringSet(parts[0]) - - if err != nil { - return nil, fmt.Errorf("action list must look like connect,bind given: %s", parts[0]) - } - - hosts, err := ParseStringSet(parts[1]) - - if err != nil { - return nil, fmt.Errorf("hosts list must look like google.pl,*.google.com given: %s", parts[1]) - } - - ports, err := ParsePortSet(parts[2]) - - if err != nil { - return nil, fmt.Errorf("ports list must look like 80,8000-9000, given: %s", parts[2]) - } - - permission := Permission{Actions: *actions, Hosts: *hosts, Ports: *ports} - - *ps = append(*ps, permission) - default: - return nil, fmt.Errorf("permission must have format [actions]:[hosts]:[ports] given: %s", perm) - } - } - - return ps, nil -} - -// Can tests whether the given action and host:port is allowed by this Permissions. -func (ps *Permissions) Can(action string, host string, port int) bool { - for _, p := range *ps { - if p.Actions.Contains(action) && p.Hosts.Contains(host) && p.Ports.Contains(port) { - return true - } - } - - return false -} - -func minint(x, y int) int { - if x < y { - return x - } - return y -} - -func maxint(x, y int) int { - if x > y { - return x - } - return y -} - -// Can tests whether the given action and address is allowed by the whitelist and blacklist. -func Can(action string, addr string, whitelist, blacklist *Permissions) bool { - if !strings.Contains(addr, ":") { - addr = addr + ":80" - } - host, strport, err := net.SplitHostPort(addr) - - if err != nil { - return false - } - - port, err := strconv.Atoi(strport) - - if err != nil { - return false - } - - return (whitelist == nil || whitelist.Can(action, host, port)) && - (blacklist == nil || !blacklist.Can(action, host, port)) -} diff --git a/gost/permissions_test.go b/gost/permissions_test.go deleted file mode 100644 index bc99824a..00000000 --- a/gost/permissions_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package gost - -import ( - "fmt" - "testing" -) - -var portRangeTests = []struct { - in string - out *PortRange -}{ - {"1", &PortRange{Min: 1, Max: 1}}, - {"1-3", &PortRange{Min: 1, Max: 3}}, - {"3-1", &PortRange{Min: 1, Max: 3}}, - {"0-100000", &PortRange{Min: 0, Max: 65535}}, - {"*", &PortRange{Min: 0, Max: 65535}}, -} - -var stringSetTests = []struct { - in string - out *StringSet -}{ - {"*", &StringSet{"*"}}, - {"google.pl,google.com", &StringSet{"google.pl", "google.com"}}, -} - -var portSetTests = []struct { - in string - out *PortSet -}{ - {"1,3", &PortSet{PortRange{Min: 1, Max: 1}, PortRange{Min: 3, Max: 3}}}, - {"1-3,7-5", &PortSet{PortRange{Min: 1, Max: 3}, PortRange{Min: 5, Max: 7}}}, - {"0-100000", &PortSet{PortRange{Min: 0, Max: 65535}}}, - {"*", &PortSet{PortRange{Min: 0, Max: 65535}}}, -} - -var permissionsTests = []struct { - in string - out *Permissions -}{ - {"", &Permissions{}}, - {"*:*:*", &Permissions{ - Permission{ - Actions: StringSet{"*"}, - Hosts: StringSet{"*"}, - Ports: PortSet{PortRange{Min: 0, Max: 65535}}, - }, - }}, - {"bind:127.0.0.1,localhost:80,443,8000-8100 connect:*.google.pl:80,443", &Permissions{ - Permission{ - Actions: StringSet{"bind"}, - Hosts: StringSet{"127.0.0.1", "localhost"}, - Ports: PortSet{ - PortRange{Min: 80, Max: 80}, - PortRange{Min: 443, Max: 443}, - PortRange{Min: 8000, Max: 8100}, - }, - }, - Permission{ - Actions: StringSet{"connect"}, - Hosts: StringSet{"*.google.pl"}, - Ports: PortSet{ - PortRange{Min: 80, Max: 80}, - PortRange{Min: 443, Max: 443}, - }, - }, - }}, -} - -func TestPortRangeParse(t *testing.T) { - for _, test := range portRangeTests { - actual, err := ParsePortRange(test.in) - if err != nil { - t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) - } else if *actual != *test.out { - t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) - } - } -} - -func TestPortRangeContains(t *testing.T) { - actual, _ := ParsePortRange("5-10") - - if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { - t.Errorf("5-10 should contain 5, 7 and 10") - } - - if actual.Contains(4) || actual.Contains(11) { - t.Errorf("5-10 should not contain 4, 11") - } -} - -func TestStringSetParse(t *testing.T) { - for _, test := range stringSetTests { - actual, err := ParseStringSet(test.in) - if err != nil { - t.Errorf("ParseStringSet(%q) returned error: %v", test.in, err) - } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { - t.Errorf("ParseStringSet(%q): got %v, want %v", test.in, actual, test.out) - } - } -} - -func TestStringSetContains(t *testing.T) { - ss, _ := ParseStringSet("google.pl,*.google.com") - - if !ss.Contains("google.pl") || !ss.Contains("www.google.com") { - t.Errorf("google.pl,*.google.com should contain google.pl and www.google.com") - } - - if ss.Contains("www.google.pl") || ss.Contains("foobar.com") { - t.Errorf("google.pl,*.google.com shound not contain www.google.pl and foobar.com") - } -} - -func TestPortSetParse(t *testing.T) { - for _, test := range portSetTests { - actual, err := ParsePortSet(test.in) - if err != nil { - t.Errorf("ParsePortRange(%q) returned error: %v", test.in, err) - } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { - t.Errorf("ParsePortRange(%q): got %v, want %v", test.in, actual, test.out) - } - } -} - -func TestPortSetContains(t *testing.T) { - actual, _ := ParsePortSet("5-10,20-30") - - if !actual.Contains(5) || !actual.Contains(7) || !actual.Contains(10) { - t.Errorf("5-10,20-30 should contain 5, 7 and 10") - } - - if !actual.Contains(20) || !actual.Contains(27) || !actual.Contains(30) { - t.Errorf("5-10,20-30 should contain 20, 27 and 30") - } - - if actual.Contains(4) || actual.Contains(11) || actual.Contains(31) { - t.Errorf("5-10,20-30 should not contain 4, 11, 31") - } -} - -func TestPermissionsParse(t *testing.T) { - for _, test := range permissionsTests { - actual, err := ParsePermissions(test.in) - if err != nil { - t.Errorf("ParsePermissions(%q) returned error: %v", test.in, err) - } else if fmt.Sprintln(actual) != fmt.Sprintln(test.out) { - t.Errorf("ParsePermissions(%q): got %v, want %v", test.in, actual, test.out) - } - } -} diff --git a/gost/redirect.go b/gost/redirect.go deleted file mode 100644 index 0ab59d10..00000000 --- a/gost/redirect.go +++ /dev/null @@ -1,241 +0,0 @@ -// +build linux - -package gost - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "syscall" - "time" - - "github.com/LiamHaworth/go-tproxy" - "github.com/go-log/log" -) - -type tcpRedirectHandler struct { - options *HandlerOptions -} - -// TCPRedirectHandler creates a server Handler for TCP transparent server. -func TCPRedirectHandler(opts ...HandlerOption) Handler { - h := &tcpRedirectHandler{} - h.Init(opts...) - - return h -} - -func (h *tcpRedirectHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - - for _, opt := range options { - opt(h.options) - } -} - -func (h *tcpRedirectHandler) Handle(c net.Conn) { - conn, ok := c.(*net.TCPConn) - if !ok { - log.Log("[red-tcp] not a TCP connection") - } - - srcAddr := conn.RemoteAddr() - dstAddr, conn, err := h.getOriginalDstAddr(conn) - if err != nil { - log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) - return - } - defer conn.Close() - - log.Logf("[red-tcp] %s -> %s", srcAddr, dstAddr) - - cc, err := h.options.Chain.DialContext(context.Background(), - "tcp", dstAddr.String(), - RetryChainOption(h.options.Retries), - TimeoutChainOption(h.options.Timeout), - ) - if err != nil { - log.Logf("[red-tcp] %s -> %s : %s", srcAddr, dstAddr, err) - return - } - defer cc.Close() - - log.Logf("[red-tcp] %s <-> %s", srcAddr, dstAddr) - transport(conn, cc) - log.Logf("[red-tcp] %s >-< %s", srcAddr, dstAddr) -} - -func (h *tcpRedirectHandler) getOriginalDstAddr(conn *net.TCPConn) (addr net.Addr, c *net.TCPConn, err error) { - defer conn.Close() - - fc, err := conn.File() - if err != nil { - return - } - defer fc.Close() - - mreq, err := syscall.GetsockoptIPv6Mreq(int(fc.Fd()), syscall.IPPROTO_IP, 80) - if err != nil { - return - } - - // only ipv4 support - ip := net.IPv4(mreq.Multiaddr[4], mreq.Multiaddr[5], mreq.Multiaddr[6], mreq.Multiaddr[7]) - port := uint16(mreq.Multiaddr[2])<<8 + uint16(mreq.Multiaddr[3]) - addr, err = net.ResolveTCPAddr("tcp4", fmt.Sprintf("%s:%d", ip.String(), port)) - if err != nil { - return - } - - cc, err := net.FileConn(fc) - if err != nil { - return - } - - c, ok := cc.(*net.TCPConn) - if !ok { - err = errors.New("not a TCP connection") - } - return -} - -type udpRedirectHandler struct { - options *HandlerOptions -} - -// UDPRedirectHandler creates a server Handler for UDP transparent server. -func UDPRedirectHandler(opts ...HandlerOption) Handler { - h := &udpRedirectHandler{} - h.Init(opts...) - - return h -} - -func (h *udpRedirectHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - - for _, opt := range options { - opt(h.options) - } -} - -func (h *udpRedirectHandler) Handle(conn net.Conn) { - defer conn.Close() - - raddr, ok := conn.LocalAddr().(*net.UDPAddr) - if !ok { - log.Log("[red-udp] wrong connection type") - return - } - - cc, err := h.options.Chain.DialContext(context.Background(), - "udp", raddr.String(), - RetryChainOption(h.options.Retries), - TimeoutChainOption(h.options.Timeout), - ) - if err != nil { - log.Logf("[red-udp] %s - %s : %s", conn.RemoteAddr(), raddr, err) - return - } - defer cc.Close() - - log.Logf("[red-udp] %s <-> %s", conn.RemoteAddr(), raddr) - transport(conn, cc) - log.Logf("[red-udp] %s >-< %s", conn.RemoteAddr(), raddr) -} - -type udpRedirectListener struct { - *net.UDPConn - config *UDPListenConfig -} - -// UDPRedirectListener creates a Listener for UDP transparent proxy server. -func UDPRedirectListener(addr string, cfg *UDPListenConfig) (Listener, error) { - laddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - - ln, err := tproxy.ListenUDP("udp", laddr) - if err != nil { - return nil, err - } - - if cfg == nil { - cfg = &UDPListenConfig{} - } - return &udpRedirectListener{ - UDPConn: ln, - config: cfg, - }, nil -} - -func (l *udpRedirectListener) Accept() (conn net.Conn, err error) { - b := make([]byte, mediumBufferSize) - - n, raddr, dstAddr, err := tproxy.ReadFromUDP(l.UDPConn, b) - if err != nil { - log.Logf("[red-udp] %s : %s", l.Addr(), err) - return - } - log.Logf("[red-udp] %s: %s -> %s", l.Addr(), raddr, dstAddr) - - c, err := tproxy.DialUDP("udp", dstAddr, raddr) - if err != nil { - log.Logf("[red-udp] %s -> %s : %s", raddr, dstAddr, err) - return - } - - ttl := l.config.TTL - if ttl <= 0 { - ttl = defaultTTL - } - - conn = &udpRedirectServerConn{ - Conn: c, - buf: b[:n], - ttl: ttl, - } - return -} - -func (l *udpRedirectListener) Addr() net.Addr { - return l.UDPConn.LocalAddr() -} - -type udpRedirectServerConn struct { - net.Conn - buf []byte - ttl time.Duration - once sync.Once -} - -func (c *udpRedirectServerConn) Read(b []byte) (n int, err error) { - if c.ttl > 0 { - c.SetReadDeadline(time.Now().Add(c.ttl)) - defer c.SetReadDeadline(time.Time{}) - } - c.once.Do(func() { - n = copy(b, c.buf) - c.buf = nil - }) - - if n == 0 { - n, err = c.Conn.Read(b) - } - return -} - -func (c *udpRedirectServerConn) Write(b []byte) (n int, err error) { - if c.ttl > 0 { - c.SetWriteDeadline(time.Now().Add(c.ttl)) - defer c.SetWriteDeadline(time.Time{}) - } - return c.Conn.Write(b) -} diff --git a/gost/redirect_other.go b/gost/redirect_other.go deleted file mode 100644 index d2a576a6..00000000 --- a/gost/redirect_other.go +++ /dev/null @@ -1,57 +0,0 @@ -// +build !linux - -package gost - -import ( - "errors" - "net" - - "github.com/go-log/log" -) - -type tcpRedirectHandler struct { - options *HandlerOptions -} - -// TCPRedirectHandler creates a server Handler for TCP redirect server. -func TCPRedirectHandler(opts ...HandlerOption) Handler { - h := &tcpRedirectHandler{ - options: &HandlerOptions{ - Chain: new(Chain), - }, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *tcpRedirectHandler) Init(options ...HandlerOption) { - log.Log("[red-tcp] TCP redirect is not available on the Windows platform") -} - -func (h *tcpRedirectHandler) Handle(c net.Conn) { - log.Log("[red-tcp] TCP redirect is not available on the Windows platform") - c.Close() -} - -type udpRedirectHandler struct { -} - -// UDPRedirectHandler creates a server Handler for UDP transparent server. -func UDPRedirectHandler(opts ...HandlerOption) Handler { - return &udpRedirectHandler{} -} - -func (h *udpRedirectHandler) Init(options ...HandlerOption) { -} - -func (h *udpRedirectHandler) Handle(conn net.Conn) { - log.Log("[red-udp] UDP redirect is not available on the Windows platform") - conn.Close() -} - -// UDPRedirectListener creates a Listener for UDP transparent proxy server. -func UDPRedirectListener(addr string, cfg *UDPListenConfig) (Listener, error) { - return nil, errors.New("UDP redirect is not available on the Windows platform") -} diff --git a/gost/relay.go b/gost/relay.go deleted file mode 100644 index c5d140bd..00000000 --- a/gost/relay.go +++ /dev/null @@ -1,369 +0,0 @@ -package gost - -import ( - "bytes" - "context" - "encoding/binary" - "errors" - "fmt" - "io" - "net" - "net/url" - "strconv" - "sync" - "time" - - "github.com/go-gost/relay" - "github.com/go-log/log" -) - -type relayConnector struct { - user *url.Userinfo - remoteAddr string -} - -// RelayConnector creates a Connector for TCP/UDP data relay. -func RelayConnector(user *url.Userinfo) Connector { - return &relayConnector{ - user: user, - } -} - -func (c *relayConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { - return conn, nil -} - -func (c *relayConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { - opts := &ConnectOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = ConnectTimeout - } - - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - var udp bool - if network == "udp" || network == "udp4" || network == "udp6" { - udp = true - } - - req := &relay.Request{ - Version: relay.Version1, - } - if udp { - req.Flags |= relay.FUDP - } - - if c.user != nil { - pwd, _ := c.user.Password() - req.Features = append(req.Features, &relay.UserAuthFeature{ - Username: c.user.Username(), - Password: pwd, - }) - } - if address != "" { - host, port, _ := net.SplitHostPort(address) - nport, _ := strconv.ParseUint(port, 10, 16) - if host == "" { - host = net.IPv4zero.String() - } - - if nport > 0 { - var atype uint8 - ip := net.ParseIP(host) - if ip == nil { - atype = relay.AddrDomain - } else if ip.To4() == nil { - atype = relay.AddrIPv6 - } else { - atype = relay.AddrIPv4 - } - - req.Features = append(req.Features, &relay.TargetAddrFeature{ - AType: atype, - Host: host, - Port: uint16(nport), - }) - } - } - - rc := &relayConn{ - udp: udp, - Conn: conn, - } - - // write the header at once. - if opts.NoDelay { - if _, err := req.WriteTo(rc); err != nil { - return nil, err - } - } else { - if _, err := req.WriteTo(&rc.wbuf); err != nil { - return nil, err - } - } - - return rc, nil -} - -type relayHandler struct { - *baseForwardHandler -} - -// RelayHandler creates a server Handler for TCP/UDP relay server. -func RelayHandler(raddr string, opts ...HandlerOption) Handler { - h := &relayHandler{ - baseForwardHandler: &baseForwardHandler{ - raddr: raddr, - group: NewNodeGroup(), - options: &HandlerOptions{}, - }, - } - for _, opt := range opts { - opt(h.options) - } - return h -} - -func (h *relayHandler) Init(options ...HandlerOption) { - h.baseForwardHandler.Init(options...) -} - -func (h *relayHandler) Handle(conn net.Conn) { - defer conn.Close() - - req := &relay.Request{} - if _, err := req.ReadFrom(conn); err != nil { - log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - - if req.Version != relay.Version1 { - log.Logf("[relay] %s - %s : bad version", conn.RemoteAddr(), conn.LocalAddr()) - return - } - - var user, pass string - var raddr string - for _, f := range req.Features { - if f.Type() == relay.FeatureUserAuth { - feature := f.(*relay.UserAuthFeature) - user, pass = feature.Username, feature.Password - } - if f.Type() == relay.FeatureTargetAddr { - feature := f.(*relay.TargetAddrFeature) - raddr = net.JoinHostPort(feature.Host, strconv.Itoa(int(feature.Port))) - } - } - - resp := &relay.Response{ - Version: relay.Version1, - Status: relay.StatusOK, - } - if h.options.Authenticator != nil && !h.options.Authenticator.Authenticate(user, pass) { - resp.Status = relay.StatusUnauthorized - resp.WriteTo(conn) - log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user) - return - } - - if raddr != "" { - if len(h.group.Nodes()) > 0 { - resp.Status = relay.StatusForbidden - resp.WriteTo(conn) - log.Logf("[relay] %s -> %s : relay to %s is forbidden", - conn.RemoteAddr(), conn.LocalAddr(), raddr) - return - } - } else { - if len(h.group.Nodes()) == 0 { - resp.Status = relay.StatusBadRequest - resp.WriteTo(conn) - log.Logf("[relay] %s -> %s : bad request, target addr is needed", - conn.RemoteAddr(), conn.LocalAddr()) - return - } - } - - udp := (req.Flags & relay.FUDP) == relay.FUDP - retries := 1 - if h.options.Chain != nil && h.options.Chain.Retries > 0 { - retries = h.options.Chain.Retries - } - if h.options.Retries > 0 { - retries = h.options.Retries - } - - network := "tcp" - if udp { - network = "udp" - } - if !Can(network, raddr, h.options.Whitelist, h.options.Blacklist) { - resp.Status = relay.StatusForbidden - resp.WriteTo(conn) - log.Logf("[relay] %s -> %s : relay to %s is forbidden", - conn.RemoteAddr(), conn.LocalAddr(), raddr) - return - } - - ctx := context.TODO() - var cc net.Conn - var node Node - var err error - for i := 0; i < retries; i++ { - if len(h.group.Nodes()) > 0 { - node, err = h.group.Next() - if err != nil { - resp.Status = relay.StatusServiceUnavailable - resp.WriteTo(conn) - log.Logf("[relay] %s - %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - raddr = node.Addr - } - - log.Logf("[relay] %s -> %s -> %s", conn.RemoteAddr(), conn.LocalAddr(), raddr) - cc, err = h.options.Chain.DialContext(ctx, - network, raddr, - RetryChainOption(h.options.Retries), - TimeoutChainOption(h.options.Timeout), - ) - if err != nil { - log.Logf("[relay] %s -> %s : %s", conn.RemoteAddr(), raddr, err) - node.MarkDead() - } else { - break - } - } - if err != nil { - resp.Status = relay.StatusServiceUnavailable - resp.WriteTo(conn) - return - } - - node.ResetDead() - defer cc.Close() - - sc := &relayConn{ - Conn: conn, - isServer: true, - udp: udp, - } - resp.WriteTo(&sc.wbuf) - conn = sc - - log.Logf("[relay] %s <-> %s", conn.RemoteAddr(), raddr) - transport(conn, cc) - log.Logf("[relay] %s >-< %s", conn.RemoteAddr(), raddr) -} - -type relayConn struct { - net.Conn - isServer bool - udp bool - wbuf bytes.Buffer - once sync.Once - headerSent bool -} - -func (c *relayConn) Read(b []byte) (n int, err error) { - c.once.Do(func() { - if c.isServer { - return - } - resp := new(relay.Response) - _, err = resp.ReadFrom(c.Conn) - if err != nil { - return - } - if resp.Version != relay.Version1 { - err = relay.ErrBadVersion - return - } - if resp.Status != relay.StatusOK { - err = fmt.Errorf("status %d", resp.Status) - return - } - }) - - if err != nil { - log.Logf("[relay] %s <- %s: %s", c.Conn.LocalAddr(), c.Conn.RemoteAddr(), err) - return - } - - if !c.udp { - return c.Conn.Read(b) - } - var bb [2]byte - _, err = io.ReadFull(c.Conn, bb[:]) - if err != nil { - return - } - dlen := int(binary.BigEndian.Uint16(bb[:])) - if len(b) >= dlen { - return io.ReadFull(c.Conn, b[:dlen]) - } - buf := make([]byte, dlen) - _, err = io.ReadFull(c.Conn, buf) - n = copy(b, buf) - return -} - -func (c *relayConn) ReadFrom(b []byte) (n int, addr net.Addr, err error) { - n, err = c.Read(b) - addr = c.Conn.RemoteAddr() - return -} - -func (c *relayConn) Write(b []byte) (n int, err error) { - if len(b) > 0xFFFF { - err = errors.New("write: data maximum exceeded") - return - } - n = len(b) // force byte length consistent - if c.wbuf.Len() > 0 { - if c.udp { - var bb [2]byte - binary.BigEndian.PutUint16(bb[:2], uint16(len(b))) - c.wbuf.Write(bb[:]) - c.headerSent = true - } - c.wbuf.Write(b) // append the data to the cached header - // _, err = c.Conn.Write(c.wbuf.Bytes()) - // c.wbuf.Reset() - _, err = c.wbuf.WriteTo(c.Conn) - return - } - - if !c.udp { - return c.Conn.Write(b) - } - if !c.headerSent { - c.headerSent = true - b2 := make([]byte, len(b)+2) - copy(b2, b) - _, err = c.Conn.Write(b2) - return - } - nsize := 2 + len(b) - var buf []byte - if nsize <= mediumBufferSize { - buf = mPool.Get().([]byte) - defer mPool.Put(buf) - } else { - buf = make([]byte, nsize) - } - binary.BigEndian.PutUint16(buf[:2], uint16(len(b))) - n = copy(buf[2:], b) - _, err = c.Conn.Write(buf[:nsize]) - return -} - -func (c *relayConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - return c.Write(b) -} diff --git a/gost/reload.go b/gost/reload.go deleted file mode 100644 index 08d708a3..00000000 --- a/gost/reload.go +++ /dev/null @@ -1,65 +0,0 @@ -package gost - -import ( - "io" - "os" - "time" - - "github.com/go-log/log" -) - -// Reloader is the interface for objects that support live reloading. -type Reloader interface { - Reload(r io.Reader) error - Period() time.Duration -} - -// Stoppable is the interface that indicates a Reloader can be stopped. -type Stoppable interface { - Stop() - Stopped() bool -} - -// PeriodReload reloads the config configFile periodically according to the period of the Reloader r. -func PeriodReload(r Reloader, configFile string) error { - if r == nil || configFile == "" { - return nil - } - - var lastMod time.Time - for { - if r.Period() < 0 { - log.Log("[reload] stopped:", configFile) - return nil - } - - f, err := os.Open(configFile) - if err != nil { - return err - } - - mt := lastMod - if finfo, err := f.Stat(); err == nil { - mt = finfo.ModTime() - } - - if !lastMod.IsZero() && !mt.Equal(lastMod) { - log.Log("[reload]", configFile) - if err := r.Reload(f); err != nil { - log.Logf("[reload] %s: %s", configFile, err) - } - } - f.Close() - lastMod = mt - - period := r.Period() - if period == 0 { - log.Log("[reload] disabled:", configFile) - return nil - } - if period < time.Second { - period = time.Second - } - <-time.After(period) - } -} diff --git a/gost/resolver.go b/gost/resolver.go deleted file mode 100644 index 618d7249..00000000 --- a/gost/resolver.go +++ /dev/null @@ -1,914 +0,0 @@ -package gost - -import ( - "bufio" - "bytes" - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "io/ioutil" - "net" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "github.com/go-log/log" - "github.com/miekg/dns" -) - -var ( - // DefaultResolverTimeout is the default timeout for name resolution. - DefaultResolverTimeout = 5 * time.Second -) - -type nameServerOptions struct { - timeout time.Duration - chain *Chain -} - -// NameServerOption allows a common way to set name server options. -type NameServerOption func(*nameServerOptions) - -// TimeoutNameServerOption sets the timeout for name server. -func TimeoutNameServerOption(timeout time.Duration) NameServerOption { - return func(opts *nameServerOptions) { - opts.timeout = timeout - } -} - -// ChainNameServerOption sets the chain for name server. -func ChainNameServerOption(chain *Chain) NameServerOption { - return func(opts *nameServerOptions) { - opts.chain = chain - } -} - -// NameServer is a name server. -// Currently supported protocol: TCP, UDP and TLS. -type NameServer struct { - Addr string - Protocol string - Hostname string // for TLS handshake verification - exchanger Exchanger - options nameServerOptions -} - -// Init initializes the name server. -func (ns *NameServer) Init(opts ...NameServerOption) error { - for _, opt := range opts { - opt(&ns.options) - } - - options := []ExchangerOption{ - TimeoutExchangerOption(ns.options.timeout), - } - protocol := strings.ToLower(ns.Protocol) - switch protocol { - case "tcp", "tcp-chain": - if protocol == "tcp-chain" { - options = append(options, ChainExchangerOption(ns.options.chain)) - } - ns.exchanger = NewDNSTCPExchanger(ns.Addr, options...) - case "tls", "tls-chain": - if protocol == "tls-chain" { - options = append(options, ChainExchangerOption(ns.options.chain)) - } - cfg := &tls.Config{ - ServerName: ns.Hostname, - } - if cfg.ServerName == "" { - cfg.InsecureSkipVerify = true - } - ns.exchanger = NewDoTExchanger(ns.Addr, cfg, options...) - case "https", "https-chain": - if protocol == "https-chain" { - options = append(options, ChainExchangerOption(ns.options.chain)) - } - u, err := url.Parse(ns.Addr) - if err != nil { - return err - } - u.Scheme = "https" - cfg := &tls.Config{ServerName: ns.Hostname} - if cfg.ServerName == "" { - cfg.InsecureSkipVerify = true - } - ns.exchanger = NewDoHExchanger(u, cfg, options...) - case "udp", "udp-chain": - fallthrough - default: - if protocol == "udp-chain" { - options = append(options, ChainExchangerOption(ns.options.chain)) - } - ns.exchanger = NewDNSExchanger(ns.Addr, options...) - } - - return nil -} - -func (ns *NameServer) String() string { - addr := ns.Addr - prot := ns.Protocol - if prot == "" { - prot = "udp" - } - return fmt.Sprintf("%s/%s", addr, prot) -} - -type resolverOptions struct { - chain *Chain - timeout time.Duration - ttl time.Duration - prefer string - srcIP net.IP -} - -// ResolverOption allows a common way to set Resolver options. -type ResolverOption func(*resolverOptions) - -// ChainResolverOption sets the chain for Resolver. -func ChainResolverOption(chain *Chain) ResolverOption { - return func(opts *resolverOptions) { - opts.chain = chain - } -} - -// TimeoutResolverOption sets the timeout for Resolver. -func TimeoutResolverOption(timeout time.Duration) ResolverOption { - return func(opts *resolverOptions) { - opts.timeout = timeout - } -} - -// TTLResolverOption sets the timeout for Resolver. -func TTLResolverOption(ttl time.Duration) ResolverOption { - return func(opts *resolverOptions) { - opts.ttl = ttl - } -} - -// PreferResolverOption sets the prefer for Resolver. -func PreferResolverOption(prefer string) ResolverOption { - return func(opts *resolverOptions) { - opts.prefer = prefer - } -} - -// SrcIPResolverOption sets the source IP for Resolver. -func SrcIPResolverOption(ip net.IP) ResolverOption { - return func(opts *resolverOptions) { - opts.srcIP = ip - } -} - -// Resolver is a name resolver for domain name. -// It contains a list of name servers. -type Resolver interface { - // Init initializes the Resolver instance. - Init(opts ...ResolverOption) error - // Resolve returns a slice of that host's IPv4 and IPv6 addresses. - Resolve(host string) ([]net.IP, error) - // Exchange performs a synchronous query, - // It sends the message query and waits for a reply. - Exchange(ctx context.Context, query []byte) (reply []byte, err error) -} - -// ReloadResolver is resolover that support live reloading. -type ReloadResolver interface { - Resolver - Reloader - Stoppable -} - -type resolver struct { - servers []NameServer - ttl time.Duration - timeout time.Duration - period time.Duration - domain string - cache *resolverCache - stopped chan struct{} - mux sync.RWMutex - prefer string // ipv4 or ipv6 - srcIP net.IP // for edns0 subnet option - options resolverOptions -} - -// NewResolver create a new Resolver with the given name servers and resolution timeout. -func NewResolver(ttl time.Duration, servers ...NameServer) ReloadResolver { - r := newResolver(ttl, servers...) - return r -} - -func newResolver(ttl time.Duration, servers ...NameServer) *resolver { - return &resolver{ - servers: servers, - cache: newResolverCache(ttl), - stopped: make(chan struct{}), - } -} - -func (r *resolver) Init(opts ...ResolverOption) error { - if r == nil { - return nil - } - - r.mux.Lock() - defer r.mux.Unlock() - - for _, opt := range opts { - opt(&r.options) - } - - timeout := r.timeout - if r.options.timeout != 0 { - timeout = r.options.timeout - } - if timeout <= 0 { - timeout = DefaultResolverTimeout - } - - if r.options.ttl != 0 { - r.ttl = r.options.ttl - } - if r.options.prefer != "" { - r.prefer = r.options.prefer - } - if r.options.srcIP != nil { - r.srcIP = r.options.srcIP - } - - var nss []NameServer - for _, ns := range r.servers { - if err := ns.Init( // init all name servers - ChainNameServerOption(r.options.chain), - TimeoutNameServerOption(timeout), - ); err != nil { - continue // ignore invalid name servers - } - nss = append(nss, ns) - } - - r.servers = nss - - return nil -} - -func (r *resolver) copyServers() []NameServer { - r.mux.RLock() - defer r.mux.RUnlock() - - servers := make([]NameServer, len(r.servers)) - for i := range r.servers { - servers[i] = r.servers[i] - } - - return servers -} - -func (r *resolver) Resolve(host string) (ips []net.IP, err error) { - r.mux.RLock() - domain := r.domain - r.mux.RUnlock() - - if ip := net.ParseIP(host); ip != nil { - return []net.IP{ip}, nil - } - - if !strings.Contains(host, ".") && domain != "" { - host = host + "." + domain - } - - ctx := context.Background() - for _, ns := range r.copyServers() { - ips, err = r.resolve(ctx, ns.exchanger, host) - if err != nil { - log.Logf("[resolver] %s via %s : %s", host, ns.String(), err) - continue - } - - if Debug { - log.Logf("[resolver] %s via %s %v", host, ns.String(), ips) - } - if len(ips) > 0 { - break - } - } - - return -} - -func (r *resolver) resolve(ctx context.Context, ex Exchanger, host string) (ips []net.IP, err error) { - if ex == nil { - return - } - - r.mux.RLock() - prefer := r.prefer - r.mux.RUnlock() - - if prefer == "ipv6" { // prefer ipv6 - mq := &dns.Msg{} - mq.SetQuestion(dns.Fqdn(host), dns.TypeAAAA) - ips, err = r.resolveIPs(ctx, ex, mq) - if err != nil || len(ips) > 0 { - return - } - } - - mq := &dns.Msg{} - mq.SetQuestion(dns.Fqdn(host), dns.TypeA) - return r.resolveIPs(ctx, ex, mq) -} - -func (r *resolver) resolveIPs(ctx context.Context, ex Exchanger, mq *dns.Msg) (ips []net.IP, err error) { - key := newResolverCacheKey(&mq.Question[0]) - mr := r.cache.loadCache(key) - if mr == nil { - r.addSubnetOpt(mq) - mr, err = r.exchangeMsg(ctx, ex, mq) - if err != nil { - return - } - r.cache.storeCache(key, mr, r.TTL()) - } - - for _, ans := range mr.Answer { - if ar, _ := ans.(*dns.AAAA); ar != nil { - ips = append(ips, ar.AAAA) - } - if ar, _ := ans.(*dns.A); ar != nil { - ips = append(ips, ar.A) - } - } - - return -} - -func (r *resolver) addSubnetOpt(m *dns.Msg) { - if m == nil || r.srcIP == nil { - return - } - opt := new(dns.OPT) - opt.Hdr.Name = "." - opt.Hdr.Rrtype = dns.TypeOPT - e := new(dns.EDNS0_SUBNET) - e.Code = dns.EDNS0SUBNET - if ip := r.srcIP.To4(); ip != nil { - e.Family = 1 - e.SourceNetmask = 32 - e.Address = ip.To4() - } else { - e.Family = 2 - e.SourceNetmask = 128 - e.Address = r.srcIP - } - opt.Option = append(opt.Option, e) - m.Extra = append(m.Extra, opt) -} - -func (r *resolver) Exchange(ctx context.Context, query []byte) (reply []byte, err error) { - mq := &dns.Msg{} - if err = mq.Unpack(query); err != nil { - return - } - - if len(mq.Question) == 0 { - return nil, errors.New("empty question") - } - - var mr *dns.Msg - // Only cache for single question. - if len(mq.Question) == 1 { - key := newResolverCacheKey(&mq.Question[0]) - mr = r.cache.loadCache(key) - if mr != nil { - log.Logf("[dns] exchange message %d (cached): %s", mq.Id, mq.Question[0].String()) - mr.Id = mq.Id - return mr.Pack() - } - - defer func() { - if mr != nil { - r.cache.storeCache(key, mr, r.TTL()) - } - }() - } - - r.addSubnetOpt(mq) - - for _, ns := range r.copyServers() { - log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), mq.Question[0].String()) - mr, err = r.exchangeMsg(ctx, ns.exchanger, mq) - if err == nil { - break - } - log.Logf("[dns] exchange message %d via %s: %s", mq.Id, ns.String(), err) - } - if err != nil { - return - } - return mr.Pack() -} - -func (r *resolver) exchangeMsg(ctx context.Context, ex Exchanger, mq *dns.Msg) (mr *dns.Msg, err error) { - query, err := mq.Pack() - if err != nil { - return - } - reply, err := ex.Exchange(ctx, query) - if err != nil { - return - } - - mr = &dns.Msg{} - err = mr.Unpack(reply) - - return -} - -func (r *resolver) TTL() time.Duration { - r.mux.RLock() - defer r.mux.RUnlock() - return r.ttl -} - -func (r *resolver) Reload(rd io.Reader) error { - var ttl, timeout, period time.Duration - var domain, prefer string - var srcIP net.IP - var nss []NameServer - - if rd == nil || r.Stopped() { - return nil - } - - scanner := bufio.NewScanner(rd) - for scanner.Scan() { - line := scanner.Text() - ss := splitLine(line) - if len(ss) == 0 { - continue - } - - switch ss[0] { - case "timeout": // timeout option - if len(ss) > 1 { - timeout, _ = time.ParseDuration(ss[1]) - } - case "ttl": // ttl option - if len(ss) > 1 { - ttl, _ = time.ParseDuration(ss[1]) - } - case "reload": // reload option - if len(ss) > 1 { - period, _ = time.ParseDuration(ss[1]) - } - case "domain": - if len(ss) > 1 { - domain = ss[1] - } - case "search", "sortlist", "options": // we don't support these features in /etc/resolv.conf - case "prefer": - if len(ss) > 1 { - prefer = strings.ToLower(ss[1]) - } - case "ip": - if len(ss) > 1 { - srcIP = net.ParseIP(ss[1]) - } - case "nameserver": // nameserver option, compatible with /etc/resolv.conf - if len(ss) <= 1 { - break - } - ss = ss[1:] - fallthrough - default: - var ns NameServer - switch len(ss) { - case 0: - break - case 1: - ns.Addr = ss[0] - case 2: - ns.Addr = ss[0] - ns.Protocol = ss[1] - default: - ns.Addr = ss[0] - ns.Protocol = ss[1] - ns.Hostname = ss[2] - } - - if strings.HasPrefix(ns.Addr, "https") && ns.Protocol == "" { - ns.Protocol = "https" - } - nss = append(nss, ns) - } - } - - if err := scanner.Err(); err != nil { - return err - } - - r.mux.Lock() - r.ttl = ttl - r.timeout = timeout - r.domain = domain - r.period = period - r.prefer = prefer - r.srcIP = srcIP - r.servers = nss - r.mux.Unlock() - - r.Init() - - return nil -} - -func (r *resolver) Period() time.Duration { - if r.Stopped() { - return -1 - } - - r.mux.RLock() - defer r.mux.RUnlock() - - return r.period -} - -// Stop stops reloading. -func (r *resolver) Stop() { - select { - case <-r.stopped: - default: - close(r.stopped) - } -} - -// Stopped checks whether the reloader is stopped. -func (r *resolver) Stopped() bool { - select { - case <-r.stopped: - return true - default: - return false - } -} - -func (r *resolver) String() string { - if r == nil { - return "" - } - - r.mux.RLock() - defer r.mux.RUnlock() - - b := &bytes.Buffer{} - fmt.Fprintf(b, "TTL %v\n", r.ttl) - fmt.Fprintf(b, "Reload %v\n", r.period) - fmt.Fprintf(b, "Domain %v\n", r.domain) - for i := range r.servers { - fmt.Fprintln(b, r.servers[i]) - } - return b.String() -} - -type resolverCacheKey string - -// newResolverCacheKey generates resolver cache key from question of dns query. -func newResolverCacheKey(q *dns.Question) resolverCacheKey { - if q == nil { - return "" - } - key := fmt.Sprintf("%s%s.%s", q.Name, dns.Class(q.Qclass).String(), dns.Type(q.Qtype).String()) - return resolverCacheKey(key) -} - -type resolverCacheItem struct { - mr *dns.Msg - ts int64 - ttl time.Duration -} - -type resolverCache struct { - m sync.Map -} - -func newResolverCache(ttl time.Duration) *resolverCache { - return &resolverCache{} -} - -func (rc *resolverCache) loadCache(key resolverCacheKey) *dns.Msg { - v, ok := rc.m.Load(key) - if !ok { - return nil - } - - item, ok := v.(*resolverCacheItem) - if !ok { - return nil - } - - elapsed := time.Since(time.Unix(item.ts, 0)) - if item.ttl > 0 && elapsed > item.ttl { - rc.m.Delete(key) - return nil - } - for _, rr := range item.mr.Answer { - if elapsed > time.Duration(rr.Header().Ttl)*time.Second { - rc.m.Delete(key) - return nil - } - } - - if Debug { - log.Logf("[resolver] cache hit %s", key) - } - - return item.mr.Copy() -} - -func (rc *resolverCache) storeCache(key resolverCacheKey, mr *dns.Msg, ttl time.Duration) { - if key == "" || mr == nil || ttl < 0 { - return - } - - rc.m.Store(key, &resolverCacheItem{ - mr: mr.Copy(), - ts: time.Now().Unix(), - ttl: ttl, - }) - if Debug { - log.Logf("[resolver] cache store %s", key) - } -} - -// Exchanger is an interface for DNS synchronous query. -type Exchanger interface { - Exchange(ctx context.Context, query []byte) ([]byte, error) -} - -type exchangerOptions struct { - chain *Chain - timeout time.Duration -} - -// ExchangerOption allows a common way to set Exchanger options. -type ExchangerOption func(opts *exchangerOptions) - -// ChainExchangerOption sets the chain for Exchanger. -func ChainExchangerOption(chain *Chain) ExchangerOption { - return func(opts *exchangerOptions) { - opts.chain = chain - } -} - -// TimeoutExchangerOption sets the timeout for Exchanger. -func TimeoutExchangerOption(timeout time.Duration) ExchangerOption { - return func(opts *exchangerOptions) { - opts.timeout = timeout - } -} - -type dnsExchanger struct { - addr string - options exchangerOptions -} - -// NewDNSExchanger creates a DNS over UDP Exchanger -func NewDNSExchanger(addr string, opts ...ExchangerOption) Exchanger { - var options exchangerOptions - for _, opt := range opts { - opt(&options) - } - - if _, port, _ := net.SplitHostPort(addr); port == "" { - addr = net.JoinHostPort(addr, "53") - } - - return &dnsExchanger{ - addr: addr, - options: options, - } -} - -func (ex *dnsExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { - t := time.Now() - c, err := ex.options.chain.DialContext(ctx, - "udp", ex.addr, - TimeoutChainOption(ex.options.timeout), - ) - if err != nil { - return nil, err - } - c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) - defer c.Close() - - conn := &dns.Conn{ - Conn: c, - } - if _, err = conn.Write(query); err != nil { - return nil, err - } - - mr, err := conn.ReadMsg() - if err != nil { - return nil, err - } - - return mr.Pack() -} - -type dnsTCPExchanger struct { - addr string - options exchangerOptions -} - -// NewDNSTCPExchanger creates a DNS over TCP Exchanger -func NewDNSTCPExchanger(addr string, opts ...ExchangerOption) Exchanger { - var options exchangerOptions - for _, opt := range opts { - opt(&options) - } - - if _, port, _ := net.SplitHostPort(addr); port == "" { - addr = net.JoinHostPort(addr, "53") - } - - return &dnsTCPExchanger{ - addr: addr, - options: options, - } -} - -func (ex *dnsTCPExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { - t := time.Now() - c, err := ex.options.chain.DialContext(ctx, - "tcp", ex.addr, - TimeoutChainOption(ex.options.timeout), - ) - if err != nil { - return nil, err - } - c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) - defer c.Close() - - conn := &dns.Conn{ - Conn: c, - } - if _, err = conn.Write(query); err != nil { - return nil, err - } - - mr, err := conn.ReadMsg() - if err != nil { - return nil, err - } - - return mr.Pack() -} - -type dotExchanger struct { - addr string - tlsConfig *tls.Config - options exchangerOptions -} - -// NewDoTExchanger creates a DNS over TLS Exchanger -func NewDoTExchanger(addr string, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { - var options exchangerOptions - for _, opt := range opts { - opt(&options) - } - - if _, port, _ := net.SplitHostPort(addr); port == "" { - addr = net.JoinHostPort(addr, "53") - } - - if tlsConfig == nil { - tlsConfig = &tls.Config{ - InsecureSkipVerify: true, - } - } - return &dotExchanger{ - addr: addr, - tlsConfig: tlsConfig, - options: options, - } -} - -func (ex *dotExchanger) dial(ctx context.Context, network, address string) (conn net.Conn, err error) { - conn, err = ex.options.chain.DialContext(ctx, - network, address, - TimeoutChainOption(ex.options.timeout), - ) - if err != nil { - return - } - conn = tls.Client(conn, ex.tlsConfig) - - return -} - -func (ex *dotExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { - t := time.Now() - c, err := ex.dial(ctx, "tcp", ex.addr) - if err != nil { - return nil, err - } - c.SetDeadline(time.Now().Add(ex.options.timeout - time.Since(t))) - defer c.Close() - - conn := &dns.Conn{ - Conn: c, - } - if _, err = conn.Write(query); err != nil { - return nil, err - } - - mr, err := conn.ReadMsg() - if err != nil { - return nil, err - } - - return mr.Pack() -} - -type dohExchanger struct { - endpoint *url.URL - client *http.Client - options exchangerOptions -} - -// NewDoHExchanger creates a DNS over HTTPS Exchanger -func NewDoHExchanger(urlStr *url.URL, tlsConfig *tls.Config, opts ...ExchangerOption) Exchanger { - var options exchangerOptions - for _, opt := range opts { - opt(&options) - } - ex := &dohExchanger{ - endpoint: urlStr, - options: options, - } - - ex.client = &http.Client{ - Timeout: options.timeout, - Transport: &http.Transport{ - // Proxy: ProxyFromEnvironment, - TLSClientConfig: tlsConfig, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: options.timeout, - ExpectContinueTimeout: 1 * time.Second, - DialContext: ex.dialContext, - }, - } - - return ex -} - -func (ex *dohExchanger) dialContext(ctx context.Context, network, address string) (net.Conn, error) { - return ex.options.chain.DialContext(ctx, - network, address, - TimeoutChainOption(ex.options.timeout), - ) -} - -func (ex *dohExchanger) Exchange(ctx context.Context, query []byte) ([]byte, error) { - req, err := http.NewRequestWithContext(ctx, "POST", ex.endpoint.String(), bytes.NewBuffer(query)) - if err != nil { - return nil, fmt.Errorf("failed to create an HTTPS request: %s", err) - } - - // req.Header.Add("Content-Type", "application/dns-udpwireformat") - req.Header.Add("Content-Type", "application/dns-message") - req.Host = ex.endpoint.Hostname() - - client := ex.client - if client == nil { - client = http.DefaultClient - } - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to perform an HTTPS request: %s", err) - } - - // Check response status code - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("returned status code %d", resp.StatusCode) - } - - // Read wireformat response from the body - buf, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("failed to read the response body: %s", err) - } - - return buf, nil -} diff --git a/gost/resolver_test.go b/gost/resolver_test.go deleted file mode 100644 index 79732ea9..00000000 --- a/gost/resolver_test.go +++ /dev/null @@ -1,270 +0,0 @@ -package gost - -import ( - "bytes" - "fmt" - "io" - "net" - "testing" - "time" -) - -var dnsTests = []struct { - ns NameServer - host string - pass bool -}{ - {NameServer{Addr: "1.1.1.1"}, "192.168.1.1", true}, - {NameServer{Addr: "1.1.1.1"}, "github", true}, - {NameServer{Addr: "1.1.1.1"}, "github.com", true}, - {NameServer{Addr: "1.1.1.1:53"}, "github.com", true}, - {NameServer{Addr: "1.1.1.1:53", Protocol: "tcp"}, "github.com", true}, - {NameServer{Addr: "1.1.1.1:853", Protocol: "tls"}, "github.com", true}, - {NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "example.com"}, "github.com", false}, - {NameServer{Addr: "1.1.1.1:853", Protocol: "tls", Hostname: "cloudflare-dns.com"}, "github.com", true}, - {NameServer{Addr: "https://cloudflare-dns.com/dns-query", Protocol: "https"}, "github.com", true}, - {NameServer{Addr: "https://1.0.0.1/dns-query", Protocol: "https"}, "github.com", true}, - {NameServer{Addr: "1.1.1.1:12345"}, "github.com", false}, - {NameServer{Addr: "1.1.1.1:12345", Protocol: "tcp"}, "github.com", false}, - {NameServer{Addr: "1.1.1.1:12345", Protocol: "tls"}, "github.com", false}, - {NameServer{Addr: "https://1.0.0.1:12345/dns-query", Protocol: "https"}, "github.com", false}, -} - -func dnsResolverRoundtrip(t *testing.T, r Resolver, host string) error { - ips, err := r.Resolve(host) - t.Log(host, ips, err) - if err != nil { - return err - } - - return nil -} - -func TestDNSResolver(t *testing.T) { - for i, tc := range dnsTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - ns := tc.ns - t.Log(ns) - r := NewResolver(0, ns) - resolv := r.(*resolver) - resolv.domain = "com" - if err := r.Init(); err != nil { - t.Error("got error:", err) - } - err := dnsResolverRoundtrip(t, r, tc.host) - if err != nil { - if tc.pass { - t.Error("got error:", err) - } - } else { - if !tc.pass { - t.Error("should failed") - } - } - }) - } -} - -var resolverCacheTests = []struct { - name string - ips []net.IP - ttl time.Duration - result []net.IP -}{ - {"", nil, 0, nil}, - {"", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil}, - {"", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, nil}, - {"example.com", nil, 10 * time.Second, nil}, - {"example.com", []net.IP{}, 10 * time.Second, nil}, - {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 0, nil}, - {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, -1, nil}, - {"example.com", []net.IP{net.IPv4(192, 168, 1, 1)}, 10 * time.Second, - []net.IP{net.IPv4(192, 168, 1, 1)}}, - {"example.com", []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}, 10 * time.Second, - []net.IP{net.IPv4(192, 168, 1, 1), net.IPv4(192, 168, 1, 2)}}, -} - -/* -func TestResolverCache(t *testing.T) { - isEqual := func(a, b []net.IP) bool { - if a == nil && b == nil { - return true - } - - if a == nil || b == nil || len(a) != len(b) { - return false - } - - for i := range a { - if !a[i].Equal(b[i]) { - return false - } - } - return true - } - for i, tc := range resolverCacheTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - r := newResolver(tc.ttl) - r.cache.storeCache(tc.name, tc.ips, tc.ttl) - ips := r.cache.loadCache(tc.name, tc.ttl) - - if !isEqual(tc.result, ips) { - t.Error("unexpected cache value:", tc.name, ips, tc.ttl) - } - }) - } -} -*/ - -var resolverReloadTests = []struct { - r io.Reader - - timeout time.Duration - ttl time.Duration - domain string - period time.Duration - ns *NameServer - - stopped bool -}{ - { - r: nil, - }, - { - r: bytes.NewBufferString(""), - }, - { - r: bytes.NewBufferString("reload 10s"), - period: 10 * time.Second, - }, - { - r: bytes.NewBufferString("timeout 10s\nreload 10s\n"), - timeout: 10 * time.Second, - period: 10 * time.Second, - }, - { - r: bytes.NewBufferString("ttl 10s\ntimeout 10s\nreload 10s\n"), - timeout: 10 * time.Second, - period: 10 * time.Second, - ttl: 10 * time.Second, - }, - { - r: bytes.NewBufferString("domain example.com\nttl 10s\ntimeout 10s\nreload 10s\n"), - timeout: 10 * time.Second, - period: 10 * time.Second, - ttl: 10 * time.Second, - domain: "example.com", - }, - { - r: bytes.NewBufferString("1.1.1.1"), - ns: &NameServer{ - Addr: "1.1.1.1", - }, - stopped: true, - }, - { - r: bytes.NewBufferString("\n# comment\ntimeout 10s\nsearch\nnameserver \nnameserver 1.1.1.1 udp"), - ns: &NameServer{ - Protocol: "udp", - Addr: "1.1.1.1", - }, - timeout: 10 * time.Second, - stopped: true, - }, - { - r: bytes.NewBufferString("1.1.1.1 tcp"), - ns: &NameServer{ - Addr: "1.1.1.1", - Protocol: "tcp", - }, - stopped: true, - }, - { - r: bytes.NewBufferString("1.1.1.1:853 tls cloudflare-dns.com"), - ns: &NameServer{ - Addr: "1.1.1.1:853", - Protocol: "tls", - Hostname: "cloudflare-dns.com", - }, - stopped: true, - }, - { - r: bytes.NewBufferString("1.1.1.1:853 tls"), - ns: &NameServer{ - Addr: "1.1.1.1:853", - Protocol: "tls", - }, - stopped: true, - }, - { - r: bytes.NewBufferString("1.0.0.1:53 https"), - stopped: true, - }, - { - r: bytes.NewBufferString("https://1.0.0.1/dns-query"), - ns: &NameServer{ - Addr: "https://1.0.0.1/dns-query", - Protocol: "https", - }, - stopped: true, - }, -} - -func TestResolverReload(t *testing.T) { - for i, tc := range resolverReloadTests { - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - r := newResolver(0) - if err := r.Reload(tc.r); err != nil { - t.Error(err) - } - t.Log(r.String()) - if r.TTL() != tc.ttl { - t.Errorf("ttl value should be %v, got %v", - tc.ttl, r.TTL()) - } - if r.Period() != tc.period { - t.Errorf("period value should be %v, got %v", - tc.period, r.period) - } - if r.domain != tc.domain { - t.Errorf("domain value should be %v, got %v", - tc.domain, r.domain) - } - - var ns *NameServer - if len(r.servers) > 0 { - ns = &r.servers[0] - } - - if !compareNameServer(ns, tc.ns) { - t.Errorf("nameserver not equal, should be %v, got %v", - tc.ns, r.servers) - } - - if tc.stopped { - r.Stop() - if r.Period() >= 0 { - t.Errorf("period of the stopped reloader should be minus value") - } - } - if r.Stopped() != tc.stopped { - t.Errorf("stopped value should be %v, got %v", - tc.stopped, r.Stopped()) - } - }) - } -} - -func compareNameServer(n1, n2 *NameServer) bool { - if n1 == n2 { - return true - } - if n1 == nil || n2 == nil { - return false - } - return n1.Addr == n2.Addr && - n1.Hostname == n2.Hostname && - n1.Protocol == n2.Protocol -} diff --git a/gost/server.go b/gost/server.go index dd8d5565..b130382a 100644 --- a/gost/server.go +++ b/gost/server.go @@ -1,6 +1,7 @@ package gost import ( + "fmt" "io" "net" "time" @@ -56,7 +57,8 @@ func (s *Server) Serve(h Handler, opts ...ServerOption) error { h = s.Handler } if h == nil { - h = HTTPHandler() + fmt.Println("handler is nil =====================================================") + //h = HTTPHandler() } l := s.Listener diff --git a/gost/signal.go b/gost/signal.go deleted file mode 100644 index f12e9023..00000000 --- a/gost/signal.go +++ /dev/null @@ -1,5 +0,0 @@ -// +build windows - -package gost - -func kcpSigHandler() {} diff --git a/gost/signal_unix.go b/gost/signal_unix.go deleted file mode 100644 index f1e91404..00000000 --- a/gost/signal_unix.go +++ /dev/null @@ -1,24 +0,0 @@ -// +build !windows - -package gost - -import ( - "os" - "os/signal" - "syscall" - - "github.com/go-log/log" - "github.com/xtaci/kcp-go" -) - -func kcpSigHandler() { - ch := make(chan os.Signal, 1) - signal.Notify(ch, syscall.SIGUSR1) - - for { - switch <-ch { - case syscall.SIGUSR1: - log.Logf("[kcp] SNMP: %+v", kcp.DefaultSnmp.Copy()) - } - } -} diff --git a/gost/sni.go b/gost/sni.go deleted file mode 100644 index e57d4a0f..00000000 --- a/gost/sni.go +++ /dev/null @@ -1,350 +0,0 @@ -// SNI proxy based on https://github.com/bradfitz/tcpproxy - -package gost - -import ( - "bufio" - "bytes" - "context" - "encoding/base64" - "encoding/binary" - "errors" - "fmt" - "hash/crc32" - "io" - "net" - "net/http" - "strings" - "sync" - - "github.com/asaskevich/govalidator" - dissector "github.com/ginuerzh/tls-dissector" - "github.com/go-log/log" -) - -type sniConnector struct { - host string -} - -// SNIConnector creates a Connector for SNI proxy client. -func SNIConnector(host string) Connector { - return &sniConnector{host: host} -} - -func (c *sniConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { - return c.ConnectContext(context.Background(), conn, "tcp", address, options...) -} - -func (c *sniConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { - switch network { - case "udp", "udp4", "udp6": - return nil, fmt.Errorf("%s unsupported", network) - } - - return &sniClientConn{addr: address, host: c.host, Conn: conn}, nil -} - -type sniHandler struct { - options *HandlerOptions -} - -// SNIHandler creates a server Handler for SNI proxy server. -func SNIHandler(opts ...HandlerOption) Handler { - h := &sniHandler{} - h.Init(opts...) - - return h -} - -func (h *sniHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - - for _, opt := range options { - opt(h.options) - } -} - -func (h *sniHandler) Handle(conn net.Conn) { - defer conn.Close() - - br := bufio.NewReader(conn) - hdr, err := br.Peek(dissector.RecordHeaderLen) - if err != nil { - log.Logf("[sni] %s -> %s : %s", - conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - conn = &bufferdConn{br: br, Conn: conn} - - if hdr[0] != dissector.Handshake { - // We assume it is an HTTP request - req, err := http.ReadRequest(bufio.NewReader(conn)) - if err != nil { - log.Logf("[sni] %s -> %s : %s", - conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - - if !req.URL.IsAbs() && govalidator.IsDNSName(req.Host) { - req.URL.Scheme = "http" - } - - handler := &httpHandler{options: h.options} - handler.Init() - handler.handleRequest(conn, req) - return - } - - b, host, err := readClientHelloRecord(conn, "", false) - if err != nil { - log.Logf("[sni] %s -> %s : %s", - conn.RemoteAddr(), conn.LocalAddr(), err) - return - } - - _, sport, _ := net.SplitHostPort(h.options.Host) - if sport == "" { - sport = "443" - } - host = net.JoinHostPort(host, sport) - - log.Logf("[sni] %s -> %s -> %s", - conn.RemoteAddr(), h.options.Node.String(), host) - - if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[sni] %s -> %s : Unauthorized to tcp connect to %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - return - } - if h.options.Bypass.Contains(host) { - log.Log("[sni] %s - %s bypass %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - return - } - - retries := 1 - if h.options.Chain != nil && h.options.Chain.Retries > 0 { - retries = h.options.Chain.Retries - } - if h.options.Retries > 0 { - retries = h.options.Retries - } - - var cc net.Conn - var route *Chain - for i := 0; i < retries; i++ { - route, err = h.options.Chain.selectRouteFor(host) - if err != nil { - log.Logf("[sni] %s -> %s : %s", - conn.RemoteAddr(), conn.LocalAddr(), err) - continue - } - - buf := bytes.Buffer{} - fmt.Fprintf(&buf, "%s -> %s -> ", - conn.RemoteAddr(), h.options.Node.String()) - for _, nd := range route.route { - fmt.Fprintf(&buf, "%d@%s -> ", nd.ID, nd.String()) - } - fmt.Fprintf(&buf, "%s", host) - log.Log("[route]", buf.String()) - - cc, err = route.Dial(host, - TimeoutChainOption(h.options.Timeout), - HostsChainOption(h.options.Hosts), - ResolverChainOption(h.options.Resolver), - ) - if err == nil { - break - } - log.Logf("[sni] %s -> %s : %s", - conn.RemoteAddr(), conn.LocalAddr(), err) - } - - if err != nil { - return - } - defer cc.Close() - - if _, err := cc.Write(b); err != nil { - log.Logf("[sni] %s -> %s : %s", - conn.RemoteAddr(), conn.LocalAddr(), err) - } - - log.Logf("[sni] %s <-> %s", cc.LocalAddr(), host) - transport(conn, cc) - log.Logf("[sni] %s >-< %s", cc.LocalAddr(), host) -} - -// sniSniffConn is a net.Conn that reads from r, fails on Writes, -// and crashes otherwise. -type sniSniffConn struct { - r io.Reader - net.Conn // nil; crash on any unexpected use -} - -func (c sniSniffConn) Read(p []byte) (int, error) { return c.r.Read(p) } -func (sniSniffConn) Write(p []byte) (int, error) { return 0, io.EOF } - -type sniClientConn struct { - addr string - host string - mutex sync.Mutex - obfuscated bool - net.Conn -} - -func (c *sniClientConn) Write(p []byte) (int, error) { - b, err := c.obfuscate(p) - if err != nil { - return 0, err - } - if _, err = c.Conn.Write(b); err != nil { - return 0, err - } - return len(p), nil -} - -func (c *sniClientConn) obfuscate(p []byte) ([]byte, error) { - if c.host == "" { - return p, nil - } - c.mutex.Lock() - defer c.mutex.Unlock() - - if c.obfuscated { - return p, nil - } - - if p[0] == dissector.Handshake { - b, host, err := readClientHelloRecord(bytes.NewReader(p), c.host, true) - if err != nil { - return nil, err - } - if Debug { - log.Logf("[sni] obfuscate: %s -> %s", c.addr, host) - } - c.obfuscated = true - return b, nil - } - - buf := &bytes.Buffer{} - br := bufio.NewReader(bytes.NewReader(p)) - for { - s, err := br.ReadString('\n') - if err != nil { - if err != io.EOF { - return nil, err - } - if s != "" { - buf.Write([]byte(s)) - } - break - } - - // end of HTTP header - if s == "\r\n" { - buf.Write([]byte(s)) - // drain the remain bytes. - io.Copy(buf, br) - break - } - - if strings.HasPrefix(s, "Host") { - s = strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(s, "Host:"), "\r\n")) - host := encodeServerName(s) - if Debug { - log.Logf("[sni] obfuscate: %s -> %s", s, c.host) - } - buf.WriteString("Host: " + c.host + "\r\n") - buf.WriteString("Gost-Target: " + host + "\r\n") - // drain the remain bytes. - io.Copy(buf, br) - break - } - buf.Write([]byte(s)) - } - c.obfuscated = true - return buf.Bytes(), nil -} - -func readClientHelloRecord(r io.Reader, host string, isClient bool) ([]byte, string, error) { - record, err := dissector.ReadRecord(r) - if err != nil { - return nil, "", err - } - clientHello := &dissector.ClientHelloMsg{} - if err := clientHello.Decode(record.Opaque); err != nil { - return nil, "", err - } - - if !isClient { - var extensions []dissector.Extension - - for _, ext := range clientHello.Extensions { - if ext.Type() == 0xFFFE { - b, _ := ext.Encode() - if host, err = decodeServerName(string(b)); err == nil { - continue - } - } - extensions = append(extensions, ext) - } - clientHello.Extensions = extensions - } - - for _, ext := range clientHello.Extensions { - if ext.Type() == dissector.ExtServerName { - snExtension := ext.(*dissector.ServerNameExtension) - if host == "" { - host = snExtension.Name - } - if isClient { - e, _ := dissector.NewExtension(0xFFFE, []byte(encodeServerName(snExtension.Name))) - clientHello.Extensions = append(clientHello.Extensions, e) - } - if host != "" { - snExtension.Name = host - } - break - } - } - record.Opaque, err = clientHello.Encode() - if err != nil { - return nil, "", err - } - - buf := &bytes.Buffer{} - if _, err := record.WriteTo(buf); err != nil { - return nil, "", err - } - - return buf.Bytes(), host, nil -} - -func encodeServerName(name string) string { - buf := &bytes.Buffer{} - binary.Write(buf, binary.BigEndian, crc32.ChecksumIEEE([]byte(name))) - buf.WriteString(base64.RawURLEncoding.EncodeToString([]byte(name))) - return base64.RawURLEncoding.EncodeToString(buf.Bytes()) -} - -func decodeServerName(s string) (string, error) { - b, err := base64.RawURLEncoding.DecodeString(s) - if err != nil { - return "", err - } - if len(b) < 4 { - return "", errors.New("invalid name") - } - v, err := base64.RawURLEncoding.DecodeString(string(b[4:])) - if err != nil { - return "", err - } - if crc32.ChecksumIEEE(v) != binary.BigEndian.Uint32(b[:4]) { - return "", errors.New("invalid name") - } - return string(v), nil -} diff --git a/gost/sni_test.go b/gost/sni_test.go deleted file mode 100644 index 0bfa9941..00000000 --- a/gost/sni_test.go +++ /dev/null @@ -1,148 +0,0 @@ -package gost - -import ( - "bufio" - "bytes" - "crypto/rand" - "crypto/tls" - "errors" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "net/url" - "testing" - "time" -) - -func sniRoundtrip(client *Client, server *Server, targetURL string, data []byte) (err error) { - conn, err := client.Dial(server.Addr().String()) - if err != nil { - return - } - - conn, err = client.Handshake(conn, AddrHandshakeOption(server.Addr().String())) - if err != nil { - return - } - defer conn.Close() - - u, err := url.Parse(targetURL) - if err != nil { - return - } - - conn.SetDeadline(time.Now().Add(3 * time.Second)) - defer conn.SetDeadline(time.Time{}) - - conn, err = client.Connect(conn, u.Host) - if err != nil { - return - } - - if u.Scheme == "https" { - conn = tls.Client(conn, - &tls.Config{ - InsecureSkipVerify: true, - // ServerName: u.Hostname(), - }) - u.Scheme = "http" - } - req, err := http.NewRequest( - http.MethodGet, - u.String(), - bytes.NewReader(data), - ) - if err != nil { - return - } - if err = req.WriteProxy(conn); err != nil { - return - } - resp, err := http.ReadResponse(bufio.NewReader(conn), req) - if err != nil { - return - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errors.New(resp.Status) - } - - recv, err := ioutil.ReadAll(resp.Body) - if err != nil { - return - } - - if !bytes.Equal(data, recv) { - return fmt.Errorf("data not equal") - } - - return -} - -func sniProxyRoundtrip(targetURL string, data []byte, host string) error { - ln, err := TCPListener("") - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: TCPTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIProxy(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniProxyRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} diff --git a/gost/socks.go b/gost/socks.go index 75682052..9e5ca154 100644 --- a/gost/socks.go +++ b/gost/socks.go @@ -888,29 +888,6 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { log.Logf("[socks5] %s -> %s -> %s", conn.RemoteAddr(), h.options.Node.String(), host) - if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5] %s - %s : Unauthorized to tcp connect to %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks5] %s <- %s\n%s", - conn.RemoteAddr(), conn.LocalAddr(), rep) - } - return - } - if h.options.Bypass.Contains(host) { - log.Logf("[socks5] %s - %s : Bypass %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks5] %s <- %s\n%s", - conn.RemoteAddr(), conn.LocalAddr(), rep) - } - return - } - retries := 1 if h.options.Chain != nil && h.options.Chain.Retries > 0 { retries = h.options.Chain.Retries @@ -941,8 +918,6 @@ func (h *socks5Handler) handleConnect(conn net.Conn, req *gosocks5.Request) { cc, err = route.Dial(host, TimeoutChainOption(h.options.Timeout), - HostsChainOption(h.options.Hosts), - ResolverChainOption(h.options.Resolver), ) if err == nil { break @@ -984,11 +959,6 @@ func (h *socks5Handler) handleBind(conn net.Conn, req *gosocks5.Request) { conn.RemoteAddr(), h.options.Node.String(), addr) if h.options.Chain.IsEmpty() { - if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5-bind] %s - %s : Unauthorized to tcp bind to %s", - conn.RemoteAddr(), conn.LocalAddr(), addr) - return - } h.bindOn(conn, addr) return } @@ -1113,16 +1083,7 @@ func (h *socks5Handler) bindOn(conn net.Conn, addr string) { } func (h *socks5Handler) handleUDPRelay(conn net.Conn, req *gosocks5.Request) { - addr := req.Addr.String() - if !Can("udp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5-udp] Unauthorized to udp connect to %s", addr) - rep := gosocks5.NewReply(gosocks5.NotAllowed, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks5-udp] %s <- %s\n%s", conn.RemoteAddr(), req.Addr, rep) - } - return - } + //addr := req.Addr.String() relay, err := net.ListenUDP("udp", nil) if err != nil { @@ -1260,10 +1221,6 @@ func (h *socks5Handler) transportUDP(relay, peer net.PacketConn) (err error) { if err != nil { continue // drop silently } - if h.options.Bypass.Contains(raddr.String()) { - log.Log("[socks5-udp] [bypass] write to", raddr) - continue // bypass - } if _, err := peer.WriteTo(dgram.Data, raddr); err != nil { errc <- err return @@ -1287,10 +1244,6 @@ func (h *socks5Handler) transportUDP(relay, peer net.PacketConn) (err error) { if clientAddr == nil { continue } - if h.options.Bypass.Contains(raddr.String()) { - log.Log("[socks5-udp] [bypass] read from", raddr) - continue // bypass - } buf := bytes.Buffer{} dgram := gosocks5.NewUDPDatagram(gosocks5.NewUDPHeader(0, 0, toSocksAddr(raddr)), b[:n]) dgram.Write(&buf) @@ -1339,11 +1292,7 @@ func (h *socks5Handler) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error if clientAddr == nil { clientAddr = addr } - raddr := dgram.Header.Addr.String() - if h.options.Bypass.Contains(raddr) { - log.Log("[udp-tun] [bypass] write to", raddr) - continue // bypass - } + //raddr := dgram.Header.Addr.String() dgram.Header.Rsv = uint16(len(dgram.Data)) if err := dgram.Write(cc); err != nil { errc <- err @@ -1368,11 +1317,7 @@ func (h *socks5Handler) tunnelClientUDP(uc *net.UDPConn, cc net.Conn) (err error if clientAddr == nil { continue } - raddr := dgram.Header.Addr.String() - if h.options.Bypass.Contains(raddr) { - log.Log("[udp-tun] [bypass] read from", raddr) - continue // bypass - } + //raddr := dgram.Header.Addr.String() dgram.Header.Rsv = 0 buf := bytes.Buffer{} @@ -1399,11 +1344,6 @@ func (h *socks5Handler) handleUDPTunnel(conn net.Conn, req *gosocks5.Request) { if h.options.Chain.IsEmpty() { addr := req.Addr.String() - if !Can("rudp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks5] udp-tun Unauthorized to udp bind to %s", addr) - return - } - bindAddr, _ := net.ResolveUDPAddr("udp", addr) uc, err := net.ListenUDP("udp", bindAddr) if err != nil { @@ -1468,10 +1408,6 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err errc <- err return } - if h.options.Bypass.Contains(addr.String()) { - log.Log("[socks5] udp-tun bypass read from", addr) - continue // bypass - } // pipe from peer to tunnel dgram := gosocks5.NewUDPDatagram( @@ -1501,10 +1437,6 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err if err != nil { continue // drop silently } - if h.options.Bypass.Contains(addr.String()) { - log.Log("[socks5] udp-tun bypass write to", addr) - continue // bypass - } if _, err := pc.WriteTo(dgram.Data, addr); err != nil { log.Logf("[socks5] udp-tun %s -> %s : %s", cc.RemoteAddr(), addr, err) errc <- err @@ -1526,10 +1458,6 @@ func (h *socks5Handler) tunnelServerUDP(cc net.Conn, pc net.PacketConn) (err err func (h *socks5Handler) handleMuxBind(conn net.Conn, req *gosocks5.Request) { if h.options.Chain.IsEmpty() { addr := req.Addr.String() - if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("Unauthorized to tcp mbind to %s", addr) - return - } h.muxBindOn(conn, addr) return } @@ -1703,29 +1631,6 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { log.Logf("[socks4] %s -> %s -> %s", conn.RemoteAddr(), h.options.Node.String(), addr) - if !Can("tcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[socks4] %s - %s : Unauthorized to tcp connect to %s", - conn.RemoteAddr(), conn.LocalAddr(), addr) - rep := gosocks4.NewReply(gosocks4.Rejected, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks4] %s <- %s\n%s", - conn.RemoteAddr(), conn.LocalAddr(), rep) - } - return - } - if h.options.Bypass.Contains(addr) { - log.Log("[socks4] %s - %s : Bypass %s", - conn.RemoteAddr(), conn.LocalAddr(), addr) - rep := gosocks4.NewReply(gosocks4.Rejected, nil) - rep.Write(conn) - if Debug { - log.Logf("[socks4] %s <- %s\n%s", - conn.RemoteAddr(), conn.LocalAddr(), rep) - } - return - } - retries := 1 if h.options.Chain != nil && h.options.Chain.Retries > 0 { retries = h.options.Chain.Retries @@ -1756,8 +1661,6 @@ func (h *socks4Handler) handleConnect(conn net.Conn, req *gosocks4.Request) { cc, err = route.Dial(addr, TimeoutChainOption(h.options.Timeout), - HostsChainOption(h.options.Hosts), - ResolverChainOption(h.options.Resolver), ) if err == nil { break diff --git a/gost/ss.go b/gost/ss.go index e3532fdf..c3c16d42 100644 --- a/gost/ss.go +++ b/gost/ss.go @@ -142,18 +142,6 @@ func (h *shadowHandler) Handle(conn net.Conn) { log.Logf("[ss] %s -> %s", conn.RemoteAddr(), host) - if !Can("tcp", host, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ss] %s - %s : Unauthorized to tcp connect to %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - return - } - - if h.options.Bypass.Contains(host) { - log.Logf("[ss] %s - %s : Bypass %s", - conn.RemoteAddr(), conn.LocalAddr(), host) - return - } - retries := 1 if h.options.Chain != nil && h.options.Chain.Retries > 0 { retries = h.options.Chain.Retries @@ -183,8 +171,6 @@ func (h *shadowHandler) Handle(conn net.Conn) { cc, err = route.Dial(host, TimeoutChainOption(h.options.Timeout), - HostsChainOption(h.options.Hosts), - ResolverChainOption(h.options.Resolver), ) if err == nil { break @@ -434,10 +420,6 @@ func (h *shadowUDPHandler) transportUDP(conn net.Conn, cc net.PacketConn) error if err != nil { return } - if h.options.Bypass.Contains(addr.String()) { - log.Log("[ssu] bypass", addr) - return // bypass - } _, err = cc.WriteTo(dgram.Data, addr) return }() @@ -462,10 +444,6 @@ func (h *shadowUDPHandler) transportUDP(conn net.Conn, cc net.PacketConn) error if Debug { log.Logf("[ssu] %s <<< %s length: %d", conn.RemoteAddr(), addr, n) } - if h.options.Bypass.Contains(addr.String()) { - log.Log("[ssu] bypass", addr) - return // bypass - } dgram := gosocks5.NewUDPDatagram( gosocks5.NewUDPHeader(uint16(n), 0, toSocksAddr(addr)), b[:n]) buf := bytes.Buffer{} diff --git a/gost/ssh.go b/gost/ssh.go deleted file mode 100644 index f772503a..00000000 --- a/gost/ssh.go +++ /dev/null @@ -1,982 +0,0 @@ -package gost - -import ( - "context" - "crypto/tls" - "encoding/binary" - "errors" - "fmt" - "io/ioutil" - "net" - "strconv" - "strings" - "sync" - "time" - - "github.com/go-log/log" - "golang.org/x/crypto/ssh" -) - -// Applicable SSH Request types for Port Forwarding - RFC 4254 7.X -const ( - DirectForwardRequest = "direct-tcpip" // RFC 4254 7.2 - RemoteForwardRequest = "tcpip-forward" // RFC 4254 7.1 - ForwardedTCPReturnRequest = "forwarded-tcpip" // RFC 4254 7.2 - CancelRemoteForwardRequest = "cancel-tcpip-forward" // RFC 4254 7.1 - - GostSSHTunnelRequest = "gost-tunnel" // extended request type for ssh tunnel -) - -var ( - errSessionDead = errors.New("session is dead") -) - -// ParseSSHKeyFile parses ssh key file. -func ParseSSHKeyFile(fp string) (ssh.Signer, error) { - key, err := ioutil.ReadFile(fp) - if err != nil { - return nil, err - } - return ssh.ParsePrivateKey(key) -} - -// ParseSSHAuthorizedKeysFile parses ssh Authorized Keys file. -func ParseSSHAuthorizedKeysFile(fp string) (map[string]bool, error) { - authorizedKeysBytes, err := ioutil.ReadFile(fp) - if err != nil { - return nil, err - } - authorizedKeysMap := make(map[string]bool) - for len(authorizedKeysBytes) > 0 { - pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) - if err != nil { - return nil, err - } - authorizedKeysMap[string(pubKey.Marshal())] = true - authorizedKeysBytes = rest - } - - return authorizedKeysMap, nil -} - -type sshDirectForwardConnector struct { -} - -// SSHDirectForwardConnector creates a Connector for SSH TCP direct port forwarding. -func SSHDirectForwardConnector() Connector { - return &sshDirectForwardConnector{} -} - -func (c *sshDirectForwardConnector) Connect(conn net.Conn, raddr string, options ...ConnectOption) (net.Conn, error) { - return c.ConnectContext(context.Background(), conn, "tcp", raddr, options...) -} - -func (c *sshDirectForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, raddr string, options ...ConnectOption) (net.Conn, error) { - switch network { - case "udp", "udp4", "udp6": - return nil, fmt.Errorf("%s unsupported", network) - } - - opts := &ConnectOptions{} - for _, option := range options { - option(opts) - } - - cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. - if !ok { - return nil, errors.New("ssh: wrong connection type") - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = ConnectTimeout - } - - cc.session.conn.SetDeadline(time.Now().Add(timeout)) - defer cc.session.conn.SetDeadline(time.Time{}) - - conn, err := cc.session.client.Dial("tcp", raddr) - if err != nil { - log.Logf("[ssh-tcp] %s -> %s : %s", cc.session.addr, raddr, err) - return nil, err - } - return conn, nil -} - -type sshRemoteForwardConnector struct { -} - -// SSHRemoteForwardConnector creates a Connector for SSH TCP remote port forwarding. -func SSHRemoteForwardConnector() Connector { - return &sshRemoteForwardConnector{} -} - -func (c *sshRemoteForwardConnector) Connect(conn net.Conn, address string, options ...ConnectOption) (net.Conn, error) { - return c.ConnectContext(context.Background(), conn, "tcp", address, options...) -} - -func (c *sshRemoteForwardConnector) ConnectContext(ctx context.Context, conn net.Conn, network, address string, options ...ConnectOption) (net.Conn, error) { - switch network { - case "udp", "udp4", "udp6": - return nil, fmt.Errorf("%s unsupported", network) - } - - cc, ok := conn.(*sshNopConn) // TODO: this is an ugly type assertion, need to find a better solution. - if !ok { - return nil, errors.New("ssh: wrong connection type") - } - - cc.session.once.Do(func() { - go func() { - defer log.Log("ssh-rtcp: session is closed") - defer close(cc.session.connChan) - - if cc.session == nil || cc.session.client == nil { - return - } - if strings.HasPrefix(address, ":") { - address = "0.0.0.0" + address - } - ln, err := cc.session.client.Listen("tcp", address) - if err != nil { - return - } - log.Log("[ssh-rtcp] listening on", ln.Addr()) - - for { - rc, err := ln.Accept() - if err != nil { - log.Logf("[ssh-rtcp] %s <-> %s accpet : %s", ln.Addr(), address, err) - return - } - // log.Log("[ssh-rtcp] accept", rc.LocalAddr(), rc.RemoteAddr()) - select { - case cc.session.connChan <- rc: - default: - rc.Close() - log.Logf("[ssh-rtcp] %s - %s: connection queue is full", ln.Addr(), address) - } - } - }() - }) - - sc, ok := <-cc.session.connChan - if !ok { - return nil, errors.New("ssh-rtcp: connection is closed") - } - return sc, nil -} - -type sshForwardTransporter struct { - sessions map[string]*sshSession - sessionMutex sync.Mutex -} - -// SSHForwardTransporter creates a Transporter that is used by SSH port forwarding server. -func SSHForwardTransporter() Transporter { - return &sshForwardTransporter{ - sessions: make(map[string]*sshSession), - } -} - -func (tr *sshForwardTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - - session, ok := tr.sessions[addr] - if !ok || session.Closed() { - if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, timeout) - } else { - conn, err = opts.Chain.Dial(addr) - } - if err != nil { - return - } - session = &sshSession{ - addr: addr, - conn: conn, - } - tr.sessions[addr] = session - } - - return session.conn, nil -} - -func (tr *sshForwardTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = HandshakeTimeout - } - - config := ssh.ClientConfig{ - Timeout: timeout, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - if opts.User != nil { - config.User = opts.User.Username() - if password, _ := opts.User.Password(); password != "" { - config.Auth = []ssh.AuthMethod{ - ssh.Password(password), - } - } - } - if opts.SSHConfig != nil && opts.SSHConfig.Key != nil { - config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key)) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - session, ok := tr.sessions[opts.Addr] - if !ok || session.client == nil { - sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) - if err != nil { - log.Log("ssh", err) - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - - session = &sshSession{ - addr: opts.Addr, - conn: conn, - client: ssh.NewClient(sshConn, chans, reqs), - closed: make(chan struct{}), - deaded: make(chan struct{}), - connChan: make(chan net.Conn, 1024), - } - tr.sessions[opts.Addr] = session - go session.Ping(opts.Interval, opts.Timeout, opts.Retry) - go session.waitServer() - go session.waitClose() - } - if session.Closed() { - delete(tr.sessions, opts.Addr) - return nil, errSessionDead - } - - return &sshNopConn{session: session}, nil -} - -func (tr *sshForwardTransporter) Multiplex() bool { - return true -} - -type sshTunnelTransporter struct { - sessions map[string]*sshSession - sessionMutex sync.Mutex -} - -// SSHTunnelTransporter creates a Transporter that is used by SSH tunnel client. -func SSHTunnelTransporter() Transporter { - return &sshTunnelTransporter{ - sessions: make(map[string]*sshSession), - } -} - -func (tr *sshTunnelTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - - session, ok := tr.sessions[addr] - if !ok || session.Closed() { - if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, timeout) - } else { - conn, err = opts.Chain.Dial(addr) - } - if err != nil { - return - } - session = &sshSession{ - addr: addr, - conn: conn, - } - tr.sessions[addr] = session - } - - return session.conn, nil -} - -func (tr *sshTunnelTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = HandshakeTimeout - } - - config := ssh.ClientConfig{ - Timeout: timeout, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - if opts.User != nil { - config.User = opts.User.Username() - if password, _ := opts.User.Password(); password != "" { - config.Auth = []ssh.AuthMethod{ - ssh.Password(password), - } - } - } - if opts.SSHConfig != nil && opts.SSHConfig.Key != nil { - config.Auth = append(config.Auth, ssh.PublicKeys(opts.SSHConfig.Key)) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - session, ok := tr.sessions[opts.Addr] - if !ok || session.client == nil { - sshConn, chans, reqs, err := ssh.NewClientConn(conn, opts.Addr, &config) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - - session = &sshSession{ - addr: opts.Addr, - conn: conn, - client: ssh.NewClient(sshConn, chans, reqs), - closed: make(chan struct{}), - deaded: make(chan struct{}), - } - tr.sessions[opts.Addr] = session - go session.Ping(opts.Interval, opts.Timeout, opts.Retry) - go session.waitServer() - go session.waitClose() - } - - if session.Closed() { - delete(tr.sessions, opts.Addr) - return nil, errSessionDead - } - - channel, reqs, err := session.client.OpenChannel(GostSSHTunnelRequest, nil) - if err != nil { - return nil, err - } - go ssh.DiscardRequests(reqs) - return &sshConn{channel: channel, conn: conn}, nil -} - -func (tr *sshTunnelTransporter) Multiplex() bool { - return true -} - -type sshSession struct { - addr string - conn net.Conn - client *ssh.Client - closed chan struct{} - deaded chan struct{} - once sync.Once - connChan chan net.Conn -} - -func (s *sshSession) Ping(interval, timeout time.Duration, retries int) { - if interval <= 0 { - return - } - if timeout <= 0 { - timeout = PingTimeout - } - - if retries == 0 { - retries = 1 - } - - defer close(s.deaded) - - log.Logf("[ssh] ping is enabled, interval: %v, timeout: %v, retry: %d", interval, timeout, retries) - baseCtx := context.Background() - t := time.NewTicker(interval) - defer t.Stop() - - count := retries + 1 - for { - select { - case <-t.C: - start := time.Now() - if Debug { - log.Log("[ssh] sending ping") - } - ctx, cancel := context.WithTimeout(baseCtx, timeout) - var err error - select { - case err = <-s.sendPing(): - case <-ctx.Done(): - err = errors.New("Timeout") - } - cancel() - if err != nil { - log.Log("[ssh] ping:", err) - count-- - if count == 0 { - return - } - continue - } - if Debug { - log.Log("[ssh] ping OK, RTT:", time.Since(start)) - } - count = retries + 1 - case <-s.closed: - return - } - } -} - -func (s *sshSession) sendPing() <-chan error { - ch := make(chan error, 1) - go func() { - if _, _, err := s.client.SendRequest("ping", true, nil); err != nil { - ch <- err - } - close(ch) - }() - return ch -} - -func (s *sshSession) waitServer() error { - defer close(s.closed) - return s.client.Wait() -} - -func (s *sshSession) waitClose() { - defer s.client.Close() - - select { - case <-s.deaded: - case <-s.closed: - } -} - -func (s *sshSession) Closed() bool { - select { - case <-s.deaded: - return true - case <-s.closed: - return true - default: - } - return false -} - -type sshForwardHandler struct { - options *HandlerOptions - config *ssh.ServerConfig -} - -// SSHForwardHandler creates a server Handler for SSH port forwarding server. -func SSHForwardHandler(opts ...HandlerOption) Handler { - h := &sshForwardHandler{} - h.Init(opts...) - - return h -} - -func (h *sshForwardHandler) Init(options ...HandlerOption) { - if h.options == nil { - h.options = &HandlerOptions{} - } - - for _, opt := range options { - opt(h.options) - } - h.config = &ssh.ServerConfig{} - - h.config.PasswordCallback = defaultSSHPasswordCallback(h.options.Authenticator) - if h.options.Authenticator == nil { - h.config.NoClientAuth = true - } - tlsConfig := h.options.TLSConfig - if tlsConfig == nil { - tlsConfig = DefaultTLSConfig - } - if tlsConfig != nil && len(tlsConfig.Certificates) > 0 { - signer, err := ssh.NewSignerFromKey(tlsConfig.Certificates[0].PrivateKey) - if err != nil { - log.Log("[ssh-forward]", err) - } - h.config.AddHostKey(signer) - } -} - -func (h *sshForwardHandler) Handle(conn net.Conn) { - sshConn, chans, reqs, err := ssh.NewServerConn(conn, h.config) - if err != nil { - log.Logf("[ssh-forward] %s -> %s : %s", conn.RemoteAddr(), h.options.Node.Addr, err) - conn.Close() - return - } - defer sshConn.Close() - - log.Logf("[ssh-forward] %s <-> %s", conn.RemoteAddr(), h.options.Node.Addr) - h.handleForward(sshConn, chans, reqs) - log.Logf("[ssh-forward] %s >-< %s", conn.RemoteAddr(), h.options.Node.Addr) -} - -func (h *sshForwardHandler) handleForward(conn ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request) { - quit := make(chan struct{}) - defer close(quit) // quit signal - - go func() { - for req := range reqs { - switch req.Type { - case RemoteForwardRequest: - go h.tcpipForwardRequest(conn, req, quit) - default: - // log.Log("[ssh] unknown request type:", req.Type, req.WantReply) - if req.WantReply { - req.Reply(false, nil) - } - } - } - }() - - go func() { - for newChannel := range chans { - // Check the type of channel - t := newChannel.ChannelType() - switch t { - case DirectForwardRequest: - channel, requests, err := newChannel.Accept() - if err != nil { - log.Log("[ssh] Could not accept channel:", err) - continue - } - p := directForward{} - ssh.Unmarshal(newChannel.ExtraData(), &p) - - if p.Host1 == "" { - p.Host1 = "" - } - - go ssh.DiscardRequests(requests) - go h.directPortForwardChannel(channel, fmt.Sprintf("%s:%d", p.Host1, p.Port1)) - default: - log.Log("[ssh] Unknown channel type:", t) - newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - } - } - }() - - conn.Wait() -} - -func (h *sshForwardHandler) directPortForwardChannel(channel ssh.Channel, raddr string) { - defer channel.Close() - - log.Logf("[ssh-tcp] %s - %s", h.options.Node.Addr, raddr) - - if !Can("tcp", raddr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ssh-tcp] Unauthorized to tcp connect to %s", raddr) - return - } - - if h.options.Bypass.Contains(raddr) { - log.Logf("[ssh-tcp] [bypass] %s", raddr) - return - } - - conn, err := h.options.Chain.Dial(raddr, - RetryChainOption(h.options.Retries), - TimeoutChainOption(h.options.Timeout), - HostsChainOption(h.options.Hosts), - ResolverChainOption(h.options.Resolver), - ) - if err != nil { - log.Logf("[ssh-tcp] %s - %s : %s", h.options.Node.Addr, raddr, err) - return - } - defer conn.Close() - - log.Logf("[ssh-tcp] %s <-> %s", h.options.Node.Addr, raddr) - transport(conn, channel) - log.Logf("[ssh-tcp] %s >-< %s", h.options.Node.Addr, raddr) -} - -// tcpipForward is structure for RFC 4254 7.1 "tcpip-forward" request -type tcpipForward struct { - Host string - Port uint32 -} - -func (h *sshForwardHandler) tcpipForwardRequest(sshConn ssh.Conn, req *ssh.Request, quit <-chan struct{}) { - t := tcpipForward{} - ssh.Unmarshal(req.Payload, &t) - - addr := fmt.Sprintf("%s:%d", t.Host, t.Port) - - if !Can("rtcp", addr, h.options.Whitelist, h.options.Blacklist) { - log.Logf("[ssh-rtcp] Unauthorized to tcp bind to %s", addr) - req.Reply(false, nil) - return - } - - ln, err := net.Listen("tcp", addr) //tie to the client connection - if err != nil { - log.Log("[ssh-rtcp]", err) - req.Reply(false, nil) - return - } - defer ln.Close() - - log.Log("[ssh-rtcp] listening on tcp", ln.Addr()) - - replyFunc := func() error { - if t.Port == 0 && req.WantReply { // Client sent port 0. let them know which port is actually being used - _, port, err := getHostPortFromAddr(ln.Addr()) - if err != nil { - return err - } - var b [4]byte - binary.BigEndian.PutUint32(b[:], uint32(port)) - t.Port = uint32(port) - return req.Reply(true, b[:]) - } - return req.Reply(true, nil) - } - if err := replyFunc(); err != nil { - log.Log("[ssh-rtcp]", err) - return - } - - go func() { - for { - conn, err := ln.Accept() - if err != nil { // Unable to accept new connection - listener is likely closed - return - } - - go func(conn net.Conn) { - defer conn.Close() - - p := directForward{} - var err error - - var portnum int - p.Host1 = t.Host - p.Port1 = t.Port - p.Host2, portnum, err = getHostPortFromAddr(conn.RemoteAddr()) - if err != nil { - return - } - - p.Port2 = uint32(portnum) - ch, reqs, err := sshConn.OpenChannel(ForwardedTCPReturnRequest, ssh.Marshal(p)) - if err != nil { - log.Log("[ssh-rtcp] open forwarded channel:", err) - return - } - defer ch.Close() - go ssh.DiscardRequests(reqs) - - log.Logf("[ssh-rtcp] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - transport(ch, conn) - log.Logf("[ssh-rtcp] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) - }(conn) - } - }() - - <-quit -} - -// SSHConfig holds the SSH tunnel server config -type SSHConfig struct { - Authenticator Authenticator - TLSConfig *tls.Config - Key ssh.Signer - AuthorizedKeys map[string]bool -} - -type sshTunnelListener struct { - net.Listener - config *ssh.ServerConfig - connChan chan net.Conn - errChan chan error -} - -// SSHTunnelListener creates a Listener for SSH tunnel server. -func SSHTunnelListener(addr string, config *SSHConfig) (Listener, error) { - ln, err := net.Listen("tcp", addr) - if err != nil { - return nil, err - } - - if config == nil { - config = &SSHConfig{} - } - - sshConfig := &ssh.ServerConfig{ - PasswordCallback: defaultSSHPasswordCallback(config.Authenticator), - PublicKeyCallback: defaultSSHPublicKeyCallback(config.AuthorizedKeys), - } - - if config.Authenticator == nil && len(config.AuthorizedKeys) == 0 { - sshConfig.NoClientAuth = true - } - - signer := config.Key - if signer == nil { - signer, err = ssh.NewSignerFromKey(DefaultTLSConfig.Certificates[0].PrivateKey) - if err != nil { - ln.Close() - return nil, err - } - } - sshConfig.AddHostKey(signer) - - l := &sshTunnelListener{ - Listener: tcpKeepAliveListener{ln.(*net.TCPListener)}, - config: sshConfig, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - - go l.listenLoop() - - return l, nil -} - -func (l *sshTunnelListener) listenLoop() { - for { - conn, err := l.Listener.Accept() - if err != nil { - log.Log("[ssh] accept:", err) - l.errChan <- err - close(l.errChan) - return - } - go l.serveConn(conn) - } -} - -func (l *sshTunnelListener) serveConn(conn net.Conn) { - sc, chans, reqs, err := ssh.NewServerConn(conn, l.config) - if err != nil { - log.Logf("[ssh] %s -> %s : %s", conn.RemoteAddr(), conn.LocalAddr(), err) - conn.Close() - return - } - defer sc.Close() - - go ssh.DiscardRequests(reqs) - go func() { - for newChannel := range chans { - // Check the type of channel - t := newChannel.ChannelType() - switch t { - case GostSSHTunnelRequest: - channel, requests, err := newChannel.Accept() - if err != nil { - log.Log("[ssh] Could not accept channel:", err) - continue - } - go ssh.DiscardRequests(requests) - cc := &sshConn{conn: conn, channel: channel} - select { - case l.connChan <- cc: - default: - cc.Close() - log.Logf("[ssh] %s - %s: connection queue is full", conn.RemoteAddr(), l.Addr()) - } - - default: - log.Log("[ssh] Unknown channel type:", t) - newChannel.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t)) - } - } - }() - - log.Logf("[ssh] %s <-> %s", conn.RemoteAddr(), conn.LocalAddr()) - sc.Wait() - log.Logf("[ssh] %s >-< %s", conn.RemoteAddr(), conn.LocalAddr()) -} - -func (l *sshTunnelListener) Accept() (conn net.Conn, err error) { - var ok bool - select { - case conn = <-l.connChan: - case err, ok = <-l.errChan: - if !ok { - err = errors.New("accpet on closed listener") - } - } - return -} - -// directForward is structure for RFC 4254 7.2 - can be used for "forwarded-tcpip" and "direct-tcpip" -type directForward struct { - Host1 string - Port1 uint32 - Host2 string - Port2 uint32 -} - -func (p directForward) String() string { - return fmt.Sprintf("%s:%d -> %s:%d", p.Host2, p.Port2, p.Host1, p.Port1) -} - -func getHostPortFromAddr(addr net.Addr) (host string, port int, err error) { - host, portString, err := net.SplitHostPort(addr.String()) - if err != nil { - return - } - port, err = strconv.Atoi(portString) - return -} - -// PasswordCallbackFunc is a callback function used by SSH server. -// It authenticates user using a password. -type PasswordCallbackFunc func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) - -func defaultSSHPasswordCallback(au Authenticator) PasswordCallbackFunc { - if au == nil { - return nil - } - return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { - if au.Authenticate(conn.User(), string(password)) { - return nil, nil - } - log.Logf("[ssh] %s -> %s : password rejected for %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User()) - return nil, fmt.Errorf("password rejected for %s", conn.User()) - } -} - -// PublicKeyCallbackFunc is a callback function used by SSH server. -// It offers a public key for authentication. -type PublicKeyCallbackFunc func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) - -func defaultSSHPublicKeyCallback(keys map[string]bool) PublicKeyCallbackFunc { - if len(keys) == 0 { - return nil - } - - return func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { - if keys[string(pubKey.Marshal())] { - return &ssh.Permissions{ - // Record the public key used for authentication. - Extensions: map[string]string{ - "pubkey-fp": ssh.FingerprintSHA256(pubKey), - }, - }, nil - } - return nil, fmt.Errorf("unknown public key for %q", c.User()) - } -} - -type sshNopConn struct { - session *sshSession -} - -func (c *sshNopConn) Read(b []byte) (n int, err error) { - return 0, &net.OpError{Op: "read", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("read not supported")} -} - -func (c *sshNopConn) Write(b []byte) (n int, err error) { - return 0, &net.OpError{Op: "write", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("write not supported")} -} - -func (c *sshNopConn) Close() error { - return nil -} - -func (c *sshNopConn) LocalAddr() net.Addr { - return &net.TCPAddr{ - IP: net.IPv4zero, - Port: 0, - } -} - -func (c *sshNopConn) RemoteAddr() net.Addr { - return &net.TCPAddr{ - IP: net.IPv4zero, - Port: 0, - } -} - -func (c *sshNopConn) SetDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *sshNopConn) SetReadDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -func (c *sshNopConn) SetWriteDeadline(t time.Time) error { - return &net.OpError{Op: "set", Net: "ssh", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} -} - -type sshConn struct { - channel ssh.Channel - conn net.Conn -} - -func (c *sshConn) Read(b []byte) (n int, err error) { - return c.channel.Read(b) -} - -func (c *sshConn) Write(b []byte) (n int, err error) { - return c.channel.Write(b) -} - -func (c *sshConn) Close() error { - return c.channel.Close() -} - -func (c *sshConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *sshConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *sshConn) SetDeadline(t time.Time) error { - return c.conn.SetDeadline(t) -} - -func (c *sshConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *sshConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} diff --git a/gost/ssh_test.go b/gost/ssh_test.go deleted file mode 100644 index 20d24f02..00000000 --- a/gost/ssh_test.go +++ /dev/null @@ -1,581 +0,0 @@ -package gost - -import ( - "crypto/rand" - "crypto/tls" - "fmt" - "net" - "net/http/httptest" - "net/url" - "testing" -) - -func sshDirectForwardRoundtrip(targetURL string, data []byte) error { - ln, err := TCPListener("") - if err != nil { - return err - } - - client := &Client{ - Connector: SSHDirectForwardConnector(), - Transporter: SSHForwardTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SSHForwardHandler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSHDirectForward(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := sshDirectForwardRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} - -func BenchmarkSSHDirectForward(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := TCPListener("") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: SSHDirectForwardConnector(), - Transporter: SSHForwardTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SSHForwardHandler(), - } - - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkSSHDirectForwardParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := TCPListener("") - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: SSHDirectForwardConnector(), - Transporter: SSHForwardTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SSHForwardHandler(), - } - - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func sshRemoteForwardRoundtrip(t *testing.T, targetURL string, data []byte) (err error) { - ln, err := TCPListener("") - if err != nil { - return - } - - client := &Client{ - Connector: SSHRemoteForwardConnector(), - Transporter: SSHForwardTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SSHForwardHandler(), - } - - go server.Run() - defer server.Close() - - conn, err := proxyConn(client, server) - if err != nil { - return - } - defer conn.Close() - - go func() { - conn, err = client.Connect(conn, ":0") - if err != nil { - return - } - }() - - c, err := net.Dial("tcp", conn.LocalAddr().String()) - if err != nil { - return - } - defer c.Close() - - u, err := url.Parse(targetURL) - if err != nil { - return - } - - cc, err := net.Dial("tcp", u.Host) - if err != nil { - return - } - defer cc.Close() - - go transport(conn, cc) - - t.Log("httpRoundtrip") - return httpRoundtrip(c, targetURL, data) -} - -// TODO: fix this test -func _TestSSHRemoteForward(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := sshRemoteForwardRoundtrip(t, httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} - -func httpOverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := SSHTunnelListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverSSHTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverSSHTunnelRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverSSHTunnel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := SSHTunnelListener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverSSHTunnelParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := SSHTunnelListener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := SSHTunnelListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverSSHTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverSSHTunnelRoundtrip(httpSrv.URL, sendData, - nil, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { - ln, err := SSHTunnelListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverSSHTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverSSHTunnelRoundtrip(httpSrv.URL, sendData, nil) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config) error { - ln, err := SSHTunnelListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverSSHTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverSSHTunnelRoundtrip(httpSrv.URL, sendData, nil) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverSSHTunnelRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := SSHTunnelListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverSSHTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverSSHTunnelRoundtrip(httpSrv.URL, sendData, - nil, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverSSHTunnelRoundtrip(targetURL string, data []byte, host string) error { - ln, err := SSHTunnelListener("", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverSSHTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverSSHTunnelRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func sshForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := SSHTunnelListener("", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: SSHTunnelTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSHForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := sshForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} diff --git a/gost/tls_test.go b/gost/tls_test.go index 1d74ce86..cfca4b28 100644 --- a/gost/tls_test.go +++ b/gost/tls_test.go @@ -9,89 +9,6 @@ import ( "testing" ) -func httpOverTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := TLSListener("", tlsConfig) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: TLSTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverTLS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverTLSRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverTLS(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := TLSListener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: TLSTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - func BenchmarkHTTPOverTLSParallel(b *testing.B) { httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -303,33 +220,6 @@ func TestSSOverTLS(t *testing.T) { } } -func sniOverTLSRoundtrip(targetURL string, data []byte, host string) error { - ln, err := TLSListener("", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: TLSTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - func TestSNIOverTLS(t *testing.T) { httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() @@ -409,56 +299,6 @@ func TestTLSForwardTunnel(t *testing.T) { } } -func httpOverMTLSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := MTLSListener("", tlsConfig) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: MTLSTransporter(), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverMTLS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverMTLSRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - func BenchmarkHTTPOverMTLS(b *testing.B) { httpSrv := httptest.NewServer(httpTestHandler) defer httpSrv.Close() diff --git a/gost/ws.go b/gost/ws.go deleted file mode 100644 index 9dc8f0dc..00000000 --- a/gost/ws.go +++ /dev/null @@ -1,806 +0,0 @@ -package gost - -import ( - "crypto/rand" - "crypto/sha1" - "crypto/tls" - "encoding/base64" - "io" - "net" - "net/http" - "net/http/httputil" - "sync" - "time" - - "net/url" - - "github.com/go-log/log" - "github.com/gorilla/websocket" - smux "github.com/xtaci/smux" -) - -const ( - defaultWSPath = "/ws" -) - -// WSOptions describes the options for websocket. -type WSOptions struct { - ReadBufferSize int - WriteBufferSize int - HandshakeTimeout time.Duration - EnableCompression bool - UserAgent string - Path string -} - -type wsTransporter struct { - tcpTransporter - options *WSOptions -} - -// WSTransporter creates a Transporter that is used by websocket proxy client. -func WSTransporter(opts *WSOptions) Transporter { - return &wsTransporter{ - options: opts, - } -} - -func (tr *wsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - wsOptions := tr.options - if opts.WSOptions != nil { - wsOptions = opts.WSOptions - } - if wsOptions == nil { - wsOptions = &WSOptions{} - } - - path := wsOptions.Path - if path == "" { - path = defaultWSPath - } - url := url.URL{Scheme: "ws", Host: opts.Host, Path: path} - return websocketClientConn(url.String(), conn, nil, wsOptions) -} - -type mwsTransporter struct { - tcpTransporter - options *WSOptions - sessions map[string]*muxSession - sessionMutex sync.Mutex -} - -// MWSTransporter creates a Transporter that is used by multiplex-websocket proxy client. -func MWSTransporter(opts *WSOptions) Transporter { - return &mwsTransporter{ - options: opts, - sessions: make(map[string]*muxSession), - } -} - -func (tr *mwsTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if session != nil && session.IsClosed() { - delete(tr.sessions, addr) - ok = false - } - if !ok { - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - - if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, timeout) - } else { - conn, err = opts.Chain.Dial(addr) - } - if err != nil { - return - } - session = &muxSession{conn: conn} - tr.sessions[addr] = session - } - return session.conn, nil -} - -func (tr *mwsTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = HandshakeTimeout - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - session, ok := tr.sessions[opts.Addr] - if !ok || session.session == nil { - s, err := tr.initSession(opts.Addr, conn, opts) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - session = s - tr.sessions[opts.Addr] = session - } - - cc, err := session.GetConn() - if err != nil { - session.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - return cc, nil -} - -func (tr *mwsTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) { - if opts == nil { - opts = &HandshakeOptions{} - } - wsOptions := tr.options - if opts.WSOptions != nil { - wsOptions = opts.WSOptions - } - if wsOptions == nil { - wsOptions = &WSOptions{} - } - - path := wsOptions.Path - if path == "" { - path = defaultWSPath - } - url := url.URL{Scheme: "ws", Host: opts.Host, Path: path} - conn, err := websocketClientConn(url.String(), conn, nil, wsOptions) - if err != nil { - return nil, err - } - // stream multiplex - smuxConfig := smux.DefaultConfig() - session, err := smux.Client(conn, smuxConfig) - if err != nil { - return nil, err - } - return &muxSession{conn: conn, session: session}, nil -} - -func (tr *mwsTransporter) Multiplex() bool { - return true -} - -type wssTransporter struct { - tcpTransporter - options *WSOptions -} - -// WSSTransporter creates a Transporter that is used by websocket secure proxy client. -func WSSTransporter(opts *WSOptions) Transporter { - return &wssTransporter{ - options: opts, - } -} - -func (tr *wssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - wsOptions := tr.options - if opts.WSOptions != nil { - wsOptions = opts.WSOptions - } - if wsOptions == nil { - wsOptions = &WSOptions{} - } - - if opts.TLSConfig == nil { - opts.TLSConfig = &tls.Config{InsecureSkipVerify: true} - } - path := wsOptions.Path - if path == "" { - path = defaultWSPath - } - url := url.URL{Scheme: "wss", Host: opts.Host, Path: path} - return websocketClientConn(url.String(), conn, opts.TLSConfig, wsOptions) -} - -type mwssTransporter struct { - tcpTransporter - options *WSOptions - sessions map[string]*muxSession - sessionMutex sync.Mutex -} - -// MWSSTransporter creates a Transporter that is used by multiplex-websocket secure proxy client. -func MWSSTransporter(opts *WSOptions) Transporter { - return &mwssTransporter{ - options: opts, - sessions: make(map[string]*muxSession), - } -} - -func (tr *mwssTransporter) Dial(addr string, options ...DialOption) (conn net.Conn, err error) { - opts := &DialOptions{} - for _, option := range options { - option(opts) - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - session, ok := tr.sessions[addr] - if session != nil && session.IsClosed() { - delete(tr.sessions, addr) - ok = false - } - if !ok { - timeout := opts.Timeout - if timeout <= 0 { - timeout = DialTimeout - } - - if opts.Chain == nil { - conn, err = net.DialTimeout("tcp", addr, timeout) - } else { - conn, err = opts.Chain.Dial(addr) - } - if err != nil { - return - } - session = &muxSession{conn: conn} - tr.sessions[addr] = session - } - return session.conn, nil -} - -func (tr *mwssTransporter) Handshake(conn net.Conn, options ...HandshakeOption) (net.Conn, error) { - opts := &HandshakeOptions{} - for _, option := range options { - option(opts) - } - - timeout := opts.Timeout - if timeout <= 0 { - timeout = HandshakeTimeout - } - - tr.sessionMutex.Lock() - defer tr.sessionMutex.Unlock() - - conn.SetDeadline(time.Now().Add(timeout)) - defer conn.SetDeadline(time.Time{}) - - session, ok := tr.sessions[opts.Addr] - if !ok || session.session == nil { - s, err := tr.initSession(opts.Addr, conn, opts) - if err != nil { - conn.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - session = s - tr.sessions[opts.Addr] = session - } - cc, err := session.GetConn() - if err != nil { - session.Close() - delete(tr.sessions, opts.Addr) - return nil, err - } - return cc, nil -} - -func (tr *mwssTransporter) initSession(addr string, conn net.Conn, opts *HandshakeOptions) (*muxSession, error) { - if opts == nil { - opts = &HandshakeOptions{} - } - wsOptions := tr.options - if opts.WSOptions != nil { - wsOptions = opts.WSOptions - } - if wsOptions == nil { - wsOptions = &WSOptions{} - } - - tlsConfig := opts.TLSConfig - if tlsConfig == nil { - tlsConfig = &tls.Config{InsecureSkipVerify: true} - } - path := wsOptions.Path - if path == "" { - path = defaultWSPath - } - url := url.URL{Scheme: "wss", Host: opts.Host, Path: path} - conn, err := websocketClientConn(url.String(), conn, tlsConfig, wsOptions) - if err != nil { - return nil, err - } - // stream multiplex - smuxConfig := smux.DefaultConfig() - session, err := smux.Client(conn, smuxConfig) - if err != nil { - return nil, err - } - return &muxSession{conn: conn, session: session}, nil -} - -func (tr *mwssTransporter) Multiplex() bool { - return true -} - -type wsListener struct { - addr net.Addr - upgrader *websocket.Upgrader - srv *http.Server - connChan chan net.Conn - errChan chan error -} - -// WSListener creates a Listener for websocket proxy server. -func WSListener(addr string, options *WSOptions) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - if options == nil { - options = &WSOptions{} - } - l := &wsListener{ - upgrader: &websocket.Upgrader{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: options.EnableCompression, - }, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - - path := options.Path - if path == "" { - path = defaultWSPath - } - mux := http.NewServeMux() - mux.Handle(path, http.HandlerFunc(l.upgrade)) - l.srv = &http.Server{ - Addr: addr, - Handler: mux, - ReadHeaderTimeout: 30 * time.Second, - } - - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - l.addr = ln.Addr() - - go func() { - err := l.srv.Serve(tcpKeepAliveListener{ln}) - if err != nil { - l.errChan <- err - } - close(l.errChan) - }() - select { - case err := <-l.errChan: - return nil, err - default: - } - - return l, nil -} - -func (l *wsListener) upgrade(w http.ResponseWriter, r *http.Request) { - log.Logf("[ws] %s -> %s", r.RemoteAddr, l.addr) - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Log(string(dump)) - } - conn, err := l.upgrader.Upgrade(w, r, nil) - if err != nil { - log.Logf("[ws] %s - %s : %s", r.RemoteAddr, l.addr, err) - return - } - select { - case l.connChan <- websocketServerConn(conn): - default: - conn.Close() - log.Logf("[ws] %s - %s: connection queue is full", r.RemoteAddr, l.addr) - } -} - -func (l *wsListener) Accept() (conn net.Conn, err error) { - select { - case conn = <-l.connChan: - case err = <-l.errChan: - } - return -} - -func (l *wsListener) Close() error { - return l.srv.Close() -} - -func (l *wsListener) Addr() net.Addr { - return l.addr -} - -type mwsListener struct { - addr net.Addr - upgrader *websocket.Upgrader - srv *http.Server - connChan chan net.Conn - errChan chan error -} - -// MWSListener creates a Listener for multiplex-websocket proxy server. -func MWSListener(addr string, options *WSOptions) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - if options == nil { - options = &WSOptions{} - } - l := &mwsListener{ - upgrader: &websocket.Upgrader{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: options.EnableCompression, - }, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - } - - path := options.Path - if path == "" { - path = defaultWSPath - } - - mux := http.NewServeMux() - mux.Handle(path, http.HandlerFunc(l.upgrade)) - l.srv = &http.Server{ - Addr: addr, - Handler: mux, - ReadHeaderTimeout: 30 * time.Second, - } - - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - l.addr = ln.Addr() - - go func() { - err := l.srv.Serve(tcpKeepAliveListener{ln}) - if err != nil { - l.errChan <- err - } - close(l.errChan) - }() - select { - case err := <-l.errChan: - return nil, err - default: - } - - return l, nil -} - -func (l *mwsListener) upgrade(w http.ResponseWriter, r *http.Request) { - log.Logf("[mws] %s -> %s", r.RemoteAddr, l.addr) - if Debug { - dump, _ := httputil.DumpRequest(r, false) - log.Log(string(dump)) - } - conn, err := l.upgrader.Upgrade(w, r, nil) - if err != nil { - log.Logf("[mws] %s - %s : %s", r.RemoteAddr, l.addr, err) - return - } - - l.mux(websocketServerConn(conn)) -} - -func (l *mwsListener) mux(conn net.Conn) { - smuxConfig := smux.DefaultConfig() - mux, err := smux.Server(conn, smuxConfig) - if err != nil { - log.Logf("[mws] %s - %s : %s", conn.RemoteAddr(), l.Addr(), err) - return - } - defer mux.Close() - - log.Logf("[mws] %s <-> %s", conn.RemoteAddr(), l.Addr()) - defer log.Logf("[mws] %s >-< %s", conn.RemoteAddr(), l.Addr()) - - for { - stream, err := mux.AcceptStream() - if err != nil { - log.Log("[mws] accept stream:", err) - return - } - - cc := &muxStreamConn{Conn: conn, stream: stream} - select { - case l.connChan <- cc: - default: - cc.Close() - log.Logf("[mws] %s - %s: connection queue is full", conn.RemoteAddr(), conn.LocalAddr()) - } - } -} - -func (l *mwsListener) Accept() (conn net.Conn, err error) { - select { - case conn = <-l.connChan: - case err = <-l.errChan: - } - return -} - -func (l *mwsListener) Close() error { - return l.srv.Close() -} - -func (l *mwsListener) Addr() net.Addr { - return l.addr -} - -type wssListener struct { - *wsListener -} - -// WSSListener creates a Listener for websocket secure proxy server. -func WSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - if options == nil { - options = &WSOptions{} - } - l := &wssListener{ - wsListener: &wsListener{ - upgrader: &websocket.Upgrader{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: options.EnableCompression, - }, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - }, - } - - if tlsConfig == nil { - tlsConfig = DefaultTLSConfig - } - - path := options.Path - if path == "" { - path = defaultWSPath - } - - mux := http.NewServeMux() - mux.Handle(path, http.HandlerFunc(l.upgrade)) - l.srv = &http.Server{ - Addr: addr, - TLSConfig: tlsConfig, - Handler: mux, - ReadHeaderTimeout: 30 * time.Second, - } - - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - l.addr = ln.Addr() - - go func() { - err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) - if err != nil { - l.errChan <- err - } - close(l.errChan) - }() - select { - case err := <-l.errChan: - return nil, err - default: - } - - return l, nil -} - -type mwssListener struct { - *mwsListener -} - -// MWSSListener creates a Listener for multiplex-websocket secure proxy server. -func MWSSListener(addr string, tlsConfig *tls.Config, options *WSOptions) (Listener, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", addr) - if err != nil { - return nil, err - } - if options == nil { - options = &WSOptions{} - } - l := &mwssListener{ - mwsListener: &mwsListener{ - upgrader: &websocket.Upgrader{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - CheckOrigin: func(r *http.Request) bool { return true }, - EnableCompression: options.EnableCompression, - }, - connChan: make(chan net.Conn, 1024), - errChan: make(chan error, 1), - }, - } - - if tlsConfig == nil { - tlsConfig = DefaultTLSConfig - } - - path := options.Path - if path == "" { - path = defaultWSPath - } - - mux := http.NewServeMux() - mux.Handle(path, http.HandlerFunc(l.upgrade)) - l.srv = &http.Server{ - Addr: addr, - TLSConfig: tlsConfig, - Handler: mux, - ReadHeaderTimeout: 30 * time.Second, - } - - ln, err := net.ListenTCP("tcp", tcpAddr) - if err != nil { - return nil, err - } - l.addr = ln.Addr() - - go func() { - err := l.srv.Serve(tls.NewListener(tcpKeepAliveListener{ln}, tlsConfig)) - if err != nil { - l.errChan <- err - } - close(l.errChan) - }() - select { - case err := <-l.errChan: - return nil, err - default: - } - - return l, nil -} - -var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -func computeAcceptKey(challengeKey string) string { - h := sha1.New() - h.Write([]byte(challengeKey)) - h.Write(keyGUID) - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - -func generateChallengeKey() (string, error) { - p := make([]byte, 16) - if _, err := io.ReadFull(rand.Reader, p); err != nil { - return "", err - } - return base64.StdEncoding.EncodeToString(p), nil -} - -// TODO: due to the concurrency control in the websocket.Conn, -// a data race may be met when using with multiplexing. -// See: https://godoc.org/gopkg.in/gorilla/websocket.v1#hdr-Concurrency -type websocketConn struct { - conn *websocket.Conn - rb []byte -} - -func websocketClientConn(url string, conn net.Conn, tlsConfig *tls.Config, options *WSOptions) (net.Conn, error) { - if options == nil { - options = &WSOptions{} - } - - timeout := options.HandshakeTimeout - if timeout <= 0 { - timeout = HandshakeTimeout - } - - dialer := websocket.Dialer{ - ReadBufferSize: options.ReadBufferSize, - WriteBufferSize: options.WriteBufferSize, - TLSClientConfig: tlsConfig, - HandshakeTimeout: timeout, - EnableCompression: options.EnableCompression, - NetDial: func(net, addr string) (net.Conn, error) { - return conn, nil - }, - } - header := http.Header{} - header.Set("User-Agent", DefaultUserAgent) - if options.UserAgent != "" { - header.Set("User-Agent", options.UserAgent) - } - c, resp, err := dialer.Dial(url, header) - if err != nil { - return nil, err - } - resp.Body.Close() - return &websocketConn{conn: c}, nil -} - -func websocketServerConn(conn *websocket.Conn) net.Conn { - // conn.EnableWriteCompression(true) - return &websocketConn{ - conn: conn, - } -} - -func (c *websocketConn) Read(b []byte) (n int, err error) { - if len(c.rb) == 0 { - _, c.rb, err = c.conn.ReadMessage() - } - n = copy(b, c.rb) - c.rb = c.rb[n:] - return -} - -func (c *websocketConn) Write(b []byte) (n int, err error) { - err = c.conn.WriteMessage(websocket.BinaryMessage, b) - n = len(b) - return -} - -func (c *websocketConn) Close() error { - return c.conn.Close() -} - -func (c *websocketConn) LocalAddr() net.Addr { - return c.conn.LocalAddr() -} - -func (c *websocketConn) RemoteAddr() net.Addr { - return c.conn.RemoteAddr() -} - -func (c *websocketConn) SetDeadline(t time.Time) error { - if err := c.SetReadDeadline(t); err != nil { - return err - } - return c.SetWriteDeadline(t) -} -func (c *websocketConn) SetReadDeadline(t time.Time) error { - return c.conn.SetReadDeadline(t) -} - -func (c *websocketConn) SetWriteDeadline(t time.Time) error { - return c.conn.SetWriteDeadline(t) -} diff --git a/gost/ws_test.go b/gost/ws_test.go deleted file mode 100644 index 11899c92..00000000 --- a/gost/ws_test.go +++ /dev/null @@ -1,808 +0,0 @@ -package gost - -import ( - "crypto/rand" - "fmt" - "net/http/httptest" - "net/url" - "testing" -) - -func httpOverWSRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := WSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverWSRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverWS(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := WSListener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverWSParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := WSListener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverWSRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := WSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverWSRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverWSRoundtrip(targetURL string, data []byte) error { - ln, err := WSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverWSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverWSRoundtrip(targetURL string, data []byte) error { - ln, err := WSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverWS(t *testing.T) { - - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverWSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverWSRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := WSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverWSRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverWSRoundtrip(targetURL string, data []byte, host string) error { - ln, err := WSListener("", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverWSRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func wsForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := WSListener("", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: WSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestWSForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := wsForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} - -func httpOverMWSRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := MWSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverMWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverMWSRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverMWS(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := MWSListener("", nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverMWSParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := MWSListener("", nil) - if err != nil { - b.Error(err) - } - - b.Log(ln.Addr()) - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverMWSRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := MWSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverMWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverMWSRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverMWSRoundtrip(targetURL string, data []byte) error { - ln, err := MWSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverMWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverMWSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverMWSRoundtrip(targetURL string, data []byte) error { - ln, err := MWSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverMWS(t *testing.T) { - - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverMWSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverMWSRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := MWSListener("", nil) - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverMWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverMWSRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverMWSRoundtrip(targetURL string, data []byte, host string) error { - ln, err := MWSListener("", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverMWS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverMWSRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func mwsForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := MWSListener("", nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: MWSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestMWSForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := mwsForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} diff --git a/gost/wss_test.go b/gost/wss_test.go deleted file mode 100644 index a2134807..00000000 --- a/gost/wss_test.go +++ /dev/null @@ -1,809 +0,0 @@ -package gost - -import ( - "crypto/rand" - "crypto/tls" - "fmt" - "net/http/httptest" - "net/url" - "testing" -) - -func httpOverWSSRoundtrip(targetURL string, data []byte, tlsConfig *tls.Config, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := WSSListener("", tlsConfig, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverWSSRoundtrip(httpSrv.URL, sendData, nil, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverWSS(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := WSSListener("", nil, nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverWSSParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := WSSListener("", nil, nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverWSSRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := WSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverWSSRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverWSSRoundtrip(targetURL string, data []byte) error { - ln, err := WSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverWSSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverWSSRoundtrip(targetURL string, data []byte) error { - ln, err := WSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverWSS(t *testing.T) { - - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverWSSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverWSSRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := WSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverWSSRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverWSSRoundtrip(targetURL string, data []byte, host string) error { - ln, err := WSSListener("", nil, nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverWSSRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func wssForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := WSSListener("", nil, nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: WSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestWSSForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := wssForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} - -func httpOverMWSSRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := MWSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: HTTPConnector(clientInfo), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestHTTPOverMWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range httpProxyTests { - err := httpOverMWSSRoundtrip(httpSrv.URL, sendData, tc.cliUser, tc.srvUsers) - if err == nil { - if tc.errStr != "" { - t.Errorf("#%d should failed with error %s", i, tc.errStr) - } - } else { - if tc.errStr == "" { - t.Errorf("#%d got error %v", i, err) - } - if err.Error() != tc.errStr { - t.Errorf("#%d got error %v, want %v", i, err, tc.errStr) - } - } - } -} - -func BenchmarkHTTPOverMWSS(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := MWSSListener("", nil, nil) - if err != nil { - b.Error(err) - } - - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - for i := 0; i < b.N; i++ { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } -} - -func BenchmarkHTTPOverMWSSParallel(b *testing.B) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - ln, err := MWSSListener("", nil, nil) - if err != nil { - b.Error(err) - } - - b.Log(ln.Addr()) - client := &Client{ - Connector: HTTPConnector(url.UserPassword("admin", "123456")), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: HTTPHandler( - UsersHandlerOption(url.UserPassword("admin", "123456")), - ), - } - go server.Run() - defer server.Close() - - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - if err := proxyRoundtrip(client, server, httpSrv.URL, sendData); err != nil { - b.Error(err) - } - } - }) -} - -func socks5OverMWSSRoundtrip(targetURL string, data []byte, - clientInfo *url.Userinfo, serverInfo []*url.Userinfo) error { - - ln, err := MWSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS5Connector(clientInfo), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS5Handler( - UsersHandlerOption(serverInfo...), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS5OverMWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range socks5ProxyTests { - err := socks5OverMWSSRoundtrip(httpSrv.URL, sendData, - tc.cliUser, - tc.srvUsers, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func socks4OverMWSSRoundtrip(targetURL string, data []byte) error { - ln, err := MWSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4Connector(), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4OverMWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4OverMWSSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func socks4aOverMWSSRoundtrip(targetURL string, data []byte) error { - ln, err := MWSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: SOCKS4AConnector(), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SOCKS4Handler(), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSOCKS4AOverMWSS(t *testing.T) { - - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := socks4aOverMWSSRoundtrip(httpSrv.URL, sendData) - // t.Logf("#%d %v", i, err) - if err != nil { - t.Errorf("got error: %v", err) - } -} - -func ssOverMWSSRoundtrip(targetURL string, data []byte, - clientInfo, serverInfo *url.Userinfo) error { - - ln, err := MWSSListener("", nil, nil) - if err != nil { - return err - } - - client := &Client{ - Connector: ShadowConnector(clientInfo), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: ShadowHandler( - UsersHandlerOption(serverInfo), - ), - } - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestSSOverMWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - for i, tc := range ssProxyTests { - err := ssOverMWSSRoundtrip(httpSrv.URL, sendData, - tc.clientCipher, - tc.serverCipher, - ) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - } -} - -func sniOverMWSSRoundtrip(targetURL string, data []byte, host string) error { - ln, err := MWSSListener("", nil, nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: SNIConnector(host), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: SNIHandler(HostHandlerOption(u.Host)), - } - - go server.Run() - defer server.Close() - - return sniRoundtrip(client, server, targetURL, data) -} - -func TestSNIOverMWSS(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - httpsSrv := httptest.NewTLSServer(httpTestHandler) - defer httpsSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - var sniProxyTests = []struct { - targetURL string - host string - pass bool - }{ - {httpSrv.URL, "", true}, - {httpSrv.URL, "example.com", true}, - {httpsSrv.URL, "", true}, - {httpsSrv.URL, "example.com", true}, - } - - for i, tc := range sniProxyTests { - tc := tc - t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { - err := sniOverMWSSRoundtrip(tc.targetURL, sendData, tc.host) - if err == nil { - if !tc.pass { - t.Errorf("#%d should failed", i) - } - } else { - // t.Logf("#%d %v", i, err) - if tc.pass { - t.Errorf("#%d got error: %v", i, err) - } - } - }) - } -} - -func mwssForwardTunnelRoundtrip(targetURL string, data []byte) error { - ln, err := MWSSListener("", nil, nil) - if err != nil { - return err - } - - u, err := url.Parse(targetURL) - if err != nil { - return err - } - - client := &Client{ - Connector: ForwardConnector(), - Transporter: MWSSTransporter(nil), - } - - server := &Server{ - Listener: ln, - Handler: TCPDirectForwardHandler(u.Host), - } - server.Handler.Init() - - go server.Run() - defer server.Close() - - return proxyRoundtrip(client, server, targetURL, data) -} - -func TestMWSSForwardTunnel(t *testing.T) { - httpSrv := httptest.NewServer(httpTestHandler) - defer httpSrv.Close() - - sendData := make([]byte, 128) - rand.Read(sendData) - - err := mwssForwardTunnelRoundtrip(httpSrv.URL, sendData) - if err != nil { - t.Error(err) - } -} diff --git a/pkg/route.go b/pkg/route.go index 3a88e78a..b2060ea2 100644 --- a/pkg/route.go +++ b/pkg/route.go @@ -62,30 +62,6 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { } timeout := node.GetDuration("timeout") - var tr gost.Transporter - switch node.Transport { - case "ssh": - if node.Protocol == "direct" || node.Protocol == "remote" { - tr = gost.SSHForwardTransporter() - } else { - tr = gost.SSHTunnelTransporter() - } - default: - tr = gost.TCPTransporter() - } - - var connector gost.Connector - switch node.Protocol { - case "ssu": - connector = gost.ShadowUDPConnector(node.User) - case "direct": - connector = gost.SSHDirectForwardConnector() - case "remote": - connector = gost.SSHRemoteForwardConnector() - default: - connector = gost.AutoConnector(node.User) - } - host := node.Get("host") if host == "" { host = node.Host @@ -111,8 +87,8 @@ func parseChainNode(ns string) (nodes []gost.Node, err error) { } node.Client = &gost.Client{ - Connector: connector, - Transporter: tr, + Connector: gost.AutoConnector(node.User), + Transporter: gost.TCPTransporter(), } ips := parseIP(node.Get("ip"), sport) @@ -157,10 +133,6 @@ func (r *route) GenRouters() (*router, error) { switch node.Transport { case "tcp": // Directly use SSH port forwarding if the last chain node is forward+ssh - if chain.LastNode().Protocol == "forward" && chain.LastNode().Transport == "ssh" { - chain.Nodes()[len(chain.Nodes())-1].Client.Connector = gost.SSHDirectForwardConnector() - chain.Nodes()[len(chain.Nodes())-1].Client.Transporter = gost.SSHForwardTransporter() - } ln, err = gost.TCPListener(node.Addr) case "udp": ln, err = gost.UDPListener(node.Addr, &gost.UDPListenConfig{ @@ -204,8 +176,6 @@ func (r *route) GenRouters() (*router, error) { handler = gost.TunHandler() case "tap": handler = gost.TapHandler() - case "dns": - handler = gost.DNSHandler(node.Remote) default: // start from 2.5, if remote is not empty, then we assume that it is a forward tunnel. if node.Remote != "" { @@ -236,12 +206,10 @@ func (r *route) GenRouters() (*router, error) { } type router struct { - node gost.Node - server *gost.Server - handler gost.Handler - chain *gost.Chain - resolver gost.Resolver - hosts *gost.Hosts + node gost.Node + server *gost.Server + handler gost.Handler + chain *gost.Chain } func (r *router) Serve() error {