From a8616bbe9858a0085a75420b3fa3871b719f2722 Mon Sep 17 00:00:00 2001 From: telanflow Date: Thu, 13 Aug 2020 16:17:33 +0800 Subject: [PATCH] mitm proxy examples --- context.go | 2 + examples/mitm-proxy/main.go | 69 ++++++++++++++++++++++++++++++ examples/reverse-proxy/main.go | 8 ++-- examples/simple-http-proxy/main.go | 10 ++--- filter_group.go | 58 +++++++++++++++++++++++++ handle.go | 13 ++++-- http_proxy.go | 27 ++++++++---- mitm_handler.go | 28 +++++++++++- req_filter_group.go | 33 -------------- resp_filter_group.go | 27 ------------ tunnel_handler.go | 2 +- 11 files changed, 193 insertions(+), 84 deletions(-) create mode 100644 examples/mitm-proxy/main.go create mode 100644 filter_group.go delete mode 100644 req_filter_group.go delete mode 100644 resp_filter_group.go diff --git a/context.go b/context.go index 7b970f7..74c425e 100644 --- a/context.go +++ b/context.go @@ -113,6 +113,8 @@ func (ctx *Context) Next(req *http.Request) (*http.Response, error) { ctx.mi++ if ctx.mi >= total { ctx.mi = -1 + // Final request coverage + ctx.Request = req // To make the middleware available to the tunnel proxy, // no response is obtained when the request method is equal to Connect if req.Method == http.MethodConnect { diff --git a/examples/mitm-proxy/main.go b/examples/mitm-proxy/main.go new file mode 100644 index 0000000..2a5cedc --- /dev/null +++ b/examples/mitm-proxy/main.go @@ -0,0 +1,69 @@ +package main + +import ( + "errors" + "github.com/telanflow/mps" + "log" + "net/http" + "os" + "os/signal" + "regexp" + "syscall" +) + +// A simple mitm proxy server +func main() { + quitSignChan := make(chan os.Signal) + + // create proxy server + proxy := mps.NewHttpProxy() + + // The Connect request is processed using MitmHandler + proxy.HandleConnect = mps.NewMitmHandlerWithContext(proxy.Ctx) + + // Middleware + proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { + log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) + return ctx.Next(req) + }) + + // Filter + reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$"))) + reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { + log.Printf("[INFO] req -- %s %s", req.Method, req.URL) + return req, nil + }) + respGroup := proxy.OnResponse() + respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { + if err != nil { + log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) + return resp, err + } + log.Printf("[INFO] resp -- %d", resp.StatusCode) + return resp, err + }) + + // Started proxy server + srv := http.Server{ + Addr: "localhost:8080", + Handler: proxy, + } + go func() { + log.Printf("MitmProxy started listen: http://%s", srv.Addr) + err := srv.ListenAndServe() + if errors.Is(err, http.ErrServerClosed) { + return + } + if err != nil { + quitSignChan <- syscall.SIGKILL + log.Fatalf("MitmProxy start fail: %v", err) + } + }() + + // quit signal + signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT) + + <-quitSignChan + _ = srv.Close() + log.Fatal("MitmProxy server stop!") +} diff --git a/examples/reverse-proxy/main.go b/examples/reverse-proxy/main.go index e599862..ff1a04b 100644 --- a/examples/reverse-proxy/main.go +++ b/examples/reverse-proxy/main.go @@ -14,23 +14,23 @@ import ( // A simple reverse proxy server func main() { - targetHost, _ := url.Parse("https://www.google.com") + targetURL, _ := url.Parse("https://www.google.com") quitSignChan := make(chan os.Signal) // reverse proxy server proxy := mps.NewReverseHandler() - proxy.UseFunc(middleware.SingleHostReverseProxy(targetHost)) + proxy.UseFunc(middleware.SingleHostReverseProxy(targetURL)) reqGroup := proxy.OnRequest() reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { - log.Printf("[INFO] req -- %s", req.Host) + log.Printf("[INFO] req -- %s %s", req.Method, req.Host) return req, nil }) respGroup := proxy.OnResponse() respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { if err != nil { - log.Printf("[ERRO] resp -- %v", err) + log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) return nil, err } log.Printf("[INFO] resp -- %d", resp.StatusCode) diff --git a/examples/simple-http-proxy/main.go b/examples/simple-http-proxy/main.go index 043b91f..296f812 100644 --- a/examples/simple-http-proxy/main.go +++ b/examples/simple-http-proxy/main.go @@ -18,30 +18,30 @@ func main() { // create a http proxy server proxy := mps.NewHttpProxy() proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { - log.Printf("[INFO] middleware -- %s\n", req.URL) + log.Printf("[INFO] middleware -- %s %s", req.Method, req.URL) return ctx.Next(req) }) reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$"))) reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { - log.Printf("[INFO] req -- %s\n", req.URL) + log.Printf("[INFO] req -- %s %s", req.Method, req.URL) return req, nil }) respGroup := proxy.OnResponse() respGroup.DoFunc(func(resp *http.Response, err error, ctx *mps.Context) (*http.Response, error) { if err != nil { - log.Printf("[ERRO] resp -- %v\n", err) + log.Printf("[ERRO] resp -- %s %v", ctx.Request.Method, err) return resp, err } - log.Printf("[INFO] resp -- %d\n", resp.StatusCode) + log.Printf("[INFO] resp -- %d", resp.StatusCode) return resp, err }) // Start server srv := &http.Server{ - Addr: "127.0.0.1:8081", + Addr: "localhost:8080", Handler: proxy, } go func() { diff --git a/filter_group.go b/filter_group.go new file mode 100644 index 0000000..edd203b --- /dev/null +++ b/filter_group.go @@ -0,0 +1,58 @@ +package mps + +import "net/http" + +type FilterGroup interface { + Handle() +} + +// ReqCondition is a request filter group +type ReqFilterGroup struct { + ctx *Context + filters []Filter +} + +func (cond *ReqFilterGroup) DoFunc(fn func(req *http.Request, ctx *Context) (*http.Request, *http.Response)) { + cond.Do(RequestHandleFunc(fn)) +} + +func (cond *ReqFilterGroup) Do(h RequestHandle) { + cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { + total := len(cond.filters) + for i := 0; i < total; i++ { + if !cond.filters[i].Match(req) { + return ctx.Next(req) + } + } + + req, resp := h.HandleRequest(req, ctx) + if resp != nil { + return resp, nil + } + + return ctx.Next(req) + }) +} + +// ReqCondition is a response filter group +type RespFilterGroup struct { + ctx *Context + filters []Filter +} + +func (cond *RespFilterGroup) DoFunc(fn func(resp *http.Response, err error, ctx *Context) (*http.Response, error)) { + cond.Do(ResponseHandleFunc(fn)) +} + +func (cond *RespFilterGroup) Do(h ResponseHandle) { + cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { + total := len(cond.filters) + for i := 0; i < total; i++ { + if !cond.filters[i].Match(req) { + return ctx.Next(req) + } + } + resp, err := ctx.Next(req) + return h.HandleResponse(resp, err, ctx) + }) +} diff --git a/handle.go b/handle.go index 8c229db..f055735 100644 --- a/handle.go +++ b/handle.go @@ -2,26 +2,31 @@ package mps import "net/http" +type Handle interface { + RequestHandle + ResponseHandle +} + type RequestHandle interface { - Handle(req *http.Request, ctx *Context) (*http.Request, *http.Response) + HandleRequest(req *http.Request, ctx *Context) (*http.Request, *http.Response) } // A wrapper that would convert a function to a RequestHandle interface type type RequestHandleFunc func(req *http.Request, ctx *Context) (*http.Request, *http.Response) // RequestHandle.Handle(req, ctx) <=> RequestHandleFunc(req, ctx) -func (f RequestHandleFunc) Handle(req *http.Request, ctx *Context) (*http.Request, *http.Response) { +func (f RequestHandleFunc) HandleRequest(req *http.Request, ctx *Context) (*http.Request, *http.Response) { return f(req, ctx) } type ResponseHandle interface { - Handle(resp *http.Response, err error, ctx *Context) (*http.Response, error) + HandleResponse(resp *http.Response, err error, ctx *Context) (*http.Response, error) } // A wrapper that would convert a function to a ResponseHandle interface type type ResponseHandleFunc func(resp *http.Response, err error, ctx *Context) (*http.Response, error) // ResponseHandle.Handle(resp, ctx) <=> ResponseHandleFunc(resp, ctx) -func (f ResponseHandleFunc) Handle(resp *http.Response, err error, ctx *Context) (*http.Response, error) { +func (f ResponseHandleFunc) HandleResponse(resp *http.Response, err error, ctx *Context) (*http.Response, error) { return f(resp, err, ctx) } diff --git a/http_proxy.go b/http_proxy.go index c17edae..364dd31 100644 --- a/http_proxy.go +++ b/http_proxy.go @@ -9,13 +9,13 @@ import ( // The basic proxy type. Implements http.Handler. type HttpProxy struct { - // HTTPS requests use the TunnelHandler proxy by default - HttpsHandler http.Handler + // Handles Connect requests use the TunnelHandler by default + HandleConnect http.Handler - // HTTP requests use the ForwardHandler proxy by default + // HTTP requests use the ForwardHandler by default HttpHandler http.Handler - // HTTP requests use the ReverseHandler proxy by default + // HTTP requests use the ReverseHandler by default ReverseHandler http.Handler // Client request Context @@ -27,10 +27,10 @@ func NewHttpProxy() *HttpProxy { ctx := NewContext() return &HttpProxy{ Ctx: ctx, - // default HTTP proxy + // default handles Connect method + HandleConnect: &TunnelHandler{Ctx: ctx}, + // default handles HTTP request HttpHandler: &ForwardHandler{Ctx: ctx}, - // default HTTPS proxy - HttpsHandler: &TunnelHandler{Ctx: ctx}, // default Reverse proxy ReverseHandler: &ReverseHandler{Ctx: ctx}, } @@ -39,10 +39,21 @@ func NewHttpProxy() *HttpProxy { // Standard net/http function. func (proxy *HttpProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.Method == http.MethodConnect { - proxy.HttpsHandler.ServeHTTP(rw, req) + proxy.HandleConnect.ServeHTTP(rw, req) return } + // reverse proxy http request for example: + // GET / HTTP/1.1 + // Host: www.example.com + // Connection: keep-alive + // + // forward proxy http request for example : + // GET http://www.example.com/ HTTP/1.1 + // Host: www.example.com + // Proxy-Connection: keep-alive + // + // Determines whether the path is absolute if !req.URL.IsAbs() { proxy.ReverseHandler.ServeHTTP(rw, req) } else { diff --git a/mitm_handler.go b/mitm_handler.go index 066c529..62f56db 100644 --- a/mitm_handler.go +++ b/mitm_handler.go @@ -51,8 +51,32 @@ func NewMitmHandler() *MitmHandler { } } +// Create a MitmHandler, use default cert. +func NewMitmHandlerWithContext(ctx *Context) *MitmHandler { + return &MitmHandler{ + Ctx: ctx, + BufferPool: pool.DefaultBuffer, + Certificate: cert.DefaultCertificate, + CertContainer: cert.NewMemProvider(), + } +} + +// Create a MitmHandler with cert pem block +func NewMitmHandlerWithCert(certPEMBlock, keyPEMBlock []byte) (*MitmHandler, error) { + certificate, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) + if err != nil { + return nil, err + } + return &MitmHandler{ + Ctx: NewContext(), + BufferPool: pool.DefaultBuffer, + Certificate: certificate, + CertContainer: cert.NewMemProvider(), + }, nil +} + // Create a MitmHandler with cert file -func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) { +func NewMitmHandlerWithCertFile(certFile, keyFile string) (*MitmHandler, error) { certificate, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err @@ -67,7 +91,7 @@ func NewMitmHandlerWithCert(certFile, keyFile string) (*MitmHandler, error) { // Standard net/http function. You can use it alone func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - // Execution middleware + // execution middleware ctx := mitm.Ctx.WithRequest(r) resp, err := ctx.Next(r) if err != nil && err != MethodNotSupportErr { diff --git a/req_filter_group.go b/req_filter_group.go deleted file mode 100644 index 90b415a..0000000 --- a/req_filter_group.go +++ /dev/null @@ -1,33 +0,0 @@ -package mps - -import ( - "net/http" -) - -// ReqCondition is a request condition group -type ReqFilterGroup struct { - ctx *Context - filters []Filter -} - -func (cond *ReqFilterGroup) DoFunc(fn func(req *http.Request, ctx *Context) (*http.Request, *http.Response)) { - cond.Do(RequestHandleFunc(fn)) -} - -func (cond *ReqFilterGroup) Do(h RequestHandle) { - cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { - total := len(cond.filters) - for i := 0; i < total; i++ { - if !cond.filters[i].Match(req) { - return ctx.Next(req) - } - } - - req, resp := h.Handle(req, ctx) - if resp != nil { - return resp, nil - } - - return ctx.Next(req) - }) -} diff --git a/resp_filter_group.go b/resp_filter_group.go deleted file mode 100644 index 63d83f2..0000000 --- a/resp_filter_group.go +++ /dev/null @@ -1,27 +0,0 @@ -package mps - -import ( - "net/http" -) - -type RespFilterGroup struct { - ctx *Context - filters []Filter -} - -func (cond *RespFilterGroup) DoFunc(fn func(resp *http.Response, err error, ctx *Context) (*http.Response, error)) { - cond.Do(ResponseHandleFunc(fn)) -} - -func (cond *RespFilterGroup) Do(h ResponseHandle) { - cond.ctx.UseFunc(func(req *http.Request, ctx *Context) (*http.Response, error) { - total := len(cond.filters) - for i := 0; i < total; i++ { - if !cond.filters[i].Match(req) { - return ctx.Next(req) - } - } - resp, err := ctx.Next(req) - return h.Handle(resp, err, ctx) - }) -} diff --git a/tunnel_handler.go b/tunnel_handler.go index a8ce055..fbef6a9 100644 --- a/tunnel_handler.go +++ b/tunnel_handler.go @@ -43,7 +43,7 @@ func NewTunnelHandlerWithContext(ctx *Context) *TunnelHandler { // Standard net/http function. You can use it alone func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - // Execution middleware + // execution middleware ctx := tunnel.Ctx.WithRequest(req) resp, err := ctx.Next(req) if err != nil && err != MethodNotSupportErr {