package dns import ( "errors" "net" "sort" "strings" "sync" "time" ) // Resolver DNS解析器接口 type Resolver interface { // Resolve 将域名解析为IP地址 Resolve(host string) (string, error) // ResolveWithPort 将域名解析为IP地址和端口 ResolveWithPort(host string, defaultPort int) (*Endpoint, error) // Add 添加域名解析规则 Add(host, ip string) error // AddWithPort 添加带端口的域名解析规则 AddWithPort(host, ip string, port int) error // AddWildcard 添加泛解析规则(通配符域名) AddWildcard(wildcardDomain, ip string) error // AddWildcardWithPort 添加带端口的泛解析规则 AddWildcardWithPort(wildcardDomain, ip string, port int) error // Remove 删除域名解析规则 Remove(host string) error // Clear 清除所有解析规则 Clear() } // CustomResolver 自定义DNS解析器 type CustomResolver struct { mu sync.RWMutex records map[string]*Endpoint // 精确域名到端点的映射 wildcardRules []wildcardRule // 通配符规则列表 cache map[string]cacheEntry // 外部域名解析缓存 fallback bool // 是否在本地记录找不到时回退到系统DNS ttl time.Duration // 缓存TTL } // wildcardRule 通配符规则 type wildcardRule struct { pattern string // 原始通配符模式,如 *.example.com parts []string // 分解后的模式部分,如 ["*", "example", "com"] endpoint *Endpoint // 对应的端点 } // cacheEntry 缓存条目 type cacheEntry struct { endpoint *Endpoint expiresAt time.Time } // NewResolver 创建新的自定义DNS解析器 func NewResolver(options ...Option) *CustomResolver { r := &CustomResolver{ records: make(map[string]*Endpoint), wildcardRules: make([]wildcardRule, 0), cache: make(map[string]cacheEntry), fallback: true, ttl: 5 * time.Minute, } // 应用选项 for _, option := range options { option(r) } return r } // Resolve 将域名解析为IP地址 func (r *CustomResolver) Resolve(host string) (string, error) { endpoint, err := r.ResolveWithPort(host, 0) if err != nil { return "", err } return endpoint.IP, nil } // ResolveWithPort 将域名解析为IP地址和端口 func (r *CustomResolver) ResolveWithPort(host string, defaultPort int) (*Endpoint, error) { // 首先检查自定义记录 r.mu.RLock() // 精确匹配 if endpoint, ok := r.records[host]; ok { r.mu.RUnlock() return endpoint, nil } // 尝试通配符匹配 if endpoint := r.matchWildcard(host); endpoint != nil { r.mu.RUnlock() return endpoint, nil } // 检查缓存 if entry, ok := r.cache[host]; ok { if time.Now().Before(entry.expiresAt) { r.mu.RUnlock() return entry.endpoint, nil } } r.mu.RUnlock() // 如果启用回退,则使用系统DNS if r.fallback { ips, err := net.LookupIP(host) if err != nil { return nil, err } // 使用第一个IPv4地址 var ip string for _, addr := range ips { if ipv4 := addr.To4(); ipv4 != nil { ip = ipv4.String() break } } if ip == "" { return nil, errors.New("未找到IPv4地址") } // 创建端点 endpoint := NewEndpoint(ip) if defaultPort > 0 { endpoint.Port = defaultPort } // 更新缓存 r.mu.Lock() r.cache[host] = cacheEntry{ endpoint: endpoint, expiresAt: time.Now().Add(r.ttl), } r.mu.Unlock() return endpoint, nil } return nil, errors.New("未找到域名记录且系统DNS回退被禁用") } // matchWildcard 尝试匹配通配符规则 func (r *CustomResolver) matchWildcard(host string) *Endpoint { hostParts := strings.Split(host, ".") // 按照通配符规则列表的顺序尝试匹配 // 规则顺序应该保证更具体的规则先匹配 for _, rule := range r.wildcardRules { if matchDomainPattern(hostParts, rule.parts) { return rule.endpoint } } return nil } // matchDomainPattern 判断域名部分是否匹配通配符模式 func matchDomainPattern(hostParts, patternParts []string) bool { // 如果长度不匹配,则不匹配 if len(hostParts) != len(patternParts) { return false } // 逐部分匹配 for i := 0; i < len(hostParts); i++ { // 如果模式部分是星号,则匹配任何内容 if patternParts[i] == "*" { continue } // 否则必须精确匹配 if hostParts[i] != patternParts[i] { return false } } return true } // Add 添加域名解析规则 func (r *CustomResolver) Add(host, ip string) error { return r.AddWithPort(host, ip, 0) } // AddWithPort 添加带端口的域名解析规则 func (r *CustomResolver) AddWithPort(host, ip string, port int) error { if net.ParseIP(ip) == nil { return errors.New("无效的IP地址") } r.mu.Lock() defer r.mu.Unlock() r.records[host] = NewEndpointWithPort(ip, port) return nil } // AddWildcard 添加泛解析规则 func (r *CustomResolver) AddWildcard(wildcardDomain, ip string) error { return r.AddWildcardWithPort(wildcardDomain, ip, 0) } // AddWildcardWithPort 添加带端口的泛解析规则 func (r *CustomResolver) AddWildcardWithPort(wildcardDomain, ip string, port int) error { if net.ParseIP(ip) == nil { return errors.New("无效的IP地址") } // 检查通配符格式 if !strings.Contains(wildcardDomain, "*") { return errors.New("泛解析域名必须包含通配符'*'") } // 分解通配符域名 parts := strings.Split(wildcardDomain, ".") r.mu.Lock() defer r.mu.Unlock() // 创建新的通配符规则 rule := wildcardRule{ pattern: wildcardDomain, parts: parts, endpoint: NewEndpointWithPort(ip, port), } // 将新规则添加到规则列表头部,确保更新的规则优先匹配 r.wildcardRules = append([]wildcardRule{rule}, r.wildcardRules...) return nil } // Remove 删除域名解析规则 func (r *CustomResolver) Remove(host string) error { r.mu.Lock() defer r.mu.Unlock() // 先尝试删除精确匹配记录 if _, ok := r.records[host]; ok { delete(r.records, host) return nil } // 然后尝试删除通配符记录 for i, rule := range r.wildcardRules { if rule.pattern == host { // 删除这条规则 r.wildcardRules = append(r.wildcardRules[:i], r.wildcardRules[i+1:]...) return nil } } return errors.New("域名记录不存在") } // Clear 清除所有解析规则 func (r *CustomResolver) Clear() { r.mu.Lock() defer r.mu.Unlock() r.records = make(map[string]*Endpoint) r.wildcardRules = make([]wildcardRule, 0) r.cache = make(map[string]cacheEntry) } // Option 解析器选项函数类型 type Option func(*CustomResolver) // WithFallback 设置是否回退到系统DNS func WithFallback(fallback bool) Option { return func(r *CustomResolver) { r.fallback = fallback } } // WithTTL 设置缓存TTL func WithTTL(ttl time.Duration) Option { return func(r *CustomResolver) { r.ttl = ttl } } // LoadFromMap 从映射加载DNS记录 func (r *CustomResolver) LoadFromMap(records map[string]string) error { r.mu.Lock() defer r.mu.Unlock() for host, value := range records { // 判断是否为通配符域名 if strings.Contains(host, "*") { endpoint, err := ParseEndpoint(value) if err != nil { return err } if net.ParseIP(endpoint.IP) == nil { return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")") } // 添加通配符规则 rule := wildcardRule{ pattern: host, parts: strings.Split(host, "."), endpoint: endpoint, } r.wildcardRules = append(r.wildcardRules, rule) } else { // 常规记录 endpoint, err := ParseEndpoint(value) if err != nil { return err } if net.ParseIP(endpoint.IP) == nil { return errors.New("无效的IP地址: " + endpoint.IP + " (域名: " + host + ")") } r.records[host] = endpoint } } // 对通配符规则进行排序,确保更具体的规则先匹配 sortWildcardRules(r.wildcardRules) return nil } // sortWildcardRules 对通配符规则进行排序,使更具体的规则优先匹配 func sortWildcardRules(rules []wildcardRule) { // 使用稳定排序,保证相同优先级的规则保持原有顺序(后添加的规则在前面) sort.SliceStable(rules, func(i, j int) bool { ruleI := rules[i] ruleJ := rules[j] // 计算每个规则中通配符的数量 wildcardCountI := countWildcards(ruleI.parts) wildcardCountJ := countWildcards(ruleJ.parts) // 通配符数量少的规则更具体,优先级更高 if wildcardCountI != wildcardCountJ { return wildcardCountI < wildcardCountJ } // 如果通配符数量相同,域名部分数量多的更具体 return len(ruleI.parts) > len(ruleJ.parts) }) } // countWildcards 计算域名部分中通配符的数量 func countWildcards(parts []string) int { count := 0 for _, part := range parts { if part == "*" { count++ } } return count }