278 lines
6.7 KiB
Go
278 lines
6.7 KiB
Go
package proxy
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"fmt"
|
||
"net"
|
||
"net/http"
|
||
"net/http/httputil"
|
||
"net/url"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/darkit/goproxy/internal/rewriter"
|
||
"github.com/darkit/goproxy/internal/router"
|
||
)
|
||
|
||
// ReverseProxy 反向代理
|
||
type ReverseProxy struct {
|
||
// 代理对象
|
||
proxy *Proxy
|
||
// 路由器
|
||
router *router.Router
|
||
// URL重写器
|
||
rewriter *rewriter.Rewriter
|
||
// HTTP传输对象
|
||
transport http.RoundTripper
|
||
}
|
||
|
||
// NewReverseProxy 创建反向代理
|
||
func (p *Proxy) NewReverseProxy() *ReverseProxy {
|
||
rp := &ReverseProxy{
|
||
proxy: p,
|
||
router: router.NewRouter(),
|
||
rewriter: rewriter.NewRewriter(),
|
||
}
|
||
|
||
// 创建自定义的传输对象
|
||
transport := &http.Transport{
|
||
Proxy: func(req *http.Request) (*url.URL, error) {
|
||
// 使用代理委托中的方法获取代理
|
||
return p.delegate.ParentProxy(req)
|
||
},
|
||
DialContext: p.dialContextWithCache(),
|
||
TLSClientConfig: &tls.Config{
|
||
InsecureSkipVerify: true,
|
||
},
|
||
MaxIdleConns: p.config.ConnectionPoolSize,
|
||
IdleConnTimeout: p.config.IdleTimeout,
|
||
TLSHandshakeTimeout: 10 * time.Second,
|
||
ExpectContinueTimeout: 1 * time.Second,
|
||
MaxIdleConnsPerHost: p.config.ConnectionPoolSize,
|
||
DisableCompression: !p.config.EnableCompression,
|
||
}
|
||
|
||
rp.transport = transport
|
||
|
||
// 如果配置了规则文件,加载规则
|
||
if p.config.ReverseProxyRulesFile != "" {
|
||
// 省略加载规则文件的实现
|
||
}
|
||
|
||
return rp
|
||
}
|
||
|
||
// ServeHTTP 处理反向代理请求
|
||
func (rp *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||
// 获取请求上下文
|
||
ctx := ctxPool.Get().(*Context)
|
||
ctx.Reset(req)
|
||
defer ctxPool.Put(ctx)
|
||
|
||
// 调用连接事件
|
||
rp.proxy.delegate.Connect(ctx, rw)
|
||
|
||
// 认证检查
|
||
rp.proxy.delegate.Auth(ctx, rw)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 请求前处理
|
||
rp.proxy.delegate.BeforeRequest(ctx)
|
||
if ctx.IsAborted() {
|
||
return
|
||
}
|
||
|
||
// 解析后端地址
|
||
backend, err := rp.proxy.delegate.ResolveBackend(req)
|
||
if err != nil {
|
||
rp.proxy.delegate.HandleError(rw, req, err)
|
||
return
|
||
}
|
||
|
||
// 创建请求代理对象
|
||
proxy := httputil.NewSingleHostReverseProxy(&url.URL{
|
||
Scheme: "http",
|
||
Host: backend,
|
||
})
|
||
|
||
// 使用自定义传输对象
|
||
proxy.Transport = rp.transport
|
||
|
||
// 设置自定义错误处理函数
|
||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||
rp.proxy.delegate.HandleError(rw, req, err)
|
||
}
|
||
|
||
// 设置请求修改函数
|
||
originalDirector := proxy.Director
|
||
proxy.Director = func(req *http.Request) {
|
||
// 调用原始Director函数
|
||
originalDirector(req)
|
||
|
||
// 处理URL重写
|
||
if rp.proxy.config.EnableURLRewrite {
|
||
rp.rewriter.Rewrite(req)
|
||
}
|
||
|
||
// 修改请求头
|
||
if rp.proxy.config.RewriteHostHeader {
|
||
req.Host = backend
|
||
}
|
||
|
||
// 添加X-Forwarded-For头
|
||
if rp.proxy.config.AddXForwardedFor {
|
||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||
if err == nil {
|
||
// 如果已经有X-Forwarded-For,添加到末尾
|
||
if prior, ok := req.Header["X-Forwarded-For"]; ok {
|
||
clientIP = strings.Join(prior, ", ") + ", " + clientIP
|
||
}
|
||
req.Header.Set("X-Forwarded-For", clientIP)
|
||
}
|
||
}
|
||
|
||
// 添加X-Real-IP头
|
||
if rp.proxy.config.AddXRealIP {
|
||
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
|
||
if err == nil {
|
||
req.Header.Set("X-Real-IP", clientIP)
|
||
}
|
||
}
|
||
|
||
// 设置协议头
|
||
req.Header.Set("X-Forwarded-Proto", "http")
|
||
if req.TLS != nil {
|
||
req.Header.Set("X-Forwarded-Proto", "https")
|
||
}
|
||
|
||
// 调用委托的ModifyRequest方法
|
||
rp.proxy.delegate.ModifyRequest(req)
|
||
}
|
||
|
||
// 设置响应修改函数
|
||
proxy.ModifyResponse = func(resp *http.Response) error {
|
||
// 处理响应URL重写
|
||
if rp.proxy.config.EnableURLRewrite && resp != nil {
|
||
rp.rewriter.RewriteResponse(resp, req.Host)
|
||
}
|
||
|
||
// 添加CORS头
|
||
if rp.proxy.config.EnableCORS && resp != nil {
|
||
resp.Header.Set("Access-Control-Allow-Origin", "*")
|
||
resp.Header.Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||
resp.Header.Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||
}
|
||
|
||
// 调用委托的ModifyResponse方法
|
||
return rp.proxy.delegate.ModifyResponse(resp)
|
||
}
|
||
|
||
// 更新监控指标
|
||
if rp.proxy.metrics != nil {
|
||
rp.proxy.metrics.IncActiveConnections()
|
||
defer rp.proxy.metrics.DecActiveConnections()
|
||
|
||
startTime := time.Now()
|
||
defer func() {
|
||
duration := time.Since(startTime)
|
||
rp.proxy.metrics.ObserveRequestDuration(duration.Seconds())
|
||
rp.proxy.metrics.IncRequestCount()
|
||
}()
|
||
}
|
||
|
||
// 处理WebSocket升级
|
||
if rp.proxy.config.SupportWebSocketUpgrade && isWebSocketRequest(req) {
|
||
rp.handleWebSocketUpgrade(rw, req, backend)
|
||
return
|
||
}
|
||
|
||
// 处理普通请求
|
||
proxy.ServeHTTP(rw, req)
|
||
|
||
// 完成事件
|
||
rp.proxy.delegate.Finish(ctx)
|
||
}
|
||
|
||
// 处理WebSocket升级
|
||
func (rp *ReverseProxy) handleWebSocketUpgrade(rw http.ResponseWriter, req *http.Request, backend string) {
|
||
// 创建WebSocket代理
|
||
target := &url.URL{
|
||
Scheme: "ws",
|
||
Host: backend,
|
||
}
|
||
|
||
if req.TLS != nil {
|
||
target.Scheme = "wss"
|
||
}
|
||
|
||
// 创建连接到后端的WebSocket连接
|
||
backendConn, err := rp.dialBackend(target.String(), req)
|
||
if err != nil {
|
||
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("无法连接到后端WebSocket服务: %v", err))
|
||
return
|
||
}
|
||
defer backendConn.Close()
|
||
|
||
// 将请求转发给后端
|
||
err = req.Write(backendConn)
|
||
if err != nil {
|
||
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("写入WebSocket请求错误: %v", err))
|
||
return
|
||
}
|
||
|
||
// 升级客户端连接
|
||
clientConn, err := hijacker(rw)
|
||
if err != nil {
|
||
rp.proxy.delegate.HandleError(rw, req, fmt.Errorf("升级WebSocket连接错误: %v", err))
|
||
return
|
||
}
|
||
|
||
// 双向转发数据
|
||
rp.proxy.transfer(clientConn, backendConn)
|
||
}
|
||
|
||
// 连接到后端
|
||
func (rp *ReverseProxy) dialBackend(url string, req *http.Request) (net.Conn, error) {
|
||
ctx, cancel := context.WithTimeout(req.Context(), 15*time.Second)
|
||
defer cancel()
|
||
|
||
backend := strings.TrimPrefix(url, "ws://")
|
||
backend = strings.TrimPrefix(backend, "wss://")
|
||
|
||
if strings.Contains(backend, "/") {
|
||
backend = backend[:strings.Index(backend, "/")]
|
||
}
|
||
|
||
// 根据协议选择连接方式
|
||
if strings.HasPrefix(url, "wss://") {
|
||
// 使用 tls.Dialer 替代不存在的 tls.DialWithContext
|
||
dialer := &tls.Dialer{
|
||
Config: &tls.Config{
|
||
InsecureSkipVerify: true,
|
||
},
|
||
}
|
||
return dialer.DialContext(ctx, "tcp", backend)
|
||
}
|
||
|
||
var d net.Dialer
|
||
return d.DialContext(ctx, "tcp", backend)
|
||
}
|
||
|
||
// 添加路由规则
|
||
func (rp *ReverseProxy) AddRoute(pattern string, routeType router.RouteType, target string) {
|
||
route := &router.Route{
|
||
Pattern: pattern,
|
||
Type: routeType,
|
||
Target: target,
|
||
}
|
||
rp.router.AddRoute(route)
|
||
}
|
||
|
||
// 添加重写规则
|
||
func (rp *ReverseProxy) AddRewriteRule(pattern, replacement string, useRegex bool) error {
|
||
return rp.rewriter.AddRule(pattern, replacement, useRegex)
|
||
}
|