diff --git a/mitm_handler.go b/mitm_handler.go index 6c08475..c2dd57d 100644 --- a/mitm_handler.go +++ b/mitm_handler.go @@ -28,7 +28,7 @@ import ( ) var ( - HttpMitmOk = []byte("HTTP/1.0 200 OK\r\n\r\n") + HttpMitmOk = []byte("HTTP/1.0 200 Connection Established\r\n\r\n") httpsRegexp = regexp.MustCompile("^https://") ) @@ -62,13 +62,13 @@ func NewMitmHandlerWithContext(ctx *Context) *MitmHandler { } // Create a MitmHandler with cert pem block -func NewMitmHandlerWithCert(certPEMBlock, keyPEMBlock []byte) (*MitmHandler, error) { +func NewMitmHandlerWithCert(ctx *Context, certPEMBlock, keyPEMBlock []byte) (*MitmHandler, error) { certificate, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) if err != nil { return nil, err } return &MitmHandler{ - Ctx: NewContext(), + Ctx: ctx, BufferPool: pool.DefaultBuffer, Certificate: certificate, CertContainer: cert.NewMemProvider(), @@ -76,13 +76,13 @@ func NewMitmHandlerWithCert(certPEMBlock, keyPEMBlock []byte) (*MitmHandler, err } // Create a MitmHandler with cert file -func NewMitmHandlerWithCertFile(certFile, keyFile string) (*MitmHandler, error) { +func NewMitmHandlerWithCertFile(ctx *Context, certFile, keyFile string) (*MitmHandler, error) { certificate, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } return &MitmHandler{ - Ctx: NewContext(), + Ctx: ctx, BufferPool: pool.DefaultBuffer, Certificate: certificate, CertContainer: cert.NewMemProvider(), @@ -90,10 +90,10 @@ func NewMitmHandlerWithCertFile(certFile, keyFile string) (*MitmHandler, error) } // Standard net/http function. You can use it alone -func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { +func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // execution middleware - ctx := mitm.Ctx.WithRequest(r) - resp, err := ctx.Next(r) + ctx := mitm.Ctx.WithRequest(req) + resp, err := ctx.Next(req) if err != nil && err != MethodNotSupportErr { if resp != nil { copyHeaders(rw.Header(), resp.Header, mitm.Ctx.KeepDestinationHeaders) @@ -106,7 +106,7 @@ func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { } // get hijacker connection - proxyClient, err := hijacker(rw) + clientConn, err := hijacker(rw) if err != nil { http.Error(rw, err.Error(), 502) return @@ -116,99 +116,102 @@ func (mitm *MitmHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) { // still handling the request even after hijacking the connection. Those HTTP CONNECT // request can take forever, and the server will be stuck when "closed". // TODO: Allow Server.Close() mechanism to shut down this connection as nicely as possible - tlsConfig, err := mitm.TLSConfigFromCA(r.URL.Host) + tlsConfig, err := mitm.TLSConfigFromCA(req.URL.Host) if err != nil { - ConnError(proxyClient) + ConnError(clientConn) return } - _, _ = proxyClient.Write(HttpMitmOk) + _, _ = clientConn.Write(HttpMitmOk) - go func() { - // TODO: cache connections to the remote website - rawClientTls := tls.Server(proxyClient, tlsConfig) - if err := rawClientTls.Handshake(); err != nil { - ConnError(proxyClient) - _ = rawClientTls.Close() + // data transmit + go mitm.transmit(clientConn, req, tlsConfig) +} + +func (mitm *MitmHandler) transmit(clientConn net.Conn, originalReq *http.Request, tlsConfig *tls.Config) { + // TODO: cache connections to the remote website + rawClientTls := tls.Server(clientConn, tlsConfig) + if err := rawClientTls.Handshake(); err != nil { + ConnError(clientConn) + _ = rawClientTls.Close() + return + } + defer rawClientTls.Close() + + clientTlsReader := bufio.NewReader(rawClientTls) + for !isEof(clientTlsReader) { + req, err := http.ReadRequest(clientTlsReader) + if err != nil { + break + } + + // since we're converting the request, need to carry over the original connecting IP as well + req.RemoteAddr = originalReq.RemoteAddr + + if !httpsRegexp.MatchString(req.URL.String()) { + req.URL, err = url.Parse("https://" + originalReq.Host + req.URL.String()) + } + if err != nil { return } - defer rawClientTls.Close() - clientTlsReader := bufio.NewReader(rawClientTls) - for !isEof(clientTlsReader) { - req, err := http.ReadRequest(clientTlsReader) - if err != nil { - break - } + var resp *http.Response - // since we're converting the request, need to carry over the original connecting IP as well - req.RemoteAddr = r.RemoteAddr - - if !httpsRegexp.MatchString(req.URL.String()) { - req.URL, err = url.Parse("https://" + r.Host + req.URL.String()) - } - if err != nil { - return - } - - var resp *http.Response - - // Copying a Context preserves the Transport, Middleware - ctx := mitm.Ctx.WithRequest(req) - resp, err = ctx.Next(req) - if err != nil { - return - } - - status := resp.Status - statusCode := strconv.Itoa(resp.StatusCode) + " " - if strings.HasPrefix(status, statusCode) { - status = status[len(statusCode):] - } - - // always use 1.1 to support chunked encoding - if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+status+"\r\n"); err != nil { - return - } - - // Since we don't know the length of resp, return chunked encoded response - resp.Header.Set("Transfer-Encoding", "chunked") - - // Force connection close otherwise chrome will keep CONNECT tunnel open forever - resp.Header.Set("Connection", "close") - - err = resp.Header.Write(rawClientTls) - if err != nil { - resp.Body.Close() - return - } - _, err = io.WriteString(rawClientTls, "\r\n") - if err != nil { - resp.Body.Close() - return - } - - chunked := newChunkedWriter(rawClientTls) - - buf := mitm.buffer().Get() - _, err = io.CopyBuffer(chunked, resp.Body, buf) - mitm.buffer().Put(buf) - if err != nil { - resp.Body.Close() - return - } - - // closed response body - resp.Body.Close() - - if err := chunked.Close(); err != nil { - return - } - if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { - return - } + // Copying a Context preserves the Transport, Middleware + ctx := mitm.Ctx.WithRequest(req) + resp, err = ctx.Next(req) + if err != nil { + return } - }() + + status := resp.Status + statusCode := strconv.Itoa(resp.StatusCode) + " " + if strings.HasPrefix(status, statusCode) { + status = status[len(statusCode):] + } + + // always use 1.1 to support chunked encoding + if _, err := io.WriteString(rawClientTls, "HTTP/1.1"+" "+statusCode+status+"\r\n"); err != nil { + return + } + + // Since we don't know the length of resp, return chunked encoded response + resp.Header.Set("Transfer-Encoding", "chunked") + + // Force connection close otherwise chrome will keep CONNECT tunnel open forever + resp.Header.Set("Connection", "close") + + err = resp.Header.Write(rawClientTls) + if err != nil { + resp.Body.Close() + return + } + _, err = io.WriteString(rawClientTls, "\r\n") + if err != nil { + resp.Body.Close() + return + } + + chunked := newChunkedWriter(rawClientTls) + + buf := mitm.buffer().Get() + _, err = io.CopyBuffer(chunked, resp.Body, buf) + mitm.buffer().Put(buf) + if err != nil { + resp.Body.Close() + return + } + + // closed response body + resp.Body.Close() + + if err := chunked.Close(); err != nil { + return + } + if _, err = io.WriteString(rawClientTls, "\r\n"); err != nil { + return + } + } } // Use registers an Middleware to proxy diff --git a/tunnel_handler.go b/tunnel_handler.go index fbef6a9..75d3b89 100644 --- a/tunnel_handler.go +++ b/tunnel_handler.go @@ -13,7 +13,7 @@ import ( ) var ( - HttpTunnelOk = []byte("HTTP/1.0 200 OK\r\n\r\n") + HttpTunnelOk = []byte("HTTP/1.0 200 Connection Established\r\n\r\n") HttpTunnelFail = []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n") hasPort = regexp.MustCompile(`:\d+$`) )