Files
demo/pkg/rewriter/rewriter.go
2025-03-14 18:50:49 +00:00

286 lines
6.5 KiB
Go

package rewriter
import (
"bufio"
"encoding/json"
"fmt"
"net/http"
"os"
"regexp"
"strings"
)
// Rewriter URL重写器
// 用于在反向代理中重写请求URL
type Rewriter struct {
// 重写规则列表
rules []*RewriteRule
}
// RewriteRule 重写规则
type RewriteRule struct {
// 匹配模式
Pattern string `json:"pattern"`
// 替换模式
Replacement string `json:"replacement"`
// 是否使用正则表达式
UseRegex bool `json:"use_regex"`
// 编译后的正则表达式
regex *regexp.Regexp `json:"-"`
// 规则描述
Description string `json:"description,omitempty"`
// 规则启用状态
Enabled bool `json:"enabled,omitempty"`
}
// NewRewriter 创建URL重写器
func NewRewriter() *Rewriter {
return &Rewriter{
rules: make([]*RewriteRule, 0),
}
}
// AddRule 添加重写规则
func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error {
rule := &RewriteRule{
Pattern: pattern,
Replacement: replacement,
UseRegex: useRegex,
Enabled: true,
}
if useRegex {
regex, err := regexp.Compile(pattern)
if err != nil {
return err
}
rule.regex = regex
}
r.rules = append(r.rules, rule)
return nil
}
// AddRuleWithDescription 添加带描述的重写规则
func (r *Rewriter) AddRuleWithDescription(pattern, replacement string, useRegex bool, description string) error {
rule := &RewriteRule{
Pattern: pattern,
Replacement: replacement,
UseRegex: useRegex,
Description: description,
Enabled: true,
}
if useRegex {
regex, err := regexp.Compile(pattern)
if err != nil {
return err
}
rule.regex = regex
}
r.rules = append(r.rules, rule)
return nil
}
// Rewrite 重写URL
func (r *Rewriter) Rewrite(req *http.Request) {
path := req.URL.Path
for _, rule := range r.rules {
if !rule.Enabled {
continue
}
if rule.UseRegex {
if rule.regex.MatchString(path) {
req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement)
break
}
} else {
if strings.HasPrefix(path, rule.Pattern) {
req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1)
break
}
}
}
}
// RewriteResponse 重写响应
// 主要用于处理响应中的Location头和内容中的URL
func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) {
// 处理重定向头
location := resp.Header.Get("Location")
if location != "" {
// 将后端服务器的域名替换成代理服务器的域名
for _, rule := range r.rules {
if !rule.Enabled {
continue
}
if rule.UseRegex && rule.regex != nil {
if rule.regex.MatchString(location) {
newLocation := rule.regex.ReplaceAllString(location, rule.Replacement)
resp.Header.Set("Location", newLocation)
break
}
}
}
}
}
// LoadRulesFromFile 从文件加载重写规则
func (r *Rewriter) LoadRulesFromFile(filename string) error {
file, err := os.Open(filename)
if err != nil {
return fmt.Errorf("打开文件失败: %v", err)
}
defer file.Close()
// 检查文件扩展名,决定使用何种方式解析
if strings.HasSuffix(filename, ".json") {
return r.loadRulesFromJSON(file)
} else {
return r.loadRulesFromText(file)
}
}
// loadRulesFromJSON 从JSON文件加载规则
func (r *Rewriter) loadRulesFromJSON(file *os.File) error {
var rules []*RewriteRule
decoder := json.NewDecoder(file)
if err := decoder.Decode(&rules); err != nil {
return fmt.Errorf("解析JSON失败: %v", err)
}
// 编译正则表达式
for _, rule := range rules {
if rule.UseRegex {
regex, err := regexp.Compile(rule.Pattern)
if err != nil {
return fmt.Errorf("编译正则表达式'%s'失败: %v", rule.Pattern, err)
}
rule.regex = regex
}
r.rules = append(r.rules, rule)
}
return nil
}
// loadRulesFromText 从文本文件加载规则
// 格式: pattern replacement [regex] [#description]
func (r *Rewriter) loadRulesFromText(file *os.File) error {
scanner := bufio.NewScanner(file)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// 跳过空行和注释
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) < 2 {
return fmt.Errorf("第%d行格式错误: %s", lineNum, line)
}
pattern := parts[0]
replacement := parts[1]
useRegex := false
description := ""
// 检查是否有额外选项
for i := 2; i < len(parts); i++ {
if parts[i] == "regex" {
useRegex = true
} else if strings.HasPrefix(parts[i], "#") {
// 获取描述信息
description = strings.Join(parts[i:], " ")
description = strings.TrimPrefix(description, "#")
description = strings.TrimSpace(description)
break
}
}
if err := r.AddRuleWithDescription(pattern, replacement, useRegex, description); err != nil {
return fmt.Errorf("第%d行添加规则失败: %v", lineNum, err)
}
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("读取文件失败: %v", err)
}
return nil
}
// GetRules 获取所有规则
func (r *Rewriter) GetRules() []*RewriteRule {
return r.rules
}
// EnableRule 启用规则
func (r *Rewriter) EnableRule(index int) error {
if index < 0 || index >= len(r.rules) {
return fmt.Errorf("规则索引越界: %d", index)
}
r.rules[index].Enabled = true
return nil
}
// DisableRule 禁用规则
func (r *Rewriter) DisableRule(index int) error {
if index < 0 || index >= len(r.rules) {
return fmt.Errorf("规则索引越界: %d", index)
}
r.rules[index].Enabled = false
return nil
}
// RemoveRule 删除规则
func (r *Rewriter) RemoveRule(index int) error {
if index < 0 || index >= len(r.rules) {
return fmt.Errorf("规则索引越界: %d", index)
}
r.rules = append(r.rules[:index], r.rules[index+1:]...)
return nil
}
// SaveRulesToFile 将规则保存到文件
func (r *Rewriter) SaveRulesToFile(filename string) error {
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("创建文件失败: %v", err)
}
defer file.Close()
// 根据文件扩展名决定保存格式
if strings.HasSuffix(filename, ".json") {
encoder := json.NewEncoder(file)
encoder.SetIndent("", " ")
return encoder.Encode(r.rules)
} else {
writer := bufio.NewWriter(file)
for _, rule := range r.rules {
line := rule.Pattern + " " + rule.Replacement
if rule.UseRegex {
line += " regex"
}
if rule.Description != "" {
line += " # " + rule.Description
}
if !rule.Enabled {
line = "# " + line + " (disabled)"
}
if _, err := writer.WriteString(line + "\n"); err != nil {
return fmt.Errorf("写入文件失败: %v", err)
}
}
return writer.Flush()
}
}