Files
goproxy/internal/dns/resolver.go
2025-03-13 17:53:08 +08:00

373 lines
8.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package dns
import (
"errors"
"net"
"sort"
"strings"
"sync"
"time"
)
// Resolver DNS解析器接口
type Resolver interface {
// Resolve 将域名解析为IP地址
Resolve(host string) (string, error)
// ResolveWithPort 将域名解析为IP地址和端口
ResolveWithPort(host string, defaultPort int) (*Endpoint, error)
// Add 添加域名解析规则
Add(host, ip string) error
// AddWithPort 添加带端口的域名解析规则
AddWithPort(host, ip string, port int) error
// AddWildcard 添加泛解析规则(通配符域名)
AddWildcard(wildcardDomain, ip string) error
// AddWildcardWithPort 添加带端口的泛解析规则
AddWildcardWithPort(wildcardDomain, ip string, port int) error
// Remove 删除域名解析规则
Remove(host string) error
// Clear 清除所有解析规则
Clear()
}
// CustomResolver 自定义DNS解析器
type CustomResolver struct {
mu sync.RWMutex
records map[string]*Endpoint // 精确域名到端点的映射
wildcardRules []wildcardRule // 通配符规则列表
cache map[string]cacheEntry // 外部域名解析缓存
fallback bool // 是否在本地记录找不到时回退到系统DNS
ttl time.Duration // 缓存TTL
}
// wildcardRule 通配符规则
type wildcardRule struct {
pattern string // 原始通配符模式,如 *.example.com
parts []string // 分解后的模式部分,如 ["*", "example", "com"]
endpoint *Endpoint // 对应的端点
}
// cacheEntry 缓存条目
type cacheEntry struct {
endpoint *Endpoint
expiresAt time.Time
}
// NewResolver 创建新的自定义DNS解析器
func NewResolver(options ...Option) *CustomResolver {
r := &CustomResolver{
records: make(map[string]*Endpoint),
wildcardRules: make([]wildcardRule, 0),
cache: make(map[string]cacheEntry),
fallback: true,
ttl: 5 * time.Minute,
}
// 应用选项
for _, option := range options {
option(r)
}
return r
}
// Resolve 将域名解析为IP地址
func (r *CustomResolver) Resolve(host string) (string, error) {
endpoint, err := r.ResolveWithPort(host, 0)
if err != nil {
return "", err
}
return endpoint.IP, nil
}
// ResolveWithPort 将域名解析为IP地址和端口
func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoint, error) {
// 首先检查自定义记录
r.mu.RLock()
// 精确匹配
if endpoint, ok := r.records[host]; ok {
r.mu.RUnlock()
return endpoint, nil
}
// 尝试通配符匹配
if endpoint := r.matchWildcard(host); endpoint != nil {
r.mu.RUnlock()
return endpoint, nil
}
// 检查缓存
if entry, ok := r.cache[host]; ok {
if time.Now().Before(entry.expiresAt) {
r.mu.RUnlock()
return entry.endpoint, nil
}
}
r.mu.RUnlock()
// 如果启用回退则使用系统DNS
if r.fallback {
ips, err := net.LookupIP(host)
if err != nil {
return nil, err
}
// 使用第一个IPv4地址
var ip string
for _, addr := range ips {
if ipv4 := addr.To4(); ipv4 != nil {
ip = ipv4.String()
break
}
}
if ip == "" {
return nil, errors.New("未找到IPv4地址")
}
// 创建端点
endpoint := NewEndpoint(ip)
if defaultPort > 0 {
endpoint.Port = defaultPort
}
// 更新缓存
r.mu.Lock()
r.cache[host] = cacheEntry{
endpoint: endpoint,
expiresAt: time.Now().Add(r.ttl),
}
r.mu.Unlock()
return endpoint, nil
}
return nil, errors.New("未找到域名记录且系统DNS回退被禁用")
}
// matchWildcard 尝试匹配通配符规则
func (r *CustomResolver) matchWildcard(host string) *Endpoint {
hostParts := strings.Split(host, ".")
// 按照通配符规则列表的顺序尝试匹配
// 规则顺序应该保证更具体的规则先匹配
for _, rule := range r.wildcardRules {
if matchDomainPattern(hostParts, rule.parts) {
return rule.endpoint
}
}
return nil
}
// matchDomainPattern 判断域名部分是否匹配通配符模式
func matchDomainPattern(hostParts, patternParts []string) bool {
// 如果长度不匹配,则不匹配
if len(hostParts) != len(patternParts) {
return false
}
// 逐部分匹配
for i := 0; i < len(hostParts); i++ {
// 如果模式部分是星号,则匹配任何内容
if patternParts[i] == "*" {
continue
}
// 否则必须精确匹配
if hostParts[i] != patternParts[i] {
return false
}
}
return true
}
// Add 添加域名解析规则
func (r *CustomResolver) Add(host, ip string) error {
return r.AddWithPort(host, ip, 0)
}
// AddWithPort 添加带端口的域名解析规则
func (r *CustomResolver) AddWithPort(host, ip string, port int) error {
if net.ParseIP(ip) == nil {
return errors.New("无效的IP地址")
}
r.mu.Lock()
defer r.mu.Unlock()
r.records[host] = NewEndpointWithPort(ip, port)
return nil
}
// AddWildcard 添加泛解析规则
func (r *CustomResolver) AddWildcard(wildcardDomain, ip string) error {
return r.AddWildcardWithPort(wildcardDomain, ip, 0)
}
// AddWildcardWithPort 添加带端口的泛解析规则
func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error {
if net.ParseIP(ip) == nil {
return errors.New("无效的IP地址")
}
// 检查通配符格式
if !strings.Contains(wildcardDomain, "*") {
return errors.New("泛解析域名必须包含通配符'*'")
}
// 分解通配符域名
parts := strings.Split(wildcardDomain, ".")
r.mu.Lock()
defer r.mu.Unlock()
// 创建新的通配符规则
rule := wildcardRule{
pattern: wildcardDomain,
parts: parts,
endpoint: NewEndpointWithPort(ip, port),
}
// 将新规则添加到规则列表头部,确保更新的规则优先匹配
r.wildcardRules = append([]wildcardRule{rule}, r.wildcardRules...)
return nil
}
// Remove 删除域名解析规则
func (r *CustomResolver) Remove(host string) error {
r.mu.Lock()
defer r.mu.Unlock()
// 先尝试删除精确匹配记录
if _, ok := r.records[host]; ok {
delete(r.records, host)
return nil
}
// 然后尝试删除通配符记录
for i, rule := range r.wildcardRules {
if rule.pattern == host {
// 删除这条规则
r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...)
return nil
}
}
return errors.New("域名记录不存在")
}
// Clear 清除所有解析规则
func (r *CustomResolver) Clear() {
r.mu.Lock()
defer r.mu.Unlock()
r.records = make(map[string]*Endpoint)
r.wildcardRules = make([]wildcardRule, 0)
r.cache = make(map[string]cacheEntry)
}
// Option 解析器选项函数类型
type Option func(*CustomResolver)
// WithFallback 设置是否回退到系统DNS
func WithFallback(fallback bool) Option {
return func(r *CustomResolver) {
r.fallback = fallback
}
}
// WithTTL 设置缓存TTL
func WithTTL(ttl time.Duration) Option {
return func(r *CustomResolver) {
r.ttl = ttl
}
}
// LoadFromMap 从映射加载DNS记录
func (r *CustomResolver) LoadFromMap(records map[string]string) error {
r.mu.Lock()
defer r.mu.Unlock()
for host, value := range records {
// 判断是否为通配符域名
if strings.Contains(host, "*") {
endpoint, err := ParseEndpoint(value)
if err != nil {
return err
}
if net.ParseIP(endpoint.IP) == nil {
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
}
// 添加通配符规则
rule := wildcardRule{
pattern: host,
parts: strings.Split(host, "."),
endpoint: endpoint,
}
r.wildcardRules = append(r.wildcardRules, rule)
} else {
// 常规记录
endpoint, err := ParseEndpoint(value)
if err != nil {
return err
}
if net.ParseIP(endpoint.IP) == nil {
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
}
r.records[host] = endpoint
}
}
// 对通配符规则进行排序,确保更具体的规则先匹配
sortWildcardRules(r.wildcardRules)
return nil
}
// sortWildcardRules 对通配符规则进行排序,使更具体的规则优先匹配
func sortWildcardRules(rules []wildcardRule) {
// 使用稳定排序,保证相同优先级的规则保持原有顺序(后添加的规则在前面)
sort.SliceStable(rules, func(i, j int) bool {
ruleI := rules[i]
ruleJ := rules[j]
// 计算每个规则中通配符的数量
wildcardCountI := countWildcards(ruleI.parts)
wildcardCountJ := countWildcards(ruleJ.parts)
// 通配符数量少的规则更具体,优先级更高
if wildcardCountI != wildcardCountJ {
return wildcardCountI < wildcardCountJ
}
// 如果通配符数量相同,域名部分数量多的更具体
return len(ruleI.parts) > len(ruleJ.parts)
})
}
// countWildcards 计算域名部分中通配符的数量
func countWildcards(parts []string) int {
count := 0
for _, part := range parts {
if part == "*" {
count++
}
}
return count
}