diff --git a/examples/simple-http-proxy/main.go b/examples/simple-http-proxy/main.go new file mode 100644 index 0000000..f383308 --- /dev/null +++ b/examples/simple-http-proxy/main.go @@ -0,0 +1,60 @@ +package main + +import ( + "errors" + "github.com/telanflow/mps" + "log" + "net/http" + "os" + "os/signal" + "regexp" + "syscall" +) + +// A simple http proxy server +func main() { + quitSignChan := make(chan os.Signal) + + // create a http proxy server + proxy := mps.NewHttpProxy() + proxy.UseFunc(func(req *http.Request, ctx *mps.Context) (*http.Response, error) { + log.Printf("[INFO] middleware -- %s", req.URL) + return ctx.Next(req) + }) + + reqGroup := proxy.OnRequest(mps.FilterHostMatches(regexp.MustCompile("^.*$"))) + reqGroup.DoFunc(func(req *http.Request, ctx *mps.Context) (*http.Request, *http.Response) { + log.Printf("[INFO] req -- %s\n", req.URL) + return req, nil + }) + + respGroup := proxy.OnResponse() + respGroup.DoFunc(func(resp *http.Response, ctx *mps.Context) (*http.Response, error) { + log.Printf("[INFO] resp -- %d\n", resp.StatusCode) + return resp, nil + }) + + // Start server + srv := &http.Server{ + Addr: "127.0.0.1:8081", + Handler: proxy, + } + go func() { + log.Printf("HttpProxy started listen: http://%s", srv.Addr) + err := srv.ListenAndServe() + if errors.Is(err, http.ErrServerClosed) { + return + } + if err != nil { + quitSignChan <- syscall.SIGKILL + log.Fatalf("HttpProxy start fail: %v", err) + } + }() + + // quit signal + signal.Notify(quitSignChan, syscall.SIGINT, syscall.SIGKILL, syscall.SIGTERM, syscall.SIGQUIT) + + <-quitSignChan + _ = srv.Close() + log.Fatal("HttpProxy server stop!") +} diff --git a/mitm_handler.go b/mitm_handler.go index 88335de..066c529 100644 --- a/mitm_handler.go +++ b/mitm_handler.go @@ -124,7 +124,6 @@ func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) } if err != nil { - //ctx.Warnf("Illegal URL %s", "https://"+r.Host+req.URL.Path) return } @@ -145,7 +144,6 @@ func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // always use 1.1 to support chunked encoding if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+status+"\r\n"); err != nil { - //ctx.Warnf("Cannot write TLS response HTTP status from mitm'd client: %v", err) return } diff --git a/pool/conn_container.go b/pool/conn_container.go index 3a88dc4..fbe17ac 100644 --- a/pool/conn_container.go +++ b/pool/conn_container.go @@ -2,7 +2,14 @@ package pool import "net" +// ConnContainer connection pool interface type ConnContainer interface { + // Get returned a idle net.Conn Get(addr string) (net.Conn, error) + + // Put place a idle net.Conn into the pool Put(conn net.Conn) error + + // Release connection pool + Release() error } diff --git a/pool/conn_provider.go b/pool/conn_provider.go index d840df3..732925f 100644 --- a/pool/conn_provider.go +++ b/pool/conn_provider.go @@ -63,8 +63,6 @@ RETRY: func (p *ConnProvider) Put(conn net.Conn) error { closed := atomic.LoadInt32(&p.closed) if closed == 1 { - // pool is closed, this conn must be closed - conn.Close() return errors.New("pool is closed") } @@ -80,7 +78,6 @@ func (p *ConnProvider) Put(conn net.Conn) error { // The timeout will be verified at the next `Get()` err := conn.SetDeadline(time.Now().Add(p.options.Timeout)) if err != nil { - _ = conn.Close() return err } @@ -88,8 +85,7 @@ func (p *ConnProvider) Put(conn net.Conn) error { case p.idleConnMap[addr] <- conn: return nil default: - err := conn.Close() - return fmt.Errorf("beyond max capacity. conn closed: %v", err) + return fmt.Errorf("beyond max capacity") } } diff --git a/tunnel_handler.go b/tunnel_handler.go index f90bcc9..a8ce055 100644 --- a/tunnel_handler.go +++ b/tunnel_handler.go @@ -96,7 +96,13 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request // If the ConnContainer is exists, // When io.CopyBuffer is complete, // put the idle connection into the ConnContainer so can reuse it next time - defer tunnel.connContainer().Put(targetConn) + defer func() { + err := tunnel.connContainer().Put(targetConn) + if err != nil { + // put conn fail, conn must be closed + _ = targetConn.Close() + } + }() // The cascade proxy needs to forward the request if isCascadeProxy { @@ -118,6 +124,26 @@ func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request tunnel.buffer().Put(buf) } +// Use registers an Middleware to proxy +func (tunnel *TunnelHandler) Use(middleware ...Middleware) { + tunnel.Ctx.Use(middleware...) +} + +// UseFunc registers an MiddlewareFunc to proxy +func (tunnel *TunnelHandler) UseFunc(fus ...MiddlewareFunc) { + tunnel.Ctx.UseFunc(fus...) +} + +// OnRequest filter requests through Filters +func (tunnel *TunnelHandler) OnRequest(filters ...Filter) *ReqFilterGroup { + return &ReqFilterGroup{ctx: tunnel.Ctx, filters: filters} +} + +// OnResponse filter response through Filters +func (tunnel *TunnelHandler) OnResponse(filters ...Filter) *RespFilterGroup { + return &RespFilterGroup{ctx: tunnel.Ctx, filters: filters} +} + func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) { if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.DialContext != nil { return tunnel.Ctx.Transport.DialContext(tunnel.context(), network, addr) @@ -125,7 +151,7 @@ func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) return net.DialTimeout(network, addr, 30*time.Second) } -// Transport +// Transport get http.Transport instance func (tunnel *TunnelHandler) Transport() *http.Transport { return tunnel.Ctx.Transport }