package dns import ( "bufio" "encoding/json" "fmt" "os" "regexp" "strings" "time" ) // DNSConfig DNS配置文件结构 type DNSConfig struct { Records map[string]string `json:"records"` // 普通记录和泛解析记录 Fallback bool `json:"fallback"` // 是否回退到系统DNS TTL int `json:"ttl"` // 缓存TTL,单位为秒 } // LoadFromJSON 从JSON文件加载DNS配置 func LoadFromJSON(filePath string) (*DNSConfig, error) { file, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("打开DNS配置文件失败: %w", err) } defer file.Close() config := &DNSConfig{ Records: make(map[string]string), Fallback: true, TTL: 300, // 默认5分钟 } decoder := json.NewDecoder(file) if err := decoder.Decode(config); err != nil { return nil, fmt.Errorf("解析DNS配置文件失败: %w", err) } return config, nil } // SaveToJSON 将DNS配置保存为JSON文件 func (c *DNSConfig) SaveToJSON(filePath string) error { file, err := os.Create(filePath) if err != nil { return fmt.Errorf("创建DNS配置文件失败: %w", err) } defer file.Close() encoder := json.NewEncoder(file) encoder.SetIndent("", " ") if err := encoder.Encode(c); err != nil { return fmt.Errorf("保存DNS配置文件失败: %w", err) } return nil } // 用于解析hosts文件中的IP:端口格式 var ipPortRegex = regexp.MustCompile(`^([0-9.]+)(?::(\d+))?$`) // 检查是否为通配符域名 func isWildcardDomain(domain string) bool { return strings.Contains(domain, "*") } // LoadFromHostsFile 从hosts文件格式加载DNS配置 func LoadFromHostsFile(filePath string) (*DNSConfig, error) { file, err := os.Open(filePath) if err != nil { return nil, fmt.Errorf("打开hosts文件失败: %w", err) } defer file.Close() config := &DNSConfig{ Records: make(map[string]string), Fallback: true, TTL: 300, // 默认5分钟 } scanner := bufio.NewScanner(file) lineNum := 0 for scanner.Scan() { lineNum++ line := strings.TrimSpace(scanner.Text()) // 跳过空行和注释 if line == "" || strings.HasPrefix(line, "#") { continue } fields := strings.Fields(line) if len(fields) < 2 { continue // 行格式不正确,跳过 } ipPortStr := fields[0] domains := fields[1:] // 解析IP和可能的端口 matches := ipPortRegex.FindStringSubmatch(ipPortStr) if matches == nil { continue // IP格式不正确,跳过 } ip := matches[1] portStr := matches[2] // 构造记录值 value := ip if portStr != "" { value = ip + ":" + portStr } for _, domain := range domains { // 跳过注释 if strings.HasPrefix(domain, "#") { break } // 支持通配符和普通域名 config.Records[domain] = value } } if err := scanner.Err(); err != nil { return nil, fmt.Errorf("读取hosts文件失败: %w", err) } return config, nil } // NewResolverFromConfig 从配置创建解析器 func NewResolverFromConfig(config *DNSConfig) *CustomResolver { var ttl time.Duration if config.TTL > 0 { ttl = time.Duration(config.TTL) * time.Second } else { ttl = 5 * time.Minute // 默认5分钟 } resolver := NewResolver( WithFallback(config.Fallback), WithTTL(ttl), ) // 加载记录 resolver.LoadFromMap(config.Records) return resolver }