286 lines
6.5 KiB
Go
286 lines
6.5 KiB
Go
package rewriter
|
|
|
|
import (
|
|
"bufio"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
)
|
|
|
|
// Rewriter URL重写器
|
|
// 用于在反向代理中重写请求URL
|
|
type Rewriter struct {
|
|
// 重写规则列表
|
|
rules []*RewriteRule
|
|
}
|
|
|
|
// RewriteRule 重写规则
|
|
type RewriteRule struct {
|
|
// 匹配模式
|
|
Pattern string `json:"pattern"`
|
|
// 替换模式
|
|
Replacement string `json:"replacement"`
|
|
// 是否使用正则表达式
|
|
UseRegex bool `json:"use_regex"`
|
|
// 编译后的正则表达式
|
|
regex *regexp.Regexp `json:"-"`
|
|
// 规则描述
|
|
Description string `json:"description,omitempty"`
|
|
// 规则启用状态
|
|
Enabled bool `json:"enabled,omitempty"`
|
|
}
|
|
|
|
// NewRewriter 创建URL重写器
|
|
func NewRewriter() *Rewriter {
|
|
return &Rewriter{
|
|
rules: make([]*RewriteRule, 0),
|
|
}
|
|
}
|
|
|
|
// AddRule 添加重写规则
|
|
func (r *Rewriter) AddRule(pattern, replacement string, useRegex bool) error {
|
|
rule := &RewriteRule{
|
|
Pattern: pattern,
|
|
Replacement: replacement,
|
|
UseRegex: useRegex,
|
|
Enabled: true,
|
|
}
|
|
|
|
if useRegex {
|
|
regex, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rule.regex = regex
|
|
}
|
|
|
|
r.rules = append(r.rules, rule)
|
|
return nil
|
|
}
|
|
|
|
// AddRuleWithDescription 添加带描述的重写规则
|
|
func (r *Rewriter) AddRuleWithDescription(pattern, replacement string, useRegex bool, description string) error {
|
|
rule := &RewriteRule{
|
|
Pattern: pattern,
|
|
Replacement: replacement,
|
|
UseRegex: useRegex,
|
|
Description: description,
|
|
Enabled: true,
|
|
}
|
|
|
|
if useRegex {
|
|
regex, err := regexp.Compile(pattern)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rule.regex = regex
|
|
}
|
|
|
|
r.rules = append(r.rules, rule)
|
|
return nil
|
|
}
|
|
|
|
// Rewrite 重写URL
|
|
func (r *Rewriter) Rewrite(req *http.Request) {
|
|
path := req.URL.Path
|
|
|
|
for _, rule := range r.rules {
|
|
if !rule.Enabled {
|
|
continue
|
|
}
|
|
|
|
if rule.UseRegex {
|
|
if rule.regex.MatchString(path) {
|
|
req.URL.Path = rule.regex.ReplaceAllString(path, rule.Replacement)
|
|
break
|
|
}
|
|
} else {
|
|
if strings.HasPrefix(path, rule.Pattern) {
|
|
req.URL.Path = strings.Replace(path, rule.Pattern, rule.Replacement, 1)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// RewriteResponse 重写响应
|
|
// 主要用于处理响应中的Location头和内容中的URL
|
|
func (r *Rewriter) RewriteResponse(resp *http.Response, originHost string) {
|
|
// 处理重定向头
|
|
location := resp.Header.Get("Location")
|
|
if location != "" {
|
|
// 将后端服务器的域名替换成代理服务器的域名
|
|
for _, rule := range r.rules {
|
|
if !rule.Enabled {
|
|
continue
|
|
}
|
|
|
|
if rule.UseRegex && rule.regex != nil {
|
|
if rule.regex.MatchString(location) {
|
|
newLocation := rule.regex.ReplaceAllString(location, rule.Replacement)
|
|
resp.Header.Set("Location", newLocation)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// LoadRulesFromFile 从文件加载重写规则
|
|
func (r *Rewriter) LoadRulesFromFile(filename string) error {
|
|
file, err := os.Open(filename)
|
|
if err != nil {
|
|
return fmt.Errorf("打开文件失败: %v", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
// 检查文件扩展名,决定使用何种方式解析
|
|
if strings.HasSuffix(filename, ".json") {
|
|
return r.loadRulesFromJSON(file)
|
|
} else {
|
|
return r.loadRulesFromText(file)
|
|
}
|
|
}
|
|
|
|
// loadRulesFromJSON 从JSON文件加载规则
|
|
func (r *Rewriter) loadRulesFromJSON(file *os.File) error {
|
|
var rules []*RewriteRule
|
|
decoder := json.NewDecoder(file)
|
|
if err := decoder.Decode(&rules); err != nil {
|
|
return fmt.Errorf("解析JSON失败: %v", err)
|
|
}
|
|
|
|
// 编译正则表达式
|
|
for _, rule := range rules {
|
|
if rule.UseRegex {
|
|
regex, err := regexp.Compile(rule.Pattern)
|
|
if err != nil {
|
|
return fmt.Errorf("编译正则表达式'%s'失败: %v", rule.Pattern, err)
|
|
}
|
|
rule.regex = regex
|
|
}
|
|
r.rules = append(r.rules, rule)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// loadRulesFromText 从文本文件加载规则
|
|
// 格式: pattern replacement [regex] [#description]
|
|
func (r *Rewriter) loadRulesFromText(file *os.File) error {
|
|
scanner := bufio.NewScanner(file)
|
|
lineNum := 0
|
|
|
|
for scanner.Scan() {
|
|
lineNum++
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// 跳过空行和注释
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
parts := strings.Fields(line)
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("第%d行格式错误: %s", lineNum, line)
|
|
}
|
|
|
|
pattern := parts[0]
|
|
replacement := parts[1]
|
|
useRegex := false
|
|
description := ""
|
|
|
|
// 检查是否有额外选项
|
|
for i := 2; i < len(parts); i++ {
|
|
if parts[i] == "regex" {
|
|
useRegex = true
|
|
} else if strings.HasPrefix(parts[i], "#") {
|
|
// 获取描述信息
|
|
description = strings.Join(parts[i:], " ")
|
|
description = strings.TrimPrefix(description, "#")
|
|
description = strings.TrimSpace(description)
|
|
break
|
|
}
|
|
}
|
|
|
|
if err := r.AddRuleWithDescription(pattern, replacement, useRegex, description); err != nil {
|
|
return fmt.Errorf("第%d行添加规则失败: %v", lineNum, err)
|
|
}
|
|
}
|
|
|
|
if err := scanner.Err(); err != nil {
|
|
return fmt.Errorf("读取文件失败: %v", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetRules 获取所有规则
|
|
func (r *Rewriter) GetRules() []*RewriteRule {
|
|
return r.rules
|
|
}
|
|
|
|
// EnableRule 启用规则
|
|
func (r *Rewriter) EnableRule(index int) error {
|
|
if index < 0 || index >= len(r.rules) {
|
|
return fmt.Errorf("规则索引越界: %d", index)
|
|
}
|
|
r.rules[index].Enabled = true
|
|
return nil
|
|
}
|
|
|
|
// DisableRule 禁用规则
|
|
func (r *Rewriter) DisableRule(index int) error {
|
|
if index < 0 || index >= len(r.rules) {
|
|
return fmt.Errorf("规则索引越界: %d", index)
|
|
}
|
|
r.rules[index].Enabled = false
|
|
return nil
|
|
}
|
|
|
|
// RemoveRule 删除规则
|
|
func (r *Rewriter) RemoveRule(index int) error {
|
|
if index < 0 || index >= len(r.rules) {
|
|
return fmt.Errorf("规则索引越界: %d", index)
|
|
}
|
|
r.rules = append(r.rules[:index], r.rules[index+1:]...)
|
|
return nil
|
|
}
|
|
|
|
// SaveRulesToFile 将规则保存到文件
|
|
func (r *Rewriter) SaveRulesToFile(filename string) error {
|
|
file, err := os.Create(filename)
|
|
if err != nil {
|
|
return fmt.Errorf("创建文件失败: %v", err)
|
|
}
|
|
defer file.Close()
|
|
|
|
// 根据文件扩展名决定保存格式
|
|
if strings.HasSuffix(filename, ".json") {
|
|
encoder := json.NewEncoder(file)
|
|
encoder.SetIndent("", " ")
|
|
return encoder.Encode(r.rules)
|
|
} else {
|
|
writer := bufio.NewWriter(file)
|
|
for _, rule := range r.rules {
|
|
line := rule.Pattern + " " + rule.Replacement
|
|
if rule.UseRegex {
|
|
line += " regex"
|
|
}
|
|
if rule.Description != "" {
|
|
line += " # " + rule.Description
|
|
}
|
|
if !rule.Enabled {
|
|
line = "# " + line + " (disabled)"
|
|
}
|
|
if _, err := writer.WriteString(line + "\n"); err != nil {
|
|
return fmt.Errorf("写入文件失败: %v", err)
|
|
}
|
|
}
|
|
return writer.Flush()
|
|
}
|
|
}
|