Files
mps/tunnel_handler.go
2020-08-12 16:09:46 +08:00

194 lines
4.7 KiB
Go

package mps
import (
"context"
"github.com/telanflow/mps/pool"
"io"
"net"
"net/http"
"net/http/httputil"
"net/url"
"regexp"
"time"
)
var (
HttpTunnelOk = []byte("HTTP/1.0 200 OK\r\n\r\n")
HttpTunnelFail = []byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")
hasPort = regexp.MustCompile(`:\d+$`)
)
// The tunnel proxy type. Implements http.Handler.
type TunnelHandler struct {
Ctx *Context
BufferPool httputil.BufferPool
ConnContainer pool.ConnContainer
}
// Create a tunnel handler
func NewTunnelHandler() *TunnelHandler {
return &TunnelHandler{
Ctx: NewContext(),
BufferPool: pool.DefaultBuffer,
}
}
// Create a tunnel handler with Context
func NewTunnelHandlerWithContext(ctx *Context) *TunnelHandler {
return &TunnelHandler{
Ctx: ctx,
BufferPool: pool.DefaultBuffer,
}
}
// Standard net/http function. You can use it alone
func (tunnel *TunnelHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
// Execution middleware
ctx := tunnel.Ctx.WithRequest(req)
resp, err := ctx.Next(req)
if err != nil && err != MethodNotSupportErr {
if resp != nil {
copyHeaders(rw.Header(), resp.Header, tunnel.Ctx.KeepDestinationHeaders)
rw.WriteHeader(resp.StatusCode)
buf := tunnel.buffer().Get()
_, err = io.CopyBuffer(rw, resp.Body, buf)
tunnel.buffer().Put(buf)
}
return
}
// hijacker connection
proxyClient, err := hijacker(rw)
if err != nil {
http.Error(rw, err.Error(), 502)
return
}
var (
u *url.URL = nil
targetConn net.Conn = nil
targetAddr = hostAndPort(req.URL.Host)
isCascadeProxy = false
)
if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.Proxy != nil {
u, err = tunnel.Ctx.Transport.Proxy(req)
if err != nil {
ConnError(proxyClient)
return
}
if u != nil {
// connect addr eg. "localhost:80"
targetAddr = hostAndPort(u.Host)
isCascadeProxy = true
}
}
// connect to targetAddr
targetConn, err = tunnel.connContainer().Get(targetAddr)
if err != nil {
targetConn, err = tunnel.ConnectDial("tcp", targetAddr)
if err != nil {
ConnError(proxyClient)
return
}
}
// If the ConnContainer is exists,
// When io.CopyBuffer is complete,
// put the idle connection into the ConnContainer so can reuse it next time
defer func() {
err := tunnel.connContainer().Put(targetConn)
if err != nil {
// put conn fail, conn must be closed
_ = targetConn.Close()
}
}()
// The cascade proxy needs to forward the request
if isCascadeProxy {
// The cascade proxy needs to send it as-is
_ = req.Write(targetConn)
} else {
// Tell client that the tunnel is ready
_, _ = proxyClient.Write(HttpTunnelOk)
}
go func() {
buf := tunnel.buffer().Get()
_, _ = io.CopyBuffer(targetConn, proxyClient, buf)
tunnel.buffer().Put(buf)
_ = proxyClient.Close()
}()
buf := tunnel.buffer().Get()
_, _ = io.CopyBuffer(proxyClient, targetConn, buf)
tunnel.buffer().Put(buf)
}
// Use registers an Middleware to proxy
func (tunnel *TunnelHandler) Use(middleware ...Middleware) {
tunnel.Ctx.Use(middleware...)
}
// UseFunc registers an MiddlewareFunc to proxy
func (tunnel *TunnelHandler) UseFunc(fus ...MiddlewareFunc) {
tunnel.Ctx.UseFunc(fus...)
}
// OnRequest filter requests through Filters
func (tunnel *TunnelHandler) OnRequest(filters ...Filter) *ReqFilterGroup {
return &ReqFilterGroup{ctx: tunnel.Ctx, filters: filters}
}
// OnResponse filter response through Filters
func (tunnel *TunnelHandler) OnResponse(filters ...Filter) *RespFilterGroup {
return &RespFilterGroup{ctx: tunnel.Ctx, filters: filters}
}
func (tunnel *TunnelHandler) ConnectDial(network, addr string) (net.Conn, error) {
if tunnel.Ctx.Transport != nil && tunnel.Ctx.Transport.DialContext != nil {
return tunnel.Ctx.Transport.DialContext(tunnel.context(), network, addr)
}
return net.DialTimeout(network, addr, 30*time.Second)
}
// Transport get http.Transport instance
func (tunnel *TunnelHandler) Transport() *http.Transport {
return tunnel.Ctx.Transport
}
// get a context.Context
func (tunnel *TunnelHandler) context() context.Context {
if tunnel.Ctx.Context != nil {
return tunnel.Ctx.Context
}
return context.Background()
}
// Get buffer pool
func (tunnel *TunnelHandler) buffer() httputil.BufferPool {
if tunnel.BufferPool != nil {
return tunnel.BufferPool
}
return pool.DefaultBuffer
}
// Get a conn pool
func (tunnel *TunnelHandler) connContainer() pool.ConnContainer {
if tunnel.ConnContainer != nil {
return tunnel.ConnContainer
}
return pool.DefaultConnProvider
}
func hostAndPort(addr string) string {
if !hasPort.MatchString(addr) {
addr += ":80"
}
return addr
}
func ConnError(w net.Conn) {
_, _ = w.Write(HttpTunnelFail)
_ = w.Close()
}