Files
goproxy/proxy.go
2025-03-13 15:56:33 +08:00

754 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright 2018 ouqiang authors
//
// Licensed under the Apache License, Version 2.0 (the "License"): you may
// not use this file except in compliance with the License. You may obtain
// a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations
// under the License.
// Package goproxy HTTP(S)代理, 支持中间人代理解密HTTPS数据
package goproxy
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/viki-org/dnscache"
"github.com/ouqiang/goproxy/cert"
"github.com/ouqiang/websocket"
)
const (
// 连接目标服务器超时时间
defaultTargetConnectTimeout = 5 * time.Second
// 目标服务器读写超时时间
defaultTargetReadWriteTimeout = 10 * time.Second
)
type DialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// 隧道连接成功响应行
var tunnelEstablishedResponseLine = []byte("HTTP/1.1 200 Connection established\r\n\r\n")
var badGateway = []byte(fmt.Sprintf("HTTP/1.1 %d %s\r\n\r\n", http.StatusBadGateway, http.StatusText(http.StatusBadGateway)))
var (
bufPool = sync.Pool{
New: func() interface{} {
return make([]byte, 32*1024)
},
}
ctxPool = sync.Pool{
New: func() interface{} {
return new(Context)
},
}
headerPool = NewHeaderPool()
requestPool = newRequestPool()
)
type RequestPool struct {
pool sync.Pool
}
func newRequestPool() *RequestPool {
return &RequestPool{
pool: sync.Pool{
New: func() interface{} {
return new(http.Request)
},
},
}
}
func (p *RequestPool) Get() *http.Request {
req := p.pool.Get().(*http.Request)
req.Method = ""
req.URL = nil
req.Proto = ""
req.ProtoMajor = 0
req.ProtoMinor = 0
req.Header = nil
req.Body = nil
req.GetBody = nil
req.ContentLength = 0
req.TransferEncoding = nil
req.Close = false
req.Host = ""
req.Form = nil
req.PostForm = nil
req.MultipartForm = nil
req.Trailer = nil
req.RemoteAddr = ""
req.RequestURI = ""
req.TLS = nil
req.Cancel = nil
req.Response = nil
return req
}
func (p *RequestPool) Put(req *http.Request) {
if req != nil {
p.pool.Put(req)
}
}
type HeaderPool struct {
pool sync.Pool
}
func NewHeaderPool() *HeaderPool {
return &HeaderPool{
pool: sync.Pool{
New: func() interface{} {
return http.Header{}
},
},
}
}
func (p *HeaderPool) Get() http.Header {
header := p.pool.Get().(http.Header)
for k := range header {
delete(header, k)
}
return header
}
func (p *HeaderPool) Put(header http.Header) {
if header != nil {
p.pool.Put(header)
}
}
// 生成隧道建立请求行
func makeTunnelRequestLine(addr string) string {
return fmt.Sprintf("CONNECT %s HTTP/1.1\r\n\r\n", addr)
}
type options struct {
disableKeepAlive bool
delegate Delegate
decryptHTTPS bool
websocketIntercept bool
certCache cert.Cache
transport *http.Transport
clientTrace *httptrace.ClientTrace
}
type Option func(*options)
// WithDisableKeepAlive 连接是否重用
func WithDisableKeepAlive(disableKeepAlive bool) Option {
return func(opt *options) {
opt.disableKeepAlive = disableKeepAlive
}
}
func WithClientTrace(t *httptrace.ClientTrace) Option {
return func(opt *options) {
opt.clientTrace = t
}
}
// WithDelegate 设置委托类
func WithDelegate(delegate Delegate) Option {
return func(opt *options) {
opt.delegate = delegate
}
}
// WithTransport 自定义http transport
func WithTransport(t *http.Transport) Option {
return func(opt *options) {
opt.transport = t
}
}
// WithDecryptHTTPS 中间人代理, 解密HTTPS, 需实现证书缓存接口
func WithDecryptHTTPS(c cert.Cache) Option {
return func(opt *options) {
opt.decryptHTTPS = true
opt.certCache = c
}
}
// WithEnableWebsocketIntercept 拦截websocket
func WithEnableWebsocketIntercept() Option {
return func(opt *options) {
opt.websocketIntercept = true
}
}
// New 创建proxy实例
func New(opt ...Option) *Proxy {
opts := &options{}
for _, o := range opt {
o(opts)
}
if opts.delegate == nil {
opts.delegate = &DefaultDelegate{}
}
if opts.transport == nil {
opts.transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
MaxIdleConns: 100,
MaxConnsPerHost: 10,
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
}
p := &Proxy{}
p.delegate = opts.delegate
p.websocketIntercept = opts.websocketIntercept
p.decryptHTTPS = opts.decryptHTTPS
if p.decryptHTTPS {
p.cert = cert.NewCertificate(opts.certCache, true)
}
p.transport = opts.transport
p.transport.DialContext = p.dialContext()
p.dnsCache = dnscache.New(5 * time.Minute)
p.transport.DisableKeepAlives = opts.disableKeepAlive
p.transport.Proxy = p.delegate.ParentProxy
p.clientTrace = opts.clientTrace
return p
}
// Proxy 实现了http.Handler接口
type Proxy struct {
delegate Delegate
clientConnNum int32
decryptHTTPS bool
websocketIntercept bool
cert *cert.Certificate
transport *http.Transport
clientTrace *httptrace.ClientTrace
dnsCache *dnscache.Resolver
}
var _ http.Handler = &Proxy{}
// ServeHTTP 实现了http.Handler接口
func (p *Proxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if req.URL.Host == "" {
req.URL.Host = req.Host
}
atomic.AddInt32(&p.clientConnNum, 1)
ctx := ctxPool.Get().(*Context)
ctx.Reset(req)
defer func() {
p.delegate.Finish(ctx)
ctxPool.Put(ctx)
atomic.AddInt32(&p.clientConnNum, -1)
}()
p.delegate.Connect(ctx, rw)
if ctx.abort {
return
}
p.delegate.Auth(ctx, rw)
if ctx.abort {
return
}
switch {
case ctx.Req.Method == http.MethodConnect:
p.tunnelProxy(ctx, rw)
case websocket.IsWebSocketUpgrade(ctx.Req):
p.tunnelProxy(ctx, rw)
default:
p.httpProxy(ctx, rw)
}
}
// ClientConnNum 获取客户端连接数
func (p *Proxy) ClientConnNum() int32 {
return atomic.LoadInt32(&p.clientConnNum)
}
// DoRequest 执行HTTP请求并调用responseFunc处理response
func (p *Proxy) DoRequest(ctx *Context, responseFunc func(*http.Response, error)) {
if ctx.Data == nil {
ctx.Data = make(map[interface{}]interface{})
}
p.delegate.BeforeRequest(ctx)
if ctx.abort {
return
}
newReq := requestPool.Get()
*newReq = *ctx.Req
newHeader := headerPool.Get()
CloneHeader(newReq.Header, newHeader)
newReq.Header = newHeader
for _, item := range hopHeaders {
if newReq.Header.Get(item) != "" {
newReq.Header.Del(item)
}
}
if p.clientTrace != nil {
newReq = newReq.WithContext(httptrace.WithClientTrace(newReq.Context(), p.clientTrace))
}
resp, err := p.transport.RoundTrip(newReq)
p.delegate.BeforeResponse(ctx, resp, err)
if ctx.abort {
return
}
if err == nil {
for _, h := range hopHeaders {
resp.Header.Del(h)
}
}
responseFunc(resp, err)
headerPool.Put(newHeader)
requestPool.Put(newReq)
}
// HTTP代理
func (p *Proxy) httpProxy(ctx *Context, rw http.ResponseWriter) {
ctx.Req.URL.Scheme = "http"
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTP请求错误: %s", ctx.Req.URL, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
defer func() {
_ = resp.Body.Close()
}()
CopyHeader(rw.Header(), resp.Header)
rw.WriteHeader(resp.StatusCode)
buf := bufPool.Get().([]byte)
_, _ = io.CopyBuffer(rw, resp.Body, buf)
bufPool.Put(buf)
})
}
// HTTPS代理
func (p *Proxy) httpsProxy(ctx *Context, tlsClientConn *tls.Conn) {
if websocket.IsWebSocketUpgrade(ctx.Req) {
p.websocketProxy(ctx, NewConnBuffer(tlsClientConn, nil))
return
}
p.DoRequest(ctx, func(resp *http.Response, err error) {
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 请求错误: %s", ctx.Req.URL, err))
_, _ = tlsClientConn.Write(badGateway)
return
}
err = resp.Write(tlsClientConn)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, response写入客户端失败, %s", ctx.Req.URL, err))
}
_ = resp.Body.Close()
})
}
// 隧道代理
func (p *Proxy) tunnelProxy(ctx *Context, rw http.ResponseWriter) {
clientConn, err := hijacker(rw)
if err != nil {
p.delegate.ErrorLog(err)
rw.WriteHeader(http.StatusBadGateway)
return
}
defer func() {
_ = clientConn.Close()
}()
if websocket.IsWebSocketUpgrade(ctx.Req) {
p.websocketProxy(ctx, clientConn)
return
}
parentProxyURL, err := p.delegate.ParentProxy(ctx.Req)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 解析代理地址错误: %s", ctx.Req.URL.Host, err))
rw.WriteHeader(http.StatusBadGateway)
return
}
if parentProxyURL == nil {
_, err = clientConn.Write(tunnelEstablishedResponseLine)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道连接成功,通知客户端错误: %s", ctx.Req.URL.Host, err))
return
}
}
isWebsocket := p.detectConnProtocol(clientConn)
if isWebsocket {
req, err := http.ReadRequest(clientConn.BufferReader())
if err != nil {
if err != io.EOF {
p.delegate.ErrorLog(fmt.Errorf("%s - websocket读取客户端升级请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
req.RemoteAddr = ctx.Req.RemoteAddr
req.URL.Scheme = "http"
req.URL.Host = req.Host
ctx.Req = req
p.websocketProxy(ctx, clientConn)
return
}
var tlsClientConn *tls.Conn
if p.decryptHTTPS {
tlsConfig, err := p.cert.GenerateTlsConfig(ctx.Req.URL.Host)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 生成证书失败: %s", ctx.Req.URL.Host, err))
return
}
tlsClientConn = tls.Server(clientConn, tlsConfig)
defer func() {
_ = tlsClientConn.Close()
}()
if err := tlsClientConn.Handshake(); err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 握手失败: %s", ctx.Req.URL.Host, err))
return
}
buf := bufio.NewReader(tlsClientConn)
tlsReq, err := http.ReadRequest(buf)
if err != nil {
if err != io.EOF {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - HTTPS解密, 读取客户端请求失败: %s", ctx.Req.URL.Host, err))
}
return
}
tlsReq.RemoteAddr = ctx.Req.RemoteAddr
tlsReq.URL.Scheme = "https"
tlsReq.URL.Host = tlsReq.Host
ctx.Req = tlsReq
}
targetAddr := ctx.Req.URL.Host
if parentProxyURL != nil {
targetAddr = parentProxyURL.Host
}
if !strings.Contains(targetAddr, ":") {
targetAddr += ":443"
}
targetConn, err := net.DialTimeout("tcp", targetAddr, defaultTargetConnectTimeout)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 隧道转发连接目标服务器失败: %s", ctx.Req.URL.Host, err))
return
}
defer func() {
_ = targetConn.Close()
}()
if parentProxyURL != nil {
tunnelRequestLine := makeTunnelRequestLine(ctx.Req.URL.Host)
_, _ = targetConn.Write([]byte(tunnelRequestLine))
}
if p.decryptHTTPS {
p.httpsProxy(ctx, tlsClientConn)
} else {
p.tunnelConnected(ctx, nil)
p.transfer(clientConn, targetConn)
}
}
// WebSocket代理
func (p *Proxy) websocketProxy(ctx *Context, srcConn *ConnBuffer) {
if !p.websocketIntercept {
remoteAddr := ctx.Addr()
var err error
var targetConn net.Conn
if ctx.IsHTTPS() {
targetConn, err = tls.Dial("tcp", remoteAddr, &tls.Config{InsecureSkipVerify: true})
} else {
targetConn, err = net.Dial("tcp", remoteAddr)
}
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket连接目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
err = ctx.Req.Write(targetConn)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - websocket协议转换请求写入目标服务器错误: %s", ctx.Req.URL.Host, err))
return
}
p.tunnelConnected(ctx, nil)
p.transfer(srcConn, targetConn)
return
}
up := &websocket.Upgrader{
HandshakeTimeout: defaultTargetConnectTimeout,
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
srcWSConn, err := up.Upgrade(srcConn, ctx.Req, http.Header{})
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 源连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
return
}
u := ctx.WebsocketUrl()
d := websocket.Dialer{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
}
dialTimeoutCtx, cancel := context.WithTimeout(context.Background(), defaultTargetConnectTimeout)
defer cancel()
targetWSConn, _, err := d.DialContext(dialTimeoutCtx, u.String(), ctx.Req.Header)
if err != nil {
p.tunnelConnected(ctx, err)
p.delegate.ErrorLog(fmt.Errorf("%s - 目标连接升级到websocket协议错误: %s", ctx.Req.URL.Host, err))
return
}
p.tunnelConnected(ctx, nil)
p.transferWebsocket(ctx, srcWSConn, targetWSConn)
}
// 探测连接协议
func (p *Proxy) detectConnProtocol(connBuf *ConnBuffer) (isWebsocket bool) {
methodBytes, err := connBuf.Peek(3)
if err != nil {
return false
}
method := string(methodBytes)
if method != http.MethodGet {
return false
}
return true
}
// webSocket双向转发
func (p *Proxy) transferWebsocket(ctx *Context, srcConn *websocket.Conn, targetConn *websocket.Conn) {
doneCtx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
for {
if doneCtx.Err() != nil {
return
}
msgType, msg, err := srcConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
return
}
p.delegate.WebSocketSendMessage(ctx, &msgType, &msg)
err = targetConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", srcConn.RemoteAddr(), targetConn.RemoteAddr(), err))
return
}
}
}()
for {
if doneCtx.Err() != nil {
return
}
msgType, msg, err := targetConn.ReadMessage()
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
return
}
p.delegate.WebSocketReceiveMessage(ctx, &msgType, &msg)
err = srcConn.WriteMessage(msgType, msg)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("websocket消息转发错误: [%s -> %s] %s", targetConn.RemoteAddr(), srcConn.RemoteAddr(), err))
return
}
}
}
// 双向转发
func (p *Proxy) transfer(src net.Conn, dst net.Conn) {
go func() {
buf := bufPool.Get().([]byte)
_, err := io.CopyBuffer(src, dst, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", dst.RemoteAddr().String(), src.RemoteAddr().String(), err))
}
bufPool.Put(buf)
_ = src.Close()
_ = dst.Close()
}()
buf := bufPool.Get().([]byte)
_, err := io.CopyBuffer(dst, src, buf)
if err != nil {
p.delegate.ErrorLog(fmt.Errorf("隧道双向转发错误: [%s -> %s] %s", src.RemoteAddr().String(), dst.RemoteAddr().String(), err))
}
bufPool.Put(buf)
_ = dst.Close()
_ = src.Close()
}
func (p *Proxy) tunnelConnected(ctx *Context, err error) {
ctx.TunnelProxy = true
p.delegate.BeforeRequest(ctx)
if err != nil {
p.delegate.BeforeResponse(ctx, nil, err)
return
}
resp := &http.Response{
Status: "200 OK",
StatusCode: http.StatusOK,
Proto: "1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Body: http.NoBody,
}
p.delegate.BeforeResponse(ctx, resp, nil)
}
func (p *Proxy) dialContext() DialContext {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer := &net.Dialer{
Timeout: defaultTargetConnectTimeout,
}
separator := strings.LastIndex(addr, ":")
ips, err := p.dnsCache.Fetch(addr[:separator])
if err != nil {
return nil, err
}
var ip string
for _, item := range ips {
ip = item.String()
if !strings.Contains(ip, ":") {
break
}
}
addr = ip + addr[separator:]
return dialer.DialContext(ctx, network, addr)
}
}
// 获取底层连接
func hijacker(rw http.ResponseWriter) (*ConnBuffer, error) {
hijacker, ok := rw.(http.Hijacker)
if !ok {
return nil, fmt.Errorf("http server不支持Hijacker")
}
conn, buf, err := hijacker.Hijack()
if err != nil {
return nil, fmt.Errorf("hijacker错误: %s", err)
}
return NewConnBuffer(conn, buf), nil
}
// CopyHeader 浅拷贝Header
func CopyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
// CloneHeader 深拷贝Header
func CloneHeader(h http.Header, h2 http.Header) {
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
}
var hopHeaders = []string{
"Proxy-Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailer",
"Transfer-Encoding",
}
type ConnBuffer struct {
net.Conn
buf *bufio.ReadWriter
}
func NewConnBuffer(conn net.Conn, buf *bufio.ReadWriter) *ConnBuffer {
if buf == nil {
buf = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
}
return &ConnBuffer{
Conn: conn,
buf: buf,
}
}
func (cb *ConnBuffer) BufferReader() *bufio.Reader {
return cb.buf.Reader
}
func (cb *ConnBuffer) Read(b []byte) (n int, err error) {
return cb.buf.Read(b)
}
func (cb *ConnBuffer) Peek(n int) ([]byte, error) {
return cb.buf.Peek(n)
}
func (cb *ConnBuffer) Write(p []byte) (n int, err error) {
n, err = cb.buf.Write(p)
if err != nil {
return 0, err
}
return n, cb.buf.Flush()
}
func (cb *ConnBuffer) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return cb.Conn, cb.buf, nil
}
func (cb *ConnBuffer) WriteHeader(_ int) {}
func (cb *ConnBuffer) Header() http.Header { return nil }