mirror of
https://github.com/telanflow/mps.git
synced 2025-09-26 20:41:25 +08:00
A simple example of cascading proxy
This commit is contained in:
@@ -13,9 +13,11 @@ It support HTTP, HTTPS, Websocket, ForwardProxy, ReverseProxy, MitmProxy
|
||||
- [X] Http Proxy
|
||||
- [X] Https Proxy
|
||||
- [X] Forward Proxy
|
||||
- [X] Reverse Proxy
|
||||
- [X] Tunnel Proxy
|
||||
- [ ] Reverse Proxy
|
||||
- [X] Mitm Proxy (Man-in-the-middle)
|
||||
- [ ] WekSocket Proxy
|
||||
- [ ] Socks5 Proxy
|
||||
|
||||
## 🧰 Install
|
||||
|
||||
|
@@ -13,9 +13,11 @@ MPS 是一个中间代理服务扩展库。
|
||||
- [X] Http代理
|
||||
- [X] Https代理
|
||||
- [X] 正向代理
|
||||
- [X] 反向代理
|
||||
- [X] 隧道代理
|
||||
- [ ] 反向代理
|
||||
- [X] 中间人代理 (MITM)
|
||||
- [ ] WekSocket代理
|
||||
- [ ] Socks5代理
|
||||
|
||||
## 🧰 安装
|
||||
|
||||
|
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
var DefaultMemProvider = NewMemProvider()
|
||||
|
||||
// MemProvider A simple in-memory certificate cache
|
||||
type MemProvider struct {
|
||||
cache map[string]*tls.Certificate
|
||||
|
113
context.go
113
context.go
@@ -3,11 +3,15 @@ package mps
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Http method not support
|
||||
var MethodNotSupportErr = errors.New("request method not support")
|
||||
|
||||
// Context for the request
|
||||
// which contains Middleware, Transport, and other values
|
||||
type Context struct {
|
||||
@@ -23,9 +27,12 @@ type Context struct {
|
||||
// Transport is used for global HTTP requests, and it will be reused.
|
||||
Transport *http.Transport
|
||||
|
||||
// In some cases it is not always necessary to remove the Proxy Header.
|
||||
// In some cases it is not always necessary to remove the proxy headers.
|
||||
// For example, cascade proxy
|
||||
KeepHeader bool
|
||||
KeepProxyHeaders bool
|
||||
|
||||
// In some cases it is not always necessary to reset the headers.
|
||||
KeepClientHeaders bool
|
||||
|
||||
// KeepDestinationHeaders indicates the proxy should retain any headers
|
||||
// present in the http.Response before proxying
|
||||
@@ -39,9 +46,11 @@ type Context struct {
|
||||
middlewares []Middleware
|
||||
}
|
||||
|
||||
// Create a Context
|
||||
func NewContext() *Context {
|
||||
return &Context{
|
||||
Context: context.Background(),
|
||||
// Cannot reuse one Transport because multiple proxy can collide with each other
|
||||
Transport: &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
@@ -58,31 +67,44 @@ func NewContext() *Context {
|
||||
},
|
||||
Request: nil,
|
||||
Response: nil,
|
||||
KeepHeader: false,
|
||||
KeepProxyHeaders: false,
|
||||
KeepClientHeaders: false,
|
||||
KeepDestinationHeaders: false,
|
||||
mi: -1,
|
||||
middlewares: make([]Middleware, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Use registers an Middleware to proxy
|
||||
func (ctx *Context) Use(middleware ...Middleware) {
|
||||
if ctx.middlewares == nil {
|
||||
ctx.middlewares = make([]Middleware, 0)
|
||||
}
|
||||
|
||||
ctx.middlewares = append(ctx.middlewares, middleware...)
|
||||
}
|
||||
|
||||
// UseFunc registers an MiddlewareFunc to proxy
|
||||
func (ctx *Context) UseFunc(fns ...MiddlewareFunc) {
|
||||
if ctx.middlewares == nil {
|
||||
ctx.middlewares = make([]Middleware, 0)
|
||||
}
|
||||
|
||||
for _, fn := range fns {
|
||||
ctx.middlewares = append(ctx.middlewares, fn)
|
||||
}
|
||||
}
|
||||
|
||||
// Next to exec middlewares
|
||||
// Execute the next middleware as a linked list. "ctx.Next(req)"
|
||||
// eg:
|
||||
// func Handle(req *http.Request, ctx *Context) (*http.Response, error) {
|
||||
// // You can do anything to modify the http.Request ...
|
||||
// resp, err := ctx.Next(req)
|
||||
// // You can do anything to modify the http.Response ...
|
||||
// return resp, err
|
||||
// }
|
||||
//
|
||||
// Alternatively, you can simply return the response without executing `ctx.Next()`,
|
||||
// which will interrupt subsequent middleware execution.
|
||||
func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
|
||||
var (
|
||||
total = len(ctx.middlewares)
|
||||
@@ -91,7 +113,12 @@ func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
|
||||
ctx.mi++
|
||||
if ctx.mi >= total {
|
||||
ctx.mi = -1
|
||||
return ctx.Transport.RoundTrip(req)
|
||||
// To make the middleware available to the tunnel proxy,
|
||||
// no response is obtained when the request method is equal to Connect
|
||||
if req.Method == http.MethodConnect {
|
||||
return nil, MethodNotSupportErr
|
||||
}
|
||||
return ctx.RoundTrip(req)
|
||||
}
|
||||
|
||||
middleware := ctx.middlewares[ctx.mi]
|
||||
@@ -100,15 +127,81 @@ func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
|
||||
return ctx.Response, err
|
||||
}
|
||||
|
||||
func (ctx *Context) Copy() *Context {
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
//
|
||||
// For higher-level HTTP client support (such as handling of cookies
|
||||
// and redirects), see Get, Post, and the Client type.
|
||||
//
|
||||
// Like the RoundTripper interface, the error types returned
|
||||
// by RoundTrip are unspecified.
|
||||
func (ctx *Context) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// These Headers must be reset when a client Request is issued to reuse a Request
|
||||
if !ctx.KeepClientHeaders {
|
||||
ResetClientHeaders(req)
|
||||
}
|
||||
|
||||
// In some cases it is not always necessary to remove the Proxy Header.
|
||||
// For example, cascade proxy
|
||||
if !ctx.KeepProxyHeaders {
|
||||
RemoveProxyHeaders(req)
|
||||
}
|
||||
|
||||
if ctx.Transport != nil {
|
||||
return ctx.Transport.RoundTrip(req)
|
||||
}
|
||||
return DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
|
||||
// WithRequest get the Context of the request
|
||||
func (ctx *Context) WithRequest(req *http.Request) *Context {
|
||||
return &Context{
|
||||
Context: context.Background(),
|
||||
Request: nil,
|
||||
Request: req,
|
||||
Response: nil,
|
||||
KeepHeader: false,
|
||||
KeepDestinationHeaders: false,
|
||||
KeepProxyHeaders: ctx.KeepProxyHeaders,
|
||||
KeepClientHeaders: ctx.KeepClientHeaders,
|
||||
KeepDestinationHeaders: ctx.KeepDestinationHeaders,
|
||||
Transport: ctx.Transport,
|
||||
mi: -1,
|
||||
middlewares: ctx.middlewares,
|
||||
}
|
||||
}
|
||||
|
||||
// ResetClientHeaders These Headers must be reset when a client Request is issued to reuse a Request
|
||||
func ResetClientHeaders(r *http.Request) {
|
||||
// this must be reset when serving a request with the client
|
||||
r.RequestURI = ""
|
||||
// If no Accept-Encoding header exists, Transport will add the headers it can accept
|
||||
// and would wrap the response body with the relevant reader.
|
||||
r.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
func RemoveProxyHeaders(r *http.Request) {
|
||||
// RFC 2616 (section 13.5.1)
|
||||
// https://www.ietf.org/rfc/rfc2616.txt
|
||||
r.Header.Del("Proxy-Connection")
|
||||
r.Header.Del("Proxy-Authenticate")
|
||||
r.Header.Del("Proxy-Authorization")
|
||||
// Connection, Authenticate and Authorization are single hop Header:
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616.txt
|
||||
// 14.10 Connection
|
||||
// The Connection general-header field allows the sender to specify
|
||||
// options that are desired for that particular connection and MUST NOT
|
||||
// be communicated by proxies over further connections.
|
||||
|
||||
// When server reads http request it sets req.Close to true if
|
||||
// "Connection" header contains "close".
|
||||
// https://github.com/golang/go/blob/master/src/net/http/request.go#L1080
|
||||
// Later, transfer.go adds "Connection: close" back when req.Close is true
|
||||
// https://github.com/golang/go/blob/master/src/net/http/transfer.go#L275
|
||||
// That's why tests that checks "Connection: close" removal fail
|
||||
if r.Header.Get("Connection") == "close" {
|
||||
r.Close = false
|
||||
}
|
||||
r.Header.Del("Connection")
|
||||
}
|
||||
|
64
examples/cascade-proxy/main.go
Normal file
64
examples/cascade-proxy/main.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/telanflow/mps"
|
||||
"github.com/telanflow/mps/middleware"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A simple example of cascading proxy.
|
||||
// It implements BasicAuth
|
||||
func main() {
|
||||
// endPoint server
|
||||
go http.ListenAndServe("localhost:9990", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = w.Write([]byte("successful endPoint server"))
|
||||
}))
|
||||
|
||||
// proxy server 1
|
||||
proxy1 := mps.NewHttpProxy()
|
||||
proxy1.Ctx.KeepProxyHeaders = true
|
||||
proxy1.Use(middleware.BasicAuth("mps_realm_1", func(username, password string) bool {
|
||||
return username == "foo_1" && password == "bar_1"
|
||||
}))
|
||||
go http.ListenAndServe("localhost:9991", proxy1)
|
||||
|
||||
// proxy server 2
|
||||
proxy2 := mps.NewHttpProxy()
|
||||
proxy2.Ctx.KeepProxyHeaders = true
|
||||
proxy2.Use(middleware.BasicAuth("mps_realm_2", func(username, password string) bool {
|
||||
return username == "foo_2" && password == "bar_2"
|
||||
}))
|
||||
proxy2.Transport().Proxy = func(req *http.Request) (*url.URL, error) {
|
||||
middleware.SetBasicAuth(req, "foo_1", "bar_1")
|
||||
return url.Parse("http://localhost:9991")
|
||||
}
|
||||
go http.ListenAndServe("localhost:9992", proxy2)
|
||||
|
||||
// wait proxy server run
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// send request
|
||||
// request ==> proxy2 ==> proxy1 ==> http://localhost:9990
|
||||
// response <== proxy2 <== proxy1 <== http://localhost:9990
|
||||
req, _ := http.NewRequest(http.MethodGet, "http://localhost:9990/", nil)
|
||||
http.DefaultClient.Transport = &http.Transport{
|
||||
Proxy: func(r *http.Request) (*url.URL, error) {
|
||||
middleware.SetBasicAuth(r, "foo_2", "bar_2")
|
||||
return url.Parse("http://localhost:9992")
|
||||
},
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
body, _ := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
log.Println(resp.Header)
|
||||
log.Println(string(body))
|
||||
}
|
@@ -26,7 +26,7 @@ func NewForwardHandler() *ForwardHandler {
|
||||
// Create a ForwardHandler with Context
|
||||
func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler {
|
||||
return &ForwardHandler{
|
||||
Ctx: ctx,
|
||||
Ctx: ctx,
|
||||
BufferPool: pool.DefaultBuffer,
|
||||
}
|
||||
}
|
||||
@@ -34,15 +34,7 @@ func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler {
|
||||
// Standard net/http function. You can use it alone
|
||||
func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Copying a Context preserves the Transport, Middleware
|
||||
ctx := forward.Ctx.Copy()
|
||||
ctx.Request = req
|
||||
|
||||
// In some cases it is not always necessary to remove the Proxy Header.
|
||||
// For example, cascade proxy
|
||||
if !forward.Ctx.KeepHeader {
|
||||
removeProxyHeaders(req)
|
||||
}
|
||||
|
||||
ctx := forward.Ctx.WithRequest(req)
|
||||
resp, err := ctx.Next(req)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), 502)
|
||||
@@ -80,6 +72,26 @@ func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// Use registers an Middleware to proxy
|
||||
func (forward *ForwardHandler) Use(middleware ...Middleware) {
|
||||
forward.Ctx.Use(middleware...)
|
||||
}
|
||||
|
||||
// UseFunc registers an MiddlewareFunc to proxy
|
||||
func (forward *ForwardHandler) UseFunc(fus ...MiddlewareFunc) {
|
||||
forward.Ctx.UseFunc(fus...)
|
||||
}
|
||||
|
||||
// OnRequest filter requests through Filters
|
||||
func (forward *ForwardHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
|
||||
return &ReqFilterGroup{ctx: forward.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// OnResponse filter response through Filters
|
||||
func (forward *ForwardHandler) OnResponse(filters ...Filter) *RespFilterGroup {
|
||||
return &RespFilterGroup{ctx: forward.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// Transport
|
||||
func (forward *ForwardHandler) Transport() *http.Transport {
|
||||
return forward.Ctx.Transport
|
||||
|
27
handle.go
Normal file
27
handle.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package mps
|
||||
|
||||
import "net/http"
|
||||
|
||||
type RequestHandle interface {
|
||||
Handle(req *http.Request, ctx *Context) (*http.Request, *http.Response)
|
||||
}
|
||||
|
||||
// A wrapper that would convert a function to a RequestHandle interface type
|
||||
type RequestHandleFunc func(req *http.Request, ctx *Context) (*http.Request, *http.Response)
|
||||
|
||||
// RequestHandle.Handle(req, ctx) <=> RequestHandleFunc(req, ctx)
|
||||
func (f RequestHandleFunc) Handle(req *http.Request, ctx *Context) (*http.Request, *http.Response) {
|
||||
return f(req, ctx)
|
||||
}
|
||||
|
||||
type ResponseHandle interface {
|
||||
Handle(resp *http.Response, ctx *Context) (*http.Response, error)
|
||||
}
|
||||
|
||||
// A wrapper that would convert a function to a ResponseHandle interface type
|
||||
type ResponseHandleFunc func(resp *http.Response, ctx *Context) (*http.Response, error)
|
||||
|
||||
// ResponseHandle.Handle(resp, ctx) <=> ResponseHandleFunc(resp, ctx)
|
||||
func (f ResponseHandleFunc) Handle(resp *http.Response, ctx *Context) (*http.Response, error) {
|
||||
return f(resp, ctx)
|
||||
}
|
@@ -3,7 +3,6 @@ package mps
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/telanflow/mps/pool"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
@@ -19,25 +18,21 @@ type HttpProxy struct {
|
||||
// HTTP requests use the ReverseHandler proxy by default
|
||||
ReverseHandler http.Handler
|
||||
|
||||
// Client request Context
|
||||
Ctx *Context
|
||||
|
||||
ConnContainer pool.ConnContainer
|
||||
}
|
||||
|
||||
func NewHttpProxy() *HttpProxy {
|
||||
// default Context with Proxy
|
||||
ctx := NewContext()
|
||||
// conn pool
|
||||
connPool := pool.NewConnProvider(pool.DefaultConnOptions)
|
||||
return &HttpProxy{
|
||||
Ctx: ctx,
|
||||
// default HTTP proxy
|
||||
HttpHandler: &ForwardHandler{Ctx: ctx},
|
||||
// default HTTPS proxy
|
||||
HttpsHandler: &TunnelHandler{Ctx: ctx, ConnContainer: connPool},
|
||||
HttpsHandler: &TunnelHandler{Ctx: ctx},
|
||||
// default Reverse proxy
|
||||
ReverseHandler: &ReverseHandler{Ctx: ctx},
|
||||
ConnContainer: connPool,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +40,7 @@ func NewHttpProxy() *HttpProxy {
|
||||
func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if req.Method == http.MethodConnect {
|
||||
proxy.HttpsHandler.ServeHTTP(rw, req)
|
||||
return
|
||||
}
|
||||
|
||||
if !req.URL.IsAbs() {
|
||||
@@ -65,13 +61,13 @@ func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) {
|
||||
}
|
||||
|
||||
// OnRequest filter requests through Filters
|
||||
func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqCondition {
|
||||
return &ReqCondition{ctx: proxy.Ctx, filters: filters}
|
||||
func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqFilterGroup {
|
||||
return &ReqFilterGroup{ctx: proxy.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// OnResponse filter response through Filters
|
||||
func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespCondition {
|
||||
return &RespCondition{ctx: proxy.Ctx, filters: filters}
|
||||
func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespFilterGroup {
|
||||
return &RespFilterGroup{ctx: proxy.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// Transport get http.Transport instance
|
||||
@@ -94,40 +90,6 @@ func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
// Hop-by-hop headers. These are removed when sent to the backend.
|
||||
// As of RFC 7230, hop-by-hop headers are required to appear in the
|
||||
// Connection header field. These are the headers defined by the
|
||||
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
|
||||
// compatibility.
|
||||
func removeProxyHeaders(r *http.Request) {
|
||||
r.RequestURI = "" // this must be reset when serving a request with the client
|
||||
// If no Accept-Encoding header exists, Transport will add the headers it can accept
|
||||
// and would wrap the response body with the relevant reader.
|
||||
r.Header.Del("Accept-Encoding")
|
||||
// RFC 2616 (section 13.5.1)
|
||||
// https://www.ietf.org/rfc/rfc2616.txt
|
||||
r.Header.Del("Proxy-Connection")
|
||||
r.Header.Del("Proxy-Authenticate")
|
||||
r.Header.Del("Proxy-Authorization")
|
||||
// Connection, Authenticate and Authorization are single hop Header:
|
||||
// http://www.w3.org/Protocols/rfc2616/rfc2616.txt
|
||||
// 14.10 Connection
|
||||
// The Connection general-header field allows the sender to specify
|
||||
// options that are desired for that particular connection and MUST NOT
|
||||
// be communicated by proxies over further connections.
|
||||
|
||||
// When server reads http request it sets req.Close to true if
|
||||
// "Connection" header contains "close".
|
||||
// https://github.com/golang/go/blob/master/src/net/http/request.go#L1080
|
||||
// Later, transfer.go adds "Connection: close" back when req.Close is true
|
||||
// https://github.com/golang/go/blob/master/src/net/http/transfer.go#L275
|
||||
// That's why tests that checks "Connection: close" removal fail
|
||||
if r.Header.Get("Connection") == "close" {
|
||||
r.Close = false
|
||||
}
|
||||
r.Header.Del("Connection")
|
||||
}
|
||||
|
||||
func copyHeaders(dst, src http.Header, keepDestHeaders bool) {
|
||||
if !keepDestHeaders {
|
||||
for k := range dst {
|
||||
|
@@ -1,9 +1,9 @@
|
||||
package mps
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
@@ -12,8 +12,14 @@ import (
|
||||
|
||||
func NewTestServer() *httptest.Server {
|
||||
return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
||||
query := req.URL.Query()
|
||||
text := []byte("hello world")
|
||||
if query.Get("text") != "" {
|
||||
text = []byte(query.Get("text"))
|
||||
}
|
||||
|
||||
rw.Header().Set("Server", "MPS proxy server")
|
||||
rw.Write([]byte("hello world"))
|
||||
_, _ = rw.Write(text)
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -57,8 +63,15 @@ func TestMiddlewareFunc(t *testing.T) {
|
||||
proxy := NewHttpProxy()
|
||||
// use Middleware
|
||||
proxy.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
|
||||
log.Println(req.URL.String())
|
||||
return ctx.Next(req)
|
||||
resp, err := ctx.Next(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("middleware")
|
||||
resp.Body = ioutil.NopCloser(&buf)
|
||||
return resp, nil
|
||||
})
|
||||
proxySrv := httptest.NewServer(proxy)
|
||||
defer proxySrv.Close()
|
||||
@@ -77,5 +90,5 @@ func TestMiddlewareFunc(t *testing.T) {
|
||||
asserts := assert.New(t)
|
||||
asserts.Equal(resp.StatusCode, 200)
|
||||
asserts.Equal(int64(len(body)), resp.ContentLength)
|
||||
log.Println(string(body))
|
||||
asserts.Equal(string(body), "middleware")
|
||||
}
|
||||
|
@@ -21,7 +21,7 @@ type Middleware interface {
|
||||
// A wrapper that would convert a function to a Middleware interface type
|
||||
type MiddlewareFunc func(req *http.Request, ctx *Context) (*http.Response, error)
|
||||
|
||||
// MiddlewareFunc.Handle(req, ctx) <=> MiddlewareFunc(req, ctx)
|
||||
// Middleware.Handle(req, ctx) <=> MiddlewareFunc(req, ctx)
|
||||
func (f MiddlewareFunc) Handle(req *http.Request, ctx *Context) (*http.Response, error) {
|
||||
return f(req, ctx)
|
||||
}
|
||||
|
94
middleware/basicAuth.go
Normal file
94
middleware/basicAuth.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"github.com/telanflow/mps"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// proxy Authorization header
|
||||
const proxyAuthorization = "Proxy-Authorization"
|
||||
|
||||
// BasicAuth returns a HTTP Basic Authentication middleware for requests
|
||||
// You probably want to use mps.BasicAuth(proxy) to enable authentication for all proxy activities
|
||||
func BasicAuth(realm string, fn func(username, password string) bool) mps.MiddlewareFunc {
|
||||
return func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
|
||||
auth := req.Header.Get(proxyAuthorization)
|
||||
if auth == "" {
|
||||
return BasicUnauthorized(req, realm), nil
|
||||
}
|
||||
// parses an Basic Authentication string.
|
||||
usr, pwd, ok := parseBasicAuth(auth)
|
||||
if !ok {
|
||||
return BasicUnauthorized(req, realm), nil
|
||||
}
|
||||
if !fn(usr, pwd) {
|
||||
return BasicUnauthorized(req, realm), nil
|
||||
}
|
||||
// Authorization passed
|
||||
return ctx.Next(req)
|
||||
}
|
||||
}
|
||||
|
||||
// SetBasicAuth sets the request's Authorization header to use HTTP
|
||||
// Basic Authentication with the provided username and password.
|
||||
//
|
||||
// With HTTP Basic Authentication the provided username and password
|
||||
// are not encrypted.
|
||||
//
|
||||
// Some protocols may impose additional requirements on pre-escaping the
|
||||
// username and password. For instance, when used with OAuth2, both arguments
|
||||
// must be URL encoded first with url.QueryEscape.
|
||||
func SetBasicAuth(req *http.Request, username, password string) {
|
||||
req.Header.Set(proxyAuthorization, "Basic "+basicAuth(username, password))
|
||||
}
|
||||
|
||||
// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
|
||||
// "To receive authorization, the client sends the userid and password,
|
||||
// separated by a single colon (":") character, within a base64
|
||||
// encoded string in the credentials."
|
||||
// It is not meant to be urlencoded.
|
||||
func basicAuth(username, password string) string {
|
||||
auth := username + ":" + password
|
||||
return base64.StdEncoding.EncodeToString([]byte(auth))
|
||||
}
|
||||
|
||||
// parseBasicAuth parses an HTTP Basic Authentication string.
|
||||
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
|
||||
func parseBasicAuth(auth string) (username, password string, ok bool) {
|
||||
const prefix = "Basic "
|
||||
// Case insensitive prefix match. See Issue 22736.
|
||||
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
|
||||
return
|
||||
}
|
||||
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
cs := string(c)
|
||||
s := strings.IndexByte(cs, ':')
|
||||
if s < 0 {
|
||||
return
|
||||
}
|
||||
return cs[:s], cs[s+1:], true
|
||||
}
|
||||
|
||||
func BasicUnauthorized(req *http.Request, realm string) *http.Response {
|
||||
const unauthorizedMsg = "407 Proxy Authentication Required"
|
||||
// verify realm is well formed
|
||||
return &http.Response{
|
||||
StatusCode: 407,
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Request: req,
|
||||
Header: http.Header{
|
||||
"Proxy-Authenticate": []string{"Basic realm=" + realm},
|
||||
"Proxy-Connection": []string{"close"},
|
||||
},
|
||||
Body: ioutil.NopCloser(bytes.NewBuffer([]byte(unauthorizedMsg))),
|
||||
ContentLength: int64(len(unauthorizedMsg)),
|
||||
}
|
||||
}
|
@@ -66,11 +66,25 @@ func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) {
|
||||
}
|
||||
|
||||
// Standard net/http function. You can use it alone
|
||||
func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
|
||||
// Execution middleware
|
||||
ctx := mitm.Ctx.WithRequest(r)
|
||||
resp, err := ctx.Next(r)
|
||||
if err != nil && err != MethodNotSupportErr {
|
||||
if resp != nil {
|
||||
copyHeaders(rw.Header(), resp.Header, mitm.Ctx.KeepDestinationHeaders)
|
||||
rw.WriteHeader(resp.StatusCode)
|
||||
buf := mitm.buffer().Get()
|
||||
_, err = io.CopyBuffer(rw, resp.Body, buf)
|
||||
mitm.buffer().Put(buf)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// get hijacker connection
|
||||
proxyClient, err := hijacker(w)
|
||||
proxyClient, err := hijacker(rw)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), 502)
|
||||
http.Error(rw, err.Error(), 502)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -91,7 +105,7 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
rawClientTls := tls.Server(proxyClient, tlsConfig)
|
||||
if err := rawClientTls.Handshake(); err != nil {
|
||||
ConnError(proxyClient)
|
||||
//ctx.Warnf("Cannot handshake client %v %v", r.Host, err)
|
||||
_ = rawClientTls.Close()
|
||||
return
|
||||
}
|
||||
defer rawClientTls.Close()
|
||||
@@ -100,7 +114,6 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
for !isEof(clientTlsReader) {
|
||||
req, err := http.ReadRequest(clientTlsReader)
|
||||
if err != nil {
|
||||
//ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -115,23 +128,14 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// Copying a Context preserves the Transport, Middleware
|
||||
ctx := mitm.Ctx.Copy()
|
||||
ctx.Request = req
|
||||
|
||||
// In some cases it is not always necessary to remove the Proxy Header.
|
||||
// For example, cascade proxy
|
||||
if !mitm.Ctx.KeepHeader {
|
||||
removeProxyHeaders(req)
|
||||
}
|
||||
|
||||
var resp *http.Response
|
||||
|
||||
// Copying a Context preserves the Transport, Middleware
|
||||
ctx := mitm.Ctx.WithRequest(req)
|
||||
resp, err = ctx.Next(req)
|
||||
if err != nil {
|
||||
//ctx.Warnf("Cannot read TLS response from mitm'd server %v", err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
status := resp.Status
|
||||
statusCode := strconv.Itoa(resp.StatusCode) + " "
|
||||
@@ -155,12 +159,12 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
err = resp.Header.Write(rawClientTls)
|
||||
if err != nil {
|
||||
//ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err)
|
||||
resp.Body.Close()
|
||||
return
|
||||
}
|
||||
_, err = io.WriteString(rawClientTls, "\r\n")
|
||||
if err != nil {
|
||||
//ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err)
|
||||
resp.Body.Close()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,15 +174,17 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
_, err = io.CopyBuffer(chunked, resp.Body, buf)
|
||||
mitm.buffer().Put(buf)
|
||||
if err != nil {
|
||||
//ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err)
|
||||
resp.Body.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// closed response body
|
||||
resp.Body.Close()
|
||||
|
||||
if err := chunked.Close(); err != nil {
|
||||
//ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil {
|
||||
//ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -196,13 +202,13 @@ func (mitm *MitmHandler) UseFunc(fus ...MiddlewareFunc) {
|
||||
}
|
||||
|
||||
// OnRequest filter requests through Filters
|
||||
func (mitm *MitmHandler) OnRequest(filters ...Filter) *ReqCondition {
|
||||
return &ReqCondition{ctx: mitm.Ctx, filters: filters}
|
||||
func (mitm *MitmHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
|
||||
return &ReqFilterGroup{ctx: mitm.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// OnResponse filter response through Filters
|
||||
func (mitm *MitmHandler) OnResponse(filters ...Filter) *RespCondition {
|
||||
return &RespCondition{ctx: mitm.Ctx, filters: filters}
|
||||
func (mitm *MitmHandler) OnResponse(filters ...Filter) *RespFilterGroup {
|
||||
return &RespFilterGroup{ctx: mitm.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// Get buffer pool
|
||||
@@ -213,6 +219,14 @@ func (mitm *MitmHandler) buffer() httputil.BufferPool {
|
||||
return pool.DefaultBuffer
|
||||
}
|
||||
|
||||
// Get cert.Container instance
|
||||
func (mitm *MitmHandler) certContainer() cert.Container {
|
||||
if mitm.CertContainer != nil {
|
||||
return mitm.CertContainer
|
||||
}
|
||||
return cert.DefaultMemProvider
|
||||
}
|
||||
|
||||
// Transport
|
||||
func (mitm *MitmHandler) Transport() *http.Transport {
|
||||
return mitm.Ctx.Transport
|
||||
@@ -222,7 +236,7 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
|
||||
host = stripPort(host)
|
||||
|
||||
// Returned existing certificate for the host
|
||||
crt, err := mitm.CertContainer.Get(host)
|
||||
crt, err := mitm.certContainer().Get(host)
|
||||
if err == nil && crt != nil {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
@@ -238,7 +252,7 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
|
||||
}
|
||||
|
||||
// Set certificate to container
|
||||
mitm.CertContainer.Set(host, crt)
|
||||
_ = mitm.certContainer().Set(host, crt)
|
||||
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
@@ -246,6 +260,7 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// sign host
|
||||
func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) {
|
||||
// Use the provided ca for certificate generation.
|
||||
var x509ca *x509.Certificate
|
||||
|
@@ -3,11 +3,15 @@ package pool
|
||||
import "time"
|
||||
|
||||
var DefaultConnOptions = &ConnOptions{
|
||||
IdleMaxCap: 20,
|
||||
Timeout: time.Minute,
|
||||
IdleMaxCap: 30,
|
||||
Timeout: 90 * time.Second,
|
||||
}
|
||||
|
||||
// ConnOptions is ConnProvider options
|
||||
type ConnOptions struct {
|
||||
// IdleMaxCap is max connection capacity for a single net.Addr
|
||||
IdleMaxCap int
|
||||
Timeout time.Duration
|
||||
|
||||
// Timeout specifies how long the connection will timeout
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
@@ -10,6 +10,9 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
var DefaultConnProvider = NewConnProvider(DefaultConnOptions)
|
||||
|
||||
// ConnProvider is a connection pool, it implements ConnContainer
|
||||
type ConnProvider struct {
|
||||
mu sync.RWMutex
|
||||
idleConnMap map[string]chan net.Conn
|
||||
@@ -17,6 +20,7 @@ type ConnProvider struct {
|
||||
closed int32
|
||||
}
|
||||
|
||||
// Create a ConnProvider
|
||||
func NewConnProvider(opt *ConnOptions) *ConnProvider {
|
||||
return &ConnProvider{
|
||||
options: opt,
|
||||
@@ -42,6 +46,7 @@ func (p *ConnProvider) Get(addr string) (net.Conn, error) {
|
||||
RETRY:
|
||||
select {
|
||||
case conn := <-p.idleConnMap[addr]:
|
||||
// Getting a net.Conn requires verifying that the net.Conn is valid
|
||||
_, err := conn.Read([]byte{})
|
||||
if err != nil || err == io.EOF {
|
||||
// conn is close Or timeout
|
||||
@@ -58,6 +63,8 @@ RETRY:
|
||||
func (p *ConnProvider) Put(conn net.Conn) error {
|
||||
closed := atomic.LoadInt32(&p.closed)
|
||||
if closed == 1 {
|
||||
// pool is closed, this conn must be closed
|
||||
conn.Close()
|
||||
return errors.New("pool is closed")
|
||||
}
|
||||
|
||||
@@ -70,13 +77,13 @@ func (p *ConnProvider) Put(conn net.Conn) error {
|
||||
p.mu.Unlock()
|
||||
|
||||
// set conn timeout
|
||||
// The timeout will be verified at the next `Get()`
|
||||
err := conn.SetDeadline(time.Now().Add(p.options.Timeout))
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// set idle conn
|
||||
select {
|
||||
case p.idleConnMap[addr] <- conn:
|
||||
return nil
|
||||
@@ -86,6 +93,7 @@ func (p *ConnProvider) Put(conn net.Conn) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Release connection pool
|
||||
func (p *ConnProvider) Release() error {
|
||||
closed := atomic.LoadInt32(&p.closed)
|
||||
if closed == 1 {
|
||||
@@ -93,7 +101,6 @@ func (p *ConnProvider) Release() error {
|
||||
}
|
||||
|
||||
atomic.StoreInt32(&p.closed, 1)
|
||||
|
||||
for _, connChan := range p.idleConnMap {
|
||||
close(connChan)
|
||||
for conn, ok := <-connChan; ok; {
|
||||
|
@@ -4,16 +4,17 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ReqCondition struct {
|
||||
// ReqCondition is a request condition group
|
||||
type ReqFilterGroup struct {
|
||||
ctx *Context
|
||||
filters []Filter
|
||||
}
|
||||
|
||||
func (cond *ReqCondition) DoFunc(fn func(req *http.Request) (*http.Request, *http.Response)) {
|
||||
func (cond *ReqFilterGroup) DoFunc(fn func(req *http.Request, ctx *Context) (*http.Request, *http.Response)) {
|
||||
cond.Do(RequestHandleFunc(fn))
|
||||
}
|
||||
|
||||
func (cond *ReqCondition) Do(fn RequestHandle) {
|
||||
func (cond *ReqFilterGroup) Do(h RequestHandle) {
|
||||
cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
|
||||
total := len(cond.filters)
|
||||
for i := 0; i < total; i++ {
|
||||
@@ -22,7 +23,7 @@ func (cond *ReqCondition) Do(fn RequestHandle) {
|
||||
}
|
||||
}
|
||||
|
||||
req, resp := fn.Handle(req)
|
||||
req, resp := h.Handle(req, ctx)
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
@@ -1,13 +0,0 @@
|
||||
package mps
|
||||
|
||||
import "net/http"
|
||||
|
||||
type RequestHandle interface {
|
||||
Handle(req *http.Request) (*http.Request, *http.Response)
|
||||
}
|
||||
|
||||
type RequestHandleFunc func(req *http.Request) (*http.Request, *http.Response)
|
||||
|
||||
func (f RequestHandleFunc) Handle(req *http.Request) (*http.Request, *http.Response) {
|
||||
return f(req)
|
||||
}
|
@@ -4,16 +4,16 @@ import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type RespCondition struct {
|
||||
type RespFilterGroup struct {
|
||||
ctx *Context
|
||||
filters []Filter
|
||||
}
|
||||
|
||||
func (cond *RespCondition) DoFunc(fn func(resp *http.Response) (*http.Response, error)) {
|
||||
func (cond *RespFilterGroup) DoFunc(fn func(resp *http.Response, ctx *Context) (*http.Response, error)) {
|
||||
cond.Do(ResponseHandleFunc(fn))
|
||||
}
|
||||
|
||||
func (cond *RespCondition) Do(fn ResponseHandle) {
|
||||
func (cond *RespFilterGroup) Do(h ResponseHandle) {
|
||||
cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
|
||||
total := len(cond.filters)
|
||||
for i := 0; i < total; i++ {
|
||||
@@ -27,6 +27,6 @@ func (cond *RespCondition) Do(fn ResponseHandle) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fn.Handle(resp)
|
||||
return h.Handle(resp, ctx)
|
||||
})
|
||||
}
|
@@ -1,13 +0,0 @@
|
||||
package mps
|
||||
|
||||
import "net/http"
|
||||
|
||||
type ResponseHandle interface {
|
||||
Handle(resp *http.Response) (*http.Response, error)
|
||||
}
|
||||
|
||||
type ResponseHandleFunc func(resp *http.Response) (*http.Response, error)
|
||||
|
||||
func (f ResponseHandleFunc) Handle(resp *http.Response) (*http.Response, error) {
|
||||
return f(resp)
|
||||
}
|
@@ -15,6 +15,7 @@ type ReverseHandler struct {
|
||||
BufferPool httputil.BufferPool
|
||||
}
|
||||
|
||||
// Create a ReverseHandler
|
||||
func NewReverseHandler() *ReverseHandler {
|
||||
return &ReverseHandler{
|
||||
Ctx: NewContext(),
|
||||
@@ -25,9 +26,7 @@ func NewReverseHandler() *ReverseHandler {
|
||||
// Standard net/http function. You can use it alone
|
||||
func (reverse *ReverseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Copying a Context preserves the Transport, Middleware
|
||||
ctx := reverse.Ctx.Copy()
|
||||
ctx.Request = req
|
||||
|
||||
ctx := reverse.Ctx.WithRequest(req)
|
||||
resp, err := ctx.Next(req)
|
||||
if err != nil {
|
||||
http.Error(rw, err.Error(), 502)
|
||||
@@ -76,13 +75,13 @@ func (reverse *ReverseHandler) UseFunc(fus ...MiddlewareFunc) {
|
||||
}
|
||||
|
||||
// OnRequest filter requests through Filters
|
||||
func (reverse *ReverseHandler) OnRequest(filters ...Filter) *ReqCondition {
|
||||
return &ReqCondition{ctx: reverse.Ctx, filters: filters}
|
||||
func (reverse *ReverseHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
|
||||
return &ReqFilterGroup{ctx: reverse.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// OnResponse filter response through Filters
|
||||
func (reverse *ReverseHandler) OnResponse(filters ...Filter) *RespCondition {
|
||||
return &RespCondition{ctx: reverse.Ctx, filters: filters}
|
||||
func (reverse *ReverseHandler) OnResponse(filters ...Filter) *RespFilterGroup {
|
||||
return &RespFilterGroup{ctx: reverse.Ctx, filters: filters}
|
||||
}
|
||||
|
||||
// Get buffer pool
|
||||
|
24
transport.go
Normal file
24
transport.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package mps
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Default http.Transport option
|
||||
var DefaultTransport = &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: 15 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
DualStack: true,
|
||||
}).DialContext,
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
}
|
@@ -9,6 +9,7 @@ import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -27,23 +28,35 @@ type TunnelHandler struct {
|
||||
// Create a tunnel handler
|
||||
func NewTunnelHandler() *TunnelHandler {
|
||||
return &TunnelHandler{
|
||||
Ctx: NewContext(),
|
||||
BufferPool: pool.DefaultBuffer,
|
||||
ConnContainer: pool.NewConnProvider(pool.DefaultConnOptions),
|
||||
Ctx: NewContext(),
|
||||
BufferPool: pool.DefaultBuffer,
|
||||
}
|
||||
}
|
||||
|
||||
// Create a tunnel handler with Context
|
||||
func NewTunnelHandlerWithContext(ctx *Context) *TunnelHandler {
|
||||
return &TunnelHandler{
|
||||
Ctx: ctx,
|
||||
BufferPool: pool.DefaultBuffer,
|
||||
ConnContainer: pool.NewConnProvider(pool.DefaultConnOptions),
|
||||
Ctx: ctx,
|
||||
BufferPool: pool.DefaultBuffer,
|
||||
}
|
||||
}
|
||||
|
||||
// Standard net/http function. You can use it alone
|
||||
func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
// Execution middleware
|
||||
ctx := tunnel.Ctx.WithRequest(req)
|
||||
resp, err := ctx.Next(req)
|
||||
if err != nil && err != MethodNotSupportErr {
|
||||
if resp != nil {
|
||||
copyHeaders(rw.Header(), resp.Header, tunnel.Ctx.KeepDestinationHeaders)
|
||||
rw.WriteHeader(resp.StatusCode)
|
||||
buf := tunnel.buffer().Get()
|
||||
_, err = io.CopyBuffer(rw, resp.Body, buf)
|
||||
tunnel.buffer().Put(buf)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// hijacker connection
|
||||
proxyClient, err := hijacker(rw)
|
||||
if err != nil {
|
||||
@@ -71,27 +84,19 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
|
||||
}
|
||||
|
||||
// connect to targetAddr
|
||||
if tunnel.ConnContainer != nil {
|
||||
targetConn, err = tunnel.ConnContainer.Get(targetAddr)
|
||||
if err != nil {
|
||||
targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
|
||||
}
|
||||
} else {
|
||||
targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
|
||||
}
|
||||
targetConn, err = tunnel.connContainer().Get(targetAddr)
|
||||
if err != nil {
|
||||
ConnError(proxyClient)
|
||||
return
|
||||
targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
|
||||
if err != nil {
|
||||
ConnError(proxyClient)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// If the ConnContainer is exists,
|
||||
// When io.CopyBuffer is complete,
|
||||
// put the idle connection into the ConnContainer so can reuse it next time
|
||||
if tunnel.ConnContainer != nil {
|
||||
defer tunnel.ConnContainer.Put(targetConn)
|
||||
} else {
|
||||
defer targetConn.Close()
|
||||
}
|
||||
defer tunnel.connContainer().Put(targetConn)
|
||||
|
||||
// The cascade proxy needs to forward the request
|
||||
if isCascadeProxy {
|
||||
@@ -115,16 +120,9 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
|
||||
|
||||
func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) {
|
||||
if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.DialContext != nil {
|
||||
return tunnel.Ctx.Transport.DialContext(tunnel.Context(), network, addr)
|
||||
return tunnel.Ctx.Transport.DialContext(tunnel.context(), network, addr)
|
||||
}
|
||||
return net.Dial(network, addr)
|
||||
}
|
||||
|
||||
func (tunnel *TunnelHandler) Context() context.Context {
|
||||
if tunnel.Ctx.Context != nil {
|
||||
return tunnel.Ctx.Context
|
||||
}
|
||||
return context.Background()
|
||||
return net.DialTimeout(network, addr, 30*time.Second)
|
||||
}
|
||||
|
||||
// Transport
|
||||
@@ -132,6 +130,14 @@ func (tunnel *TunnelHandler) Transport() *http.Transport {
|
||||
return tunnel.Ctx.Transport
|
||||
}
|
||||
|
||||
// get a context.Context
|
||||
func (tunnel *TunnelHandler) context() context.Context {
|
||||
if tunnel.Ctx.Context != nil {
|
||||
return tunnel.Ctx.Context
|
||||
}
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// Get buffer pool
|
||||
func (tunnel *TunnelHandler) buffer() httputil.BufferPool {
|
||||
if tunnel.BufferPool != nil {
|
||||
@@ -140,6 +146,14 @@ func (tunnel *TunnelHandler) buffer() httputil.BufferPool {
|
||||
return pool.DefaultBuffer
|
||||
}
|
||||
|
||||
// Get a conn pool
|
||||
func (tunnel *TunnelHandler) connContainer() pool.ConnContainer {
|
||||
if tunnel.ConnContainer != nil {
|
||||
return tunnel.ConnContainer
|
||||
}
|
||||
return pool.DefaultConnProvider
|
||||
}
|
||||
|
||||
func hostAndPort(addr string) string {
|
||||
if !hasPort.MatchString(addr) {
|
||||
addr += ":80"
|
||||
|
Reference in New Issue
Block a user