373 lines
8.7 KiB
Go
373 lines
8.7 KiB
Go
package dns
|
||
|
||
import (
|
||
"errors"
|
||
"net"
|
||
"sort"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// Resolver DNS解析器接口
|
||
type Resolver interface {
|
||
// Resolve 将域名解析为IP地址
|
||
Resolve(host string) (string, error)
|
||
|
||
// ResolveWithPort 将域名解析为IP地址和端口
|
||
ResolveWithPort(host string, defaultPort int) (*Endpoint, error)
|
||
|
||
// Add 添加域名解析规则
|
||
Add(host, ip string) error
|
||
|
||
// AddWithPort 添加带端口的域名解析规则
|
||
AddWithPort(host, ip string, port int) error
|
||
|
||
// AddWildcard 添加泛解析规则(通配符域名)
|
||
AddWildcard(wildcardDomain, ip string) error
|
||
|
||
// AddWildcardWithPort 添加带端口的泛解析规则
|
||
AddWildcardWithPort(wildcardDomain, ip string, port int) error
|
||
|
||
// Remove 删除域名解析规则
|
||
Remove(host string) error
|
||
|
||
// Clear 清除所有解析规则
|
||
Clear()
|
||
}
|
||
|
||
// CustomResolver 自定义DNS解析器
|
||
type CustomResolver struct {
|
||
mu sync.RWMutex
|
||
records map[string]*Endpoint // 精确域名到端点的映射
|
||
wildcardRules []wildcardRule // 通配符规则列表
|
||
cache map[string]cacheEntry // 外部域名解析缓存
|
||
fallback bool // 是否在本地记录找不到时回退到系统DNS
|
||
ttl time.Duration // 缓存TTL
|
||
}
|
||
|
||
// wildcardRule 通配符规则
|
||
type wildcardRule struct {
|
||
pattern string // 原始通配符模式,如 *.example.com
|
||
parts []string // 分解后的模式部分,如 ["*", "example", "com"]
|
||
endpoint *Endpoint // 对应的端点
|
||
}
|
||
|
||
// cacheEntry 缓存条目
|
||
type cacheEntry struct {
|
||
endpoint *Endpoint
|
||
expiresAt time.Time
|
||
}
|
||
|
||
// NewResolver 创建新的自定义DNS解析器
|
||
func NewResolver(options ...Option) *CustomResolver {
|
||
r := &CustomResolver{
|
||
records: make(map[string]*Endpoint),
|
||
wildcardRules: make([]wildcardRule, 0),
|
||
cache: make(map[string]cacheEntry),
|
||
fallback: true,
|
||
ttl: 5 * time.Minute,
|
||
}
|
||
|
||
// 应用选项
|
||
for _, option := range options {
|
||
option(r)
|
||
}
|
||
|
||
return r
|
||
}
|
||
|
||
// Resolve 将域名解析为IP地址
|
||
func (r *CustomResolver) Resolve(host string) (string, error) {
|
||
endpoint, err := r.ResolveWithPort(host, 0)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return endpoint.IP, nil
|
||
}
|
||
|
||
// ResolveWithPort 将域名解析为IP地址和端口
|
||
func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoint, error) {
|
||
// 首先检查自定义记录
|
||
r.mu.RLock()
|
||
|
||
// 精确匹配
|
||
if endpoint, ok := r.records[host]; ok {
|
||
r.mu.RUnlock()
|
||
return endpoint, nil
|
||
}
|
||
|
||
// 尝试通配符匹配
|
||
if endpoint := r.matchWildcard(host); endpoint != nil {
|
||
r.mu.RUnlock()
|
||
return endpoint, nil
|
||
}
|
||
|
||
// 检查缓存
|
||
if entry, ok := r.cache[host]; ok {
|
||
if time.Now().Before(entry.expiresAt) {
|
||
r.mu.RUnlock()
|
||
return entry.endpoint, nil
|
||
}
|
||
}
|
||
r.mu.RUnlock()
|
||
|
||
// 如果启用回退,则使用系统DNS
|
||
if r.fallback {
|
||
ips, err := net.LookupIP(host)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 使用第一个IPv4地址
|
||
var ip string
|
||
for _, addr := range ips {
|
||
if ipv4 := addr.To4(); ipv4 != nil {
|
||
ip = ipv4.String()
|
||
break
|
||
}
|
||
}
|
||
|
||
if ip == "" {
|
||
return nil, errors.New("未找到IPv4地址")
|
||
}
|
||
|
||
// 创建端点
|
||
endpoint := NewEndpoint(ip)
|
||
if defaultPort > 0 {
|
||
endpoint.Port = defaultPort
|
||
}
|
||
|
||
// 更新缓存
|
||
r.mu.Lock()
|
||
r.cache[host] = cacheEntry{
|
||
endpoint: endpoint,
|
||
expiresAt: time.Now().Add(r.ttl),
|
||
}
|
||
r.mu.Unlock()
|
||
|
||
return endpoint, nil
|
||
}
|
||
|
||
return nil, errors.New("未找到域名记录且系统DNS回退被禁用")
|
||
}
|
||
|
||
// matchWildcard 尝试匹配通配符规则
|
||
func (r *CustomResolver) matchWildcard(host string) *Endpoint {
|
||
hostParts := strings.Split(host, ".")
|
||
|
||
// 按照通配符规则列表的顺序尝试匹配
|
||
// 规则顺序应该保证更具体的规则先匹配
|
||
for _, rule := range r.wildcardRules {
|
||
if matchDomainPattern(hostParts, rule.parts) {
|
||
return rule.endpoint
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// matchDomainPattern 判断域名部分是否匹配通配符模式
|
||
func matchDomainPattern(hostParts, patternParts []string) bool {
|
||
// 如果长度不匹配,则不匹配
|
||
if len(hostParts) != len(patternParts) {
|
||
return false
|
||
}
|
||
|
||
// 逐部分匹配
|
||
for i := 0; i < len(hostParts); i++ {
|
||
// 如果模式部分是星号,则匹配任何内容
|
||
if patternParts[i] == "*" {
|
||
continue
|
||
}
|
||
|
||
// 否则必须精确匹配
|
||
if hostParts[i] != patternParts[i] {
|
||
return false
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// Add 添加域名解析规则
|
||
func (r *CustomResolver) Add(host, ip string) error {
|
||
return r.AddWithPort(host, ip, 0)
|
||
}
|
||
|
||
// AddWithPort 添加带端口的域名解析规则
|
||
func (r *CustomResolver) AddWithPort(host, ip string, port int) error {
|
||
if net.ParseIP(ip) == nil {
|
||
return errors.New("无效的IP地址")
|
||
}
|
||
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
|
||
r.records[host] = NewEndpointWithPort(ip, port)
|
||
return nil
|
||
}
|
||
|
||
// AddWildcard 添加泛解析规则
|
||
func (r *CustomResolver) AddWildcard(wildcardDomain, ip string) error {
|
||
return r.AddWildcardWithPort(wildcardDomain, ip, 0)
|
||
}
|
||
|
||
// AddWildcardWithPort 添加带端口的泛解析规则
|
||
func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error {
|
||
if net.ParseIP(ip) == nil {
|
||
return errors.New("无效的IP地址")
|
||
}
|
||
|
||
// 检查通配符格式
|
||
if !strings.Contains(wildcardDomain, "*") {
|
||
return errors.New("泛解析域名必须包含通配符'*'")
|
||
}
|
||
|
||
// 分解通配符域名
|
||
parts := strings.Split(wildcardDomain, ".")
|
||
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
|
||
// 创建新的通配符规则
|
||
rule := wildcardRule{
|
||
pattern: wildcardDomain,
|
||
parts: parts,
|
||
endpoint: NewEndpointWithPort(ip, port),
|
||
}
|
||
|
||
// 将新规则添加到规则列表头部,确保更新的规则优先匹配
|
||
r.wildcardRules = append([]wildcardRule{rule}, r.wildcardRules...)
|
||
|
||
return nil
|
||
}
|
||
|
||
// Remove 删除域名解析规则
|
||
func (r *CustomResolver) Remove(host string) error {
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
|
||
// 先尝试删除精确匹配记录
|
||
if _, ok := r.records[host]; ok {
|
||
delete(r.records, host)
|
||
return nil
|
||
}
|
||
|
||
// 然后尝试删除通配符记录
|
||
for i, rule := range r.wildcardRules {
|
||
if rule.pattern == host {
|
||
// 删除这条规则
|
||
r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
return errors.New("域名记录不存在")
|
||
}
|
||
|
||
// Clear 清除所有解析规则
|
||
func (r *CustomResolver) Clear() {
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
|
||
r.records = make(map[string]*Endpoint)
|
||
r.wildcardRules = make([]wildcardRule, 0)
|
||
r.cache = make(map[string]cacheEntry)
|
||
}
|
||
|
||
// Option 解析器选项函数类型
|
||
type Option func(*CustomResolver)
|
||
|
||
// WithFallback 设置是否回退到系统DNS
|
||
func WithFallback(fallback bool) Option {
|
||
return func(r *CustomResolver) {
|
||
r.fallback = fallback
|
||
}
|
||
}
|
||
|
||
// WithTTL 设置缓存TTL
|
||
func WithTTL(ttl time.Duration) Option {
|
||
return func(r *CustomResolver) {
|
||
r.ttl = ttl
|
||
}
|
||
}
|
||
|
||
// LoadFromMap 从映射加载DNS记录
|
||
func (r *CustomResolver) LoadFromMap(records map[string]string) error {
|
||
r.mu.Lock()
|
||
defer r.mu.Unlock()
|
||
|
||
for host, value := range records {
|
||
// 判断是否为通配符域名
|
||
if strings.Contains(host, "*") {
|
||
endpoint, err := ParseEndpoint(value)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if net.ParseIP(endpoint.IP) == nil {
|
||
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
|
||
}
|
||
|
||
// 添加通配符规则
|
||
rule := wildcardRule{
|
||
pattern: host,
|
||
parts: strings.Split(host, "."),
|
||
endpoint: endpoint,
|
||
}
|
||
r.wildcardRules = append(r.wildcardRules, rule)
|
||
|
||
} else {
|
||
// 常规记录
|
||
endpoint, err := ParseEndpoint(value)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if net.ParseIP(endpoint.IP) == nil {
|
||
return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")")
|
||
}
|
||
|
||
r.records[host] = endpoint
|
||
}
|
||
}
|
||
|
||
// 对通配符规则进行排序,确保更具体的规则先匹配
|
||
sortWildcardRules(r.wildcardRules)
|
||
|
||
return nil
|
||
}
|
||
|
||
// sortWildcardRules 对通配符规则进行排序,使更具体的规则优先匹配
|
||
func sortWildcardRules(rules []wildcardRule) {
|
||
// 使用稳定排序,保证相同优先级的规则保持原有顺序(后添加的规则在前面)
|
||
sort.SliceStable(rules, func(i, j int) bool {
|
||
ruleI := rules[i]
|
||
ruleJ := rules[j]
|
||
|
||
// 计算每个规则中通配符的数量
|
||
wildcardCountI := countWildcards(ruleI.parts)
|
||
wildcardCountJ := countWildcards(ruleJ.parts)
|
||
|
||
// 通配符数量少的规则更具体,优先级更高
|
||
if wildcardCountI != wildcardCountJ {
|
||
return wildcardCountI < wildcardCountJ
|
||
}
|
||
|
||
// 如果通配符数量相同,域名部分数量多的更具体
|
||
return len(ruleI.parts) > len(ruleJ.parts)
|
||
})
|
||
}
|
||
|
||
// countWildcards 计算域名部分中通配符的数量
|
||
func countWildcards(parts []string) int {
|
||
count := 0
|
||
for _, part := range parts {
|
||
if part == "*" {
|
||
count++
|
||
}
|
||
}
|
||
return count
|
||
}
|