A simple example of cascading proxy

This commit is contained in:
telanflow
2020-08-12 14:30:00 +08:00
parent 602d96cadb
commit e81a17de66
21 changed files with 486 additions and 177 deletions

View File

@@ -13,9 +13,11 @@ It support HTTP, HTTPS, Websocket, ForwardProxy, ReverseProxy, MitmProxy
- [X] Http Proxy
- [X] Https Proxy
- [X] Forward Proxy
- [X] Reverse Proxy
- [X] Tunnel Proxy
- [ ] Reverse Proxy
- [X] Mitm Proxy (Man-in-the-middle)
- [ ] WekSocket Proxy
- [ ] Socks5 Proxy
## 🧰 Install

View File

@@ -13,9 +13,11 @@ MPS 是一个中间代理服务扩展库。
- [X] Http代理
- [X] Https代理
- [X] 正向代理
- [X] 反向代理
- [X] 隧道代理
- [ ] 反向代理
- [X] 中间人代理 (MITM)
- [ ] WekSocket代理
- [ ] Socks5代理
## 🧰 安装

View File

@@ -6,6 +6,8 @@ import (
"strings"
)
var DefaultMemProvider = NewMemProvider()
// MemProvider A simple in-memory certificate cache
type MemProvider struct {
cache map[string]*tls.Certificate

View File

@@ -3,11 +3,15 @@ package mps
import (
"context"
"crypto/tls"
"errors"
"net"
"net/http"
"time"
)
// Http method not support
var MethodNotSupportErr = errors.New("request method not support")
// Context for the request
// which contains Middleware, Transport, and other values
type Context struct {
@@ -23,9 +27,12 @@ type Context struct {
// Transport is used for global HTTP requests, and it will be reused.
Transport *http.Transport
// In some cases it is not always necessary to remove the Proxy Header.
// In some cases it is not always necessary to remove the proxy headers.
// For example, cascade proxy
KeepHeader bool
KeepProxyHeaders bool
// In some cases it is not always necessary to reset the headers.
KeepClientHeaders bool
// KeepDestinationHeaders indicates the proxy should retain any headers
// present in the http.Response before proxying
@@ -39,9 +46,11 @@ type Context struct {
middlewares []Middleware
}
// Create a Context
func NewContext() *Context {
return &Context{
Context: context.Background(),
// Cannot reuse one Transport because multiple proxy can collide with each other
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 15 * time.Second,
@@ -58,31 +67,44 @@ func NewContext() *Context {
},
Request: nil,
Response: nil,
KeepHeader: false,
KeepProxyHeaders: false,
KeepClientHeaders: false,
KeepDestinationHeaders: false,
mi: -1,
middlewares: make([]Middleware, 0),
}
}
// Use registers an Middleware to proxy
func (ctx *Context) Use(middleware ...Middleware) {
if ctx.middlewares == nil {
ctx.middlewares = make([]Middleware, 0)
}
ctx.middlewares = append(ctx.middlewares, middleware...)
}
// UseFunc registers an MiddlewareFunc to proxy
func (ctx *Context) UseFunc(fns ...MiddlewareFunc) {
if ctx.middlewares == nil {
ctx.middlewares = make([]Middleware, 0)
}
for _, fn := range fns {
ctx.middlewares = append(ctx.middlewares, fn)
}
}
// Next to exec middlewares
// Execute the next middleware as a linked list. "ctx.Next(req)"
// eg:
// func Handle(req *http.Request, ctx *Context) (*http.Response, error) {
// // You can do anything to modify the http.Request ...
// resp, err := ctx.Next(req)
// // You can do anything to modify the http.Response ...
// return resp, err
// }
//
// Alternatively, you can simply return the response without executing `ctx.Next()`,
// which will interrupt subsequent middleware execution.
func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
var (
total = len(ctx.middlewares)
@@ -91,7 +113,12 @@ func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
ctx.mi++
if ctx.mi >= total {
ctx.mi = -1
return ctx.Transport.RoundTrip(req)
// To make the middleware available to the tunnel proxy,
// no response is obtained when the request method is equal to Connect
if req.Method == http.MethodConnect {
return nil, MethodNotSupportErr
}
return ctx.RoundTrip(req)
}
middleware := ctx.middlewares[ctx.mi]
@@ -100,15 +127,81 @@ func (ctx *Context) Next(req *http.Request) (*http.Response, error) {
return ctx.Response, err
}
func (ctx *Context) Copy() *Context {
// RoundTrip implements the RoundTripper interface.
//
// For higher-level HTTP client support (such as handling of cookies
// and redirects), see Get, Post, and the Client type.
//
// Like the RoundTripper interface, the error types returned
// by RoundTrip are unspecified.
func (ctx *Context) RoundTrip(req *http.Request) (*http.Response, error) {
// These Headers must be reset when a client Request is issued to reuse a Request
if !ctx.KeepClientHeaders {
ResetClientHeaders(req)
}
// In some cases it is not always necessary to remove the Proxy Header.
// For example, cascade proxy
if !ctx.KeepProxyHeaders {
RemoveProxyHeaders(req)
}
if ctx.Transport != nil {
return ctx.Transport.RoundTrip(req)
}
return DefaultTransport.RoundTrip(req)
}
// WithRequest get the Context of the request
func (ctx *Context) WithRequest(req *http.Request) *Context {
return &Context{
Context: context.Background(),
Request: nil,
Request: req,
Response: nil,
KeepHeader: false,
KeepDestinationHeaders: false,
KeepProxyHeaders: ctx.KeepProxyHeaders,
KeepClientHeaders: ctx.KeepClientHeaders,
KeepDestinationHeaders: ctx.KeepDestinationHeaders,
Transport: ctx.Transport,
mi: -1,
middlewares: ctx.middlewares,
}
}
// ResetClientHeaders These Headers must be reset when a client Request is issued to reuse a Request
func ResetClientHeaders(r *http.Request) {
// this must be reset when serving a request with the client
r.RequestURI = ""
// If no Accept-Encoding header exists, Transport will add the headers it can accept
// and would wrap the response body with the relevant reader.
r.Header.Del("Accept-Encoding")
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
func RemoveProxyHeaders(r *http.Request) {
// RFC 2616 (section 13.5.1)
// https://www.ietf.org/rfc/rfc2616.txt
r.Header.Del("Proxy-Connection")
r.Header.Del("Proxy-Authenticate")
r.Header.Del("Proxy-Authorization")
// Connection, Authenticate and Authorization are single hop Header:
// http://www.w3.org/Protocols/rfc2616/rfc2616.txt
// 14.10 Connection
// The Connection general-header field allows the sender to specify
// options that are desired for that particular connection and MUST NOT
// be communicated by proxies over further connections.
// When server reads http request it sets req.Close to true if
// "Connection" header contains "close".
// https://github.com/golang/go/blob/master/src/net/http/request.go#L1080
// Later, transfer.go adds "Connection: close" back when req.Close is true
// https://github.com/golang/go/blob/master/src/net/http/transfer.go#L275
// That's why tests that checks "Connection: close" removal fail
if r.Header.Get("Connection") == "close" {
r.Close = false
}
r.Header.Del("Connection")
}

View File

@@ -0,0 +1,64 @@
package main
import (
"github.com/telanflow/mps"
"github.com/telanflow/mps/middleware"
"io/ioutil"
"log"
"net/http"
"net/url"
"time"
)
// A simple example of cascading proxy.
// It implements BasicAuth
func main() {
// endPoint server
go http.ListenAndServe("localhost:9990", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte("successful endPoint server"))
}))
// proxy server 1
proxy1 := mps.NewHttpProxy()
proxy1.Ctx.KeepProxyHeaders = true
proxy1.Use(middleware.BasicAuth("mps_realm_1", func(username, password string) bool {
return username == "foo_1" && password == "bar_1"
}))
go http.ListenAndServe("localhost:9991", proxy1)
// proxy server 2
proxy2 := mps.NewHttpProxy()
proxy2.Ctx.KeepProxyHeaders = true
proxy2.Use(middleware.BasicAuth("mps_realm_2", func(username, password string) bool {
return username == "foo_2" && password == "bar_2"
}))
proxy2.Transport().Proxy = func(req *http.Request) (*url.URL, error) {
middleware.SetBasicAuth(req, "foo_1", "bar_1")
return url.Parse("http://localhost:9991")
}
go http.ListenAndServe("localhost:9992", proxy2)
// wait proxy server run
time.Sleep(2 * time.Second)
// send request
// request ==> proxy2 ==> proxy1 ==> http://localhost:9990
// response <== proxy2 <== proxy1 <== http://localhost:9990
req, _ := http.NewRequest(http.MethodGet, "http://localhost:9990/", nil)
http.DefaultClient.Transport = &http.Transport{
Proxy: func(r *http.Request) (*url.URL, error) {
middleware.SetBasicAuth(r, "foo_2", "bar_2")
return url.Parse("http://localhost:9992")
},
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Fatal(err)
}
body, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()
log.Println(resp.Header)
log.Println(string(body))
}

View File

@@ -26,7 +26,7 @@ func NewForwardHandler() *ForwardHandler {
// Create a ForwardHandler with Context
func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler {
return &ForwardHandler{
Ctx: ctx,
Ctx: ctx,
BufferPool: pool.DefaultBuffer,
}
}
@@ -34,15 +34,7 @@ func NewForwardHandlerWithContext(ctx *Context) *ForwardHandler {
// Standard net/http function. You can use it alone
func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Copying a Context preserves the Transport, Middleware
ctx := forward.Ctx.Copy()
ctx.Request = req
// In some cases it is not always necessary to remove the Proxy Header.
// For example, cascade proxy
if !forward.Ctx.KeepHeader {
removeProxyHeaders(req)
}
ctx := forward.Ctx.WithRequest(req)
resp, err := ctx.Next(req)
if err != nil {
http.Error(rw, err.Error(), 502)
@@ -80,6 +72,26 @@ func (forward *ForwardHandler) ServeHTTP(rw http.ResponseWriter, req *http.Reque
}
}
// Use registers an Middleware to proxy
func (forward *ForwardHandler) Use(middleware ...Middleware) {
forward.Ctx.Use(middleware...)
}
// UseFunc registers an MiddlewareFunc to proxy
func (forward *ForwardHandler) UseFunc(fus ...MiddlewareFunc) {
forward.Ctx.UseFunc(fus...)
}
// OnRequest filter requests through Filters
func (forward *ForwardHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
return &ReqFilterGroup{ctx: forward.Ctx, filters: filters}
}
// OnResponse filter response through Filters
func (forward *ForwardHandler) OnResponse(filters ...Filter) *RespFilterGroup {
return &RespFilterGroup{ctx: forward.Ctx, filters: filters}
}
// Transport
func (forward *ForwardHandler) Transport() *http.Transport {
return forward.Ctx.Transport

27
handle.go Normal file
View File

@@ -0,0 +1,27 @@
package mps
import "net/http"
type RequestHandle interface {
Handle(req *http.Request, ctx *Context) (*http.Request, *http.Response)
}
// A wrapper that would convert a function to a RequestHandle interface type
type RequestHandleFunc func(req *http.Request, ctx *Context) (*http.Request, *http.Response)
// RequestHandle.Handle(req, ctx) <=> RequestHandleFunc(req, ctx)
func (f RequestHandleFunc) Handle(req *http.Request, ctx *Context) (*http.Request, *http.Response) {
return f(req, ctx)
}
type ResponseHandle interface {
Handle(resp *http.Response, ctx *Context) (*http.Response, error)
}
// A wrapper that would convert a function to a ResponseHandle interface type
type ResponseHandleFunc func(resp *http.Response, ctx *Context) (*http.Response, error)
// ResponseHandle.Handle(resp, ctx) <=> ResponseHandleFunc(resp, ctx)
func (f ResponseHandleFunc) Handle(resp *http.Response, ctx *Context) (*http.Response, error) {
return f(resp, ctx)
}

View File

@@ -3,7 +3,6 @@ package mps
import (
"errors"
"fmt"
"github.com/telanflow/mps/pool"
"net"
"net/http"
)
@@ -19,25 +18,21 @@ type HttpProxy struct {
// HTTP requests use the ReverseHandler proxy by default
ReverseHandler http.Handler
// Client request Context
Ctx *Context
ConnContainer pool.ConnContainer
}
func NewHttpProxy() *HttpProxy {
// default Context with Proxy
ctx := NewContext()
// conn pool
connPool := pool.NewConnProvider(pool.DefaultConnOptions)
return &HttpProxy{
Ctx: ctx,
// default HTTP proxy
HttpHandler: &ForwardHandler{Ctx: ctx},
// default HTTPS proxy
HttpsHandler: &TunnelHandler{Ctx: ctx, ConnContainer: connPool},
HttpsHandler: &TunnelHandler{Ctx: ctx},
// default Reverse proxy
ReverseHandler: &ReverseHandler{Ctx: ctx},
ConnContainer: connPool,
}
}
@@ -45,6 +40,7 @@ func NewHttpProxy() *HttpProxy {
func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.Method == http.MethodConnect {
proxy.HttpsHandler.ServeHTTP(rw, req)
return
}
if !req.URL.IsAbs() {
@@ -65,13 +61,13 @@ func (proxy *HttpProxy) UseFunc(fus ...MiddlewareFunc) {
}
// OnRequest filter requests through Filters
func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqCondition {
return &ReqCondition{ctx: proxy.Ctx, filters: filters}
func (proxy *HttpProxy) OnRequest(filters ...Filter) *ReqFilterGroup {
return &ReqFilterGroup{ctx: proxy.Ctx, filters: filters}
}
// OnResponse filter response through Filters
func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespCondition {
return &RespCondition{ctx: proxy.Ctx, filters: filters}
func (proxy *HttpProxy) OnResponse(filters ...Filter) *RespFilterGroup {
return &RespFilterGroup{ctx: proxy.Ctx, filters: filters}
}
// Transport get http.Transport instance
@@ -94,40 +90,6 @@ func hijacker(rw http.ResponseWriter) (conn net.Conn, err error) {
return
}
// Hop-by-hop headers. These are removed when sent to the backend.
// As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
func removeProxyHeaders(r *http.Request) {
r.RequestURI = "" // this must be reset when serving a request with the client
// If no Accept-Encoding header exists, Transport will add the headers it can accept
// and would wrap the response body with the relevant reader.
r.Header.Del("Accept-Encoding")
// RFC 2616 (section 13.5.1)
// https://www.ietf.org/rfc/rfc2616.txt
r.Header.Del("Proxy-Connection")
r.Header.Del("Proxy-Authenticate")
r.Header.Del("Proxy-Authorization")
// Connection, Authenticate and Authorization are single hop Header:
// http://www.w3.org/Protocols/rfc2616/rfc2616.txt
// 14.10 Connection
// The Connection general-header field allows the sender to specify
// options that are desired for that particular connection and MUST NOT
// be communicated by proxies over further connections.
// When server reads http request it sets req.Close to true if
// "Connection" header contains "close".
// https://github.com/golang/go/blob/master/src/net/http/request.go#L1080
// Later, transfer.go adds "Connection: close" back when req.Close is true
// https://github.com/golang/go/blob/master/src/net/http/transfer.go#L275
// That's why tests that checks "Connection: close" removal fail
if r.Header.Get("Connection") == "close" {
r.Close = false
}
r.Header.Del("Connection")
}
func copyHeaders(dst, src http.Header, keepDestHeaders bool) {
if !keepDestHeaders {
for k := range dst {

View File

@@ -1,9 +1,9 @@
package mps
import (
"bytes"
"github.com/stretchr/testify/assert"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"net/url"
@@ -12,8 +12,14 @@ import (
func NewTestServer() *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
query := req.URL.Query()
text := []byte("hello world")
if query.Get("text") != "" {
text = []byte(query.Get("text"))
}
rw.Header().Set("Server", "MPS proxy server")
rw.Write([]byte("hello world"))
_, _ = rw.Write(text)
}))
}
@@ -57,8 +63,15 @@ func TestMiddlewareFunc(t *testing.T) {
proxy := NewHttpProxy()
// use Middleware
proxy.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
log.Println(req.URL.String())
return ctx.Next(req)
resp, err := ctx.Next(req)
if err != nil {
return nil, err
}
var buf bytes.Buffer
buf.WriteString("middleware")
resp.Body = ioutil.NopCloser(&buf)
return resp, nil
})
proxySrv := httptest.NewServer(proxy)
defer proxySrv.Close()
@@ -77,5 +90,5 @@ func TestMiddlewareFunc(t *testing.T) {
asserts := assert.New(t)
asserts.Equal(resp.StatusCode, 200)
asserts.Equal(int64(len(body)), resp.ContentLength)
log.Println(string(body))
asserts.Equal(string(body), "middleware")
}

View File

@@ -21,7 +21,7 @@ type Middleware interface {
// A wrapper that would convert a function to a Middleware interface type
type MiddlewareFunc func(req *http.Request, ctx *Context) (*http.Response, error)
// MiddlewareFunc.Handle(req, ctx) <=> MiddlewareFunc(req, ctx)
// Middleware.Handle(req, ctx) <=> MiddlewareFunc(req, ctx)
func (f MiddlewareFunc) Handle(req *http.Request, ctx *Context) (*http.Response, error) {
return f(req, ctx)
}

94
middleware/basicAuth.go Normal file
View File

@@ -0,0 +1,94 @@
package middleware
import (
"bytes"
"encoding/base64"
"github.com/telanflow/mps"
"io/ioutil"
"net/http"
"strings"
)
// proxy Authorization header
const proxyAuthorization = "Proxy-Authorization"
// BasicAuth returns a HTTP Basic Authentication middleware for requests
// You probably want to use mps.BasicAuth(proxy) to enable authentication for all proxy activities
func BasicAuth(realm string, fn func(username, password string) bool) mps.MiddlewareFunc {
return func(req *http.Request, ctx *mps.Context) (*http.Response, error) {
auth := req.Header.Get(proxyAuthorization)
if auth == "" {
return BasicUnauthorized(req, realm), nil
}
// parses an Basic Authentication string.
usr, pwd, ok := parseBasicAuth(auth)
if !ok {
return BasicUnauthorized(req, realm), nil
}
if !fn(usr, pwd) {
return BasicUnauthorized(req, realm), nil
}
// Authorization passed
return ctx.Next(req)
}
}
// SetBasicAuth sets the request's Authorization header to use HTTP
// Basic Authentication with the provided username and password.
//
// With HTTP Basic Authentication the provided username and password
// are not encrypted.
//
// Some protocols may impose additional requirements on pre-escaping the
// username and password. For instance, when used with OAuth2, both arguments
// must be URL encoded first with url.QueryEscape.
func SetBasicAuth(req *http.Request, username, password string) {
req.Header.Set(proxyAuthorization, "Basic "+basicAuth(username, password))
}
// See 2 (end of page 4) https://www.ietf.org/rfc/rfc2617.txt
// "To receive authorization, the client sends the userid and password,
// separated by a single colon (":") character, within a base64
// encoded string in the credentials."
// It is not meant to be urlencoded.
func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}
// parseBasicAuth parses an HTTP Basic Authentication string.
// "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==" returns ("Aladdin", "open sesame", true).
func parseBasicAuth(auth string) (username, password string, ok bool) {
const prefix = "Basic "
// Case insensitive prefix match. See Issue 22736.
if len(auth) < len(prefix) || !strings.EqualFold(auth[:len(prefix)], prefix) {
return
}
c, err := base64.StdEncoding.DecodeString(auth[len(prefix):])
if err != nil {
return
}
cs := string(c)
s := strings.IndexByte(cs, ':')
if s < 0 {
return
}
return cs[:s], cs[s+1:], true
}
func BasicUnauthorized(req *http.Request, realm string) *http.Response {
const unauthorizedMsg = "407 Proxy Authentication Required"
// verify realm is well formed
return &http.Response{
StatusCode: 407,
ProtoMajor: 1,
ProtoMinor: 1,
Request: req,
Header: http.Header{
"Proxy-Authenticate": []string{"Basic realm=" + realm},
"Proxy-Connection": []string{"close"},
},
Body: ioutil.NopCloser(bytes.NewBuffer([]byte(unauthorizedMsg))),
ContentLength: int64(len(unauthorizedMsg)),
}
}

View File

@@ -66,11 +66,25 @@ func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) {
}
// Standard net/http function. You can use it alone
func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// Execution middleware
ctx := mitm.Ctx.WithRequest(r)
resp, err := ctx.Next(r)
if err != nil && err != MethodNotSupportErr {
if resp != nil {
copyHeaders(rw.Header(), resp.Header, mitm.Ctx.KeepDestinationHeaders)
rw.WriteHeader(resp.StatusCode)
buf := mitm.buffer().Get()
_, err = io.CopyBuffer(rw, resp.Body, buf)
mitm.buffer().Put(buf)
}
return
}
// get hijacker connection
proxyClient, err := hijacker(w)
proxyClient, err := hijacker(rw)
if err != nil {
http.Error(w, err.Error(), 502)
http.Error(rw, err.Error(), 502)
return
}
@@ -91,7 +105,7 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rawClientTls := tls.Server(proxyClient, tlsConfig)
if err := rawClientTls.Handshake(); err != nil {
ConnError(proxyClient)
//ctx.Warnf("Cannot handshake client %v %v", r.Host, err)
_ = rawClientTls.Close()
return
}
defer rawClientTls.Close()
@@ -100,7 +114,6 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for !isEof(clientTlsReader) {
req, err := http.ReadRequest(clientTlsReader)
if err != nil {
//ctx.Warnf("Cannot read TLS request from mitm'd client %v %v", r.Host, err)
break
}
@@ -115,23 +128,14 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
// Copying a Context preserves the Transport, Middleware
ctx := mitm.Ctx.Copy()
ctx.Request = req
// In some cases it is not always necessary to remove the Proxy Header.
// For example, cascade proxy
if !mitm.Ctx.KeepHeader {
removeProxyHeaders(req)
}
var resp *http.Response
// Copying a Context preserves the Transport, Middleware
ctx := mitm.Ctx.WithRequest(req)
resp, err = ctx.Next(req)
if err != nil {
//ctx.Warnf("Cannot read TLS response from mitm'd server %v", err)
return
}
defer resp.Body.Close()
status := resp.Status
statusCode := strconv.Itoa(resp.StatusCode) + " "
@@ -155,12 +159,12 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
err = resp.Header.Write(rawClientTls)
if err != nil {
//ctx.Warnf("Cannot write TLS response header from mitm'd client: %v", err)
resp.Body.Close()
return
}
_, err = io.WriteString(rawClientTls, "\r\n")
if err != nil {
//ctx.Warnf("Cannot write TLS response header end from mitm'd client: %v", err)
resp.Body.Close()
return
}
@@ -170,15 +174,17 @@ func (mitm *MitmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
_, err = io.CopyBuffer(chunked, resp.Body, buf)
mitm.buffer().Put(buf)
if err != nil {
//ctx.Warnf("Cannot write TLS response body from mitm'd client: %v", err)
resp.Body.Close()
return
}
// closed response body
resp.Body.Close()
if err := chunked.Close(); err != nil {
//ctx.Warnf("Cannot write TLS chunked EOF from mitm'd client: %v", err)
return
}
if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil {
//ctx.Warnf("Cannot write TLS response chunked trailer from mitm'd client: %v", err)
return
}
}
@@ -196,13 +202,13 @@ func (mitm *MitmHandler) UseFunc(fus ...MiddlewareFunc) {
}
// OnRequest filter requests through Filters
func (mitm *MitmHandler) OnRequest(filters ...Filter) *ReqCondition {
return &ReqCondition{ctx: mitm.Ctx, filters: filters}
func (mitm *MitmHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
return &ReqFilterGroup{ctx: mitm.Ctx, filters: filters}
}
// OnResponse filter response through Filters
func (mitm *MitmHandler) OnResponse(filters ...Filter) *RespCondition {
return &RespCondition{ctx: mitm.Ctx, filters: filters}
func (mitm *MitmHandler) OnResponse(filters ...Filter) *RespFilterGroup {
return &RespFilterGroup{ctx: mitm.Ctx, filters: filters}
}
// Get buffer pool
@@ -213,6 +219,14 @@ func (mitm *MitmHandler) buffer() httputil.BufferPool {
return pool.DefaultBuffer
}
// Get cert.Container instance
func (mitm *MitmHandler) certContainer() cert.Container {
if mitm.CertContainer != nil {
return mitm.CertContainer
}
return cert.DefaultMemProvider
}
// Transport
func (mitm *MitmHandler) Transport() *http.Transport {
return mitm.Ctx.Transport
@@ -222,7 +236,7 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
host = stripPort(host)
// Returned existing certificate for the host
crt, err := mitm.CertContainer.Get(host)
crt, err := mitm.certContainer().Get(host)
if err == nil && crt != nil {
return &tls.Config{
InsecureSkipVerify: true,
@@ -238,7 +252,7 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
}
// Set certificate to container
mitm.CertContainer.Set(host, crt)
_ = mitm.certContainer().Set(host, crt)
return &tls.Config{
InsecureSkipVerify: true,
@@ -246,6 +260,7 @@ func (mitm *MitmHandler) TLSConfigFromCA(host string) (*tls.Config, error) {
}, nil
}
// sign host
func signHost(ca tls.Certificate, hosts []string) (cert *tls.Certificate, err error) {
// Use the provided ca for certificate generation.
var x509ca *x509.Certificate

View File

@@ -3,11 +3,15 @@ package pool
import "time"
var DefaultConnOptions = &ConnOptions{
IdleMaxCap: 20,
Timeout: time.Minute,
IdleMaxCap: 30,
Timeout: 90 * time.Second,
}
// ConnOptions is ConnProvider options
type ConnOptions struct {
// IdleMaxCap is max connection capacity for a single net.Addr
IdleMaxCap int
Timeout time.Duration
// Timeout specifies how long the connection will timeout
Timeout time.Duration
}

View File

@@ -10,6 +10,9 @@ import (
"time"
)
var DefaultConnProvider = NewConnProvider(DefaultConnOptions)
// ConnProvider is a connection pool, it implements ConnContainer
type ConnProvider struct {
mu sync.RWMutex
idleConnMap map[string]chan net.Conn
@@ -17,6 +20,7 @@ type ConnProvider struct {
closed int32
}
// Create a ConnProvider
func NewConnProvider(opt *ConnOptions) *ConnProvider {
return &ConnProvider{
options: opt,
@@ -42,6 +46,7 @@ func (p *ConnProvider) Get(addr string) (net.Conn, error) {
RETRY:
select {
case conn := <-p.idleConnMap[addr]:
// Getting a net.Conn requires verifying that the net.Conn is valid
_, err := conn.Read([]byte{})
if err != nil || err == io.EOF {
// conn is close Or timeout
@@ -58,6 +63,8 @@ RETRY:
func (p *ConnProvider) Put(conn net.Conn) error {
closed := atomic.LoadInt32(&p.closed)
if closed == 1 {
// pool is closed, this conn must be closed
conn.Close()
return errors.New("pool is closed")
}
@@ -70,13 +77,13 @@ func (p *ConnProvider) Put(conn net.Conn) error {
p.mu.Unlock()
// set conn timeout
// The timeout will be verified at the next `Get()`
err := conn.SetDeadline(time.Now().Add(p.options.Timeout))
if err != nil {
_ = conn.Close()
return err
}
// set idle conn
select {
case p.idleConnMap[addr] <- conn:
return nil
@@ -86,6 +93,7 @@ func (p *ConnProvider) Put(conn net.Conn) error {
}
}
// Release connection pool
func (p *ConnProvider) Release() error {
closed := atomic.LoadInt32(&p.closed)
if closed == 1 {
@@ -93,7 +101,6 @@ func (p *ConnProvider) Release() error {
}
atomic.StoreInt32(&p.closed, 1)
for _, connChan := range p.idleConnMap {
close(connChan)
for conn, ok := <-connChan; ok; {

View File

@@ -4,16 +4,17 @@ import (
"net/http"
)
type ReqCondition struct {
// ReqCondition is a request condition group
type ReqFilterGroup struct {
ctx *Context
filters []Filter
}
func (cond *ReqCondition) DoFunc(fn func(req *http.Request) (*http.Request, *http.Response)) {
func (cond *ReqFilterGroup) DoFunc(fn func(req *http.Request, ctx *Context) (*http.Request, *http.Response)) {
cond.Do(RequestHandleFunc(fn))
}
func (cond *ReqCondition) Do(fn RequestHandle) {
func (cond *ReqFilterGroup) Do(h RequestHandle) {
cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
total := len(cond.filters)
for i := 0; i < total; i++ {
@@ -22,7 +23,7 @@ func (cond *ReqCondition) Do(fn RequestHandle) {
}
}
req, resp := fn.Handle(req)
req, resp := h.Handle(req, ctx)
if resp != nil {
return resp, nil
}

View File

@@ -1,13 +0,0 @@
package mps
import "net/http"
type RequestHandle interface {
Handle(req *http.Request) (*http.Request, *http.Response)
}
type RequestHandleFunc func(req *http.Request) (*http.Request, *http.Response)
func (f RequestHandleFunc) Handle(req *http.Request) (*http.Request, *http.Response) {
return f(req)
}

View File

@@ -4,16 +4,16 @@ import (
"net/http"
)
type RespCondition struct {
type RespFilterGroup struct {
ctx *Context
filters []Filter
}
func (cond *RespCondition) DoFunc(fn func(resp *http.Response) (*http.Response, error)) {
func (cond *RespFilterGroup) DoFunc(fn func(resp *http.Response, ctx *Context) (*http.Response, error)) {
cond.Do(ResponseHandleFunc(fn))
}
func (cond *RespCondition) Do(fn ResponseHandle) {
func (cond *RespFilterGroup) Do(h ResponseHandle) {
cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) {
total := len(cond.filters)
for i := 0; i < total; i++ {
@@ -27,6 +27,6 @@ func (cond *RespCondition) Do(fn ResponseHandle) {
return nil, err
}
return fn.Handle(resp)
return h.Handle(resp, ctx)
})
}

View File

@@ -1,13 +0,0 @@
package mps
import "net/http"
type ResponseHandle interface {
Handle(resp *http.Response) (*http.Response, error)
}
type ResponseHandleFunc func(resp *http.Response) (*http.Response, error)
func (f ResponseHandleFunc) Handle(resp *http.Response) (*http.Response, error) {
return f(resp)
}

View File

@@ -15,6 +15,7 @@ type ReverseHandler struct {
BufferPool httputil.BufferPool
}
// Create a ReverseHandler
func NewReverseHandler() *ReverseHandler {
return &ReverseHandler{
Ctx: NewContext(),
@@ -25,9 +26,7 @@ func NewReverseHandler() *ReverseHandler {
// Standard net/http function. You can use it alone
func (reverse *ReverseHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Copying a Context preserves the Transport, Middleware
ctx := reverse.Ctx.Copy()
ctx.Request = req
ctx := reverse.Ctx.WithRequest(req)
resp, err := ctx.Next(req)
if err != nil {
http.Error(rw, err.Error(), 502)
@@ -76,13 +75,13 @@ func (reverse *ReverseHandler) UseFunc(fus ...MiddlewareFunc) {
}
// OnRequest filter requests through Filters
func (reverse *ReverseHandler) OnRequest(filters ...Filter) *ReqCondition {
return &ReqCondition{ctx: reverse.Ctx, filters: filters}
func (reverse *ReverseHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
return &ReqFilterGroup{ctx: reverse.Ctx, filters: filters}
}
// OnResponse filter response through Filters
func (reverse *ReverseHandler) OnResponse(filters ...Filter) *RespCondition {
return &RespCondition{ctx: reverse.Ctx, filters: filters}
func (reverse *ReverseHandler) OnResponse(filters ...Filter) *RespFilterGroup {
return &RespFilterGroup{ctx: reverse.Ctx, filters: filters}
}
// Get buffer pool

24
transport.go Normal file
View File

@@ -0,0 +1,24 @@
package mps
import (
"crypto/tls"
"net"
"net/http"
"time"
)
// Default http.Transport option
var DefaultTransport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: 15 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
Proxy: http.ProxyFromEnvironment,
}

View File

@@ -9,6 +9,7 @@ import (
"net/http/httputil"
"net/url"
"regexp"
"time"
)
var (
@@ -27,23 +28,35 @@ type TunnelHandler struct {
// Create a tunnel handler
func NewTunnelHandler() *TunnelHandler {
return &TunnelHandler{
Ctx: NewContext(),
BufferPool: pool.DefaultBuffer,
ConnContainer: pool.NewConnProvider(pool.DefaultConnOptions),
Ctx: NewContext(),
BufferPool: pool.DefaultBuffer,
}
}
// Create a tunnel handler with Context
func NewTunnelHandlerWithContext(ctx *Context) *TunnelHandler {
return &TunnelHandler{
Ctx: ctx,
BufferPool: pool.DefaultBuffer,
ConnContainer: pool.NewConnProvider(pool.DefaultConnOptions),
Ctx: ctx,
BufferPool: pool.DefaultBuffer,
}
}
// Standard net/http function. You can use it alone
func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Execution middleware
ctx := tunnel.Ctx.WithRequest(req)
resp, err := ctx.Next(req)
if err != nil && err != MethodNotSupportErr {
if resp != nil {
copyHeaders(rw.Header(), resp.Header, tunnel.Ctx.KeepDestinationHeaders)
rw.WriteHeader(resp.StatusCode)
buf := tunnel.buffer().Get()
_, err = io.CopyBuffer(rw, resp.Body, buf)
tunnel.buffer().Put(buf)
}
return
}
// hijacker connection
proxyClient, err := hijacker(rw)
if err != nil {
@@ -71,27 +84,19 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
}
// connect to targetAddr
if tunnel.ConnContainer != nil {
targetConn, err = tunnel.ConnContainer.Get(targetAddr)
if err != nil {
targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
}
} else {
targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
}
targetConn, err = tunnel.connContainer().Get(targetAddr)
if err != nil {
ConnError(proxyClient)
return
targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
if err != nil {
ConnError(proxyClient)
return
}
}
// If the ConnContainer is exists,
// When io.CopyBuffer is complete,
// put the idle connection into the ConnContainer so can reuse it next time
if tunnel.ConnContainer != nil {
defer tunnel.ConnContainer.Put(targetConn)
} else {
defer targetConn.Close()
}
defer tunnel.connContainer().Put(targetConn)
// The cascade proxy needs to forward the request
if isCascadeProxy {
@@ -115,16 +120,9 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request
func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) {
if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.DialContext != nil {
return tunnel.Ctx.Transport.DialContext(tunnel.Context(), network, addr)
return tunnel.Ctx.Transport.DialContext(tunnel.context(), network, addr)
}
return net.Dial(network, addr)
}
func (tunnel *TunnelHandler) Context() context.Context {
if tunnel.Ctx.Context != nil {
return tunnel.Ctx.Context
}
return context.Background()
return net.DialTimeout(network, addr, 30*time.Second)
}
// Transport
@@ -132,6 +130,14 @@ func (tunnel *TunnelHandler) Transport() *http.Transport {
return tunnel.Ctx.Transport
}
// get a context.Context
func (tunnel *TunnelHandler) context() context.Context {
if tunnel.Ctx.Context != nil {
return tunnel.Ctx.Context
}
return context.Background()
}
// Get buffer pool
func (tunnel *TunnelHandler) buffer() httputil.BufferPool {
if tunnel.BufferPool != nil {
@@ -140,6 +146,14 @@ func (tunnel *TunnelHandler) buffer() httputil.BufferPool {
return pool.DefaultBuffer
}
// Get a conn pool
func (tunnel *TunnelHandler) connContainer() pool.ConnContainer {
if tunnel.ConnContainer != nil {
return tunnel.ConnContainer
}
return pool.DefaultConnProvider
}
func hostAndPort(addr string) string {
if !hasPort.MatchString(addr) {
addr += ":80"