152 lines
3.3 KiB
Go
152 lines
3.3 KiB
Go
package dns
|
||
|
||
import (
|
||
"bufio"
|
||
"encoding/json"
|
||
"fmt"
|
||
"os"
|
||
"regexp"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// DNSConfig DNS配置文件结构
|
||
type DNSConfig struct {
|
||
Records map[string]string `json:"records"` // 普通记录和泛解析记录
|
||
Fallback bool `json:"fallback"` // 是否回退到系统DNS
|
||
TTL int `json:"ttl"` // 缓存TTL,单位为秒
|
||
}
|
||
|
||
// LoadFromJSON 从JSON文件加载DNS配置
|
||
func LoadFromJSON(filePath string) (*DNSConfig, error) {
|
||
file, err := os.Open(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("打开DNS配置文件失败: %w", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
config := &DNSConfig{
|
||
Records: make(map[string]string),
|
||
Fallback: true,
|
||
TTL: 300, // 默认5分钟
|
||
}
|
||
|
||
decoder := json.NewDecoder(file)
|
||
if err := decoder.Decode(config); err != nil {
|
||
return nil, fmt.Errorf("解析DNS配置文件失败: %w", err)
|
||
}
|
||
|
||
return config, nil
|
||
}
|
||
|
||
// SaveToJSON 将DNS配置保存为JSON文件
|
||
func (c *DNSConfig) SaveToJSON(filePath string) error {
|
||
file, err := os.Create(filePath)
|
||
if err != nil {
|
||
return fmt.Errorf("创建DNS配置文件失败: %w", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
encoder := json.NewEncoder(file)
|
||
encoder.SetIndent("", " ")
|
||
if err := encoder.Encode(c); err != nil {
|
||
return fmt.Errorf("保存DNS配置文件失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 用于解析hosts文件中的IP:端口格式
|
||
var ipPortRegex = regexp.MustCompile(`^([0-9.]+)(?::(\d+))?$`)
|
||
|
||
// 检查是否为通配符域名
|
||
func isWildcardDomain(domain string) bool {
|
||
return strings.Contains(domain, "*")
|
||
}
|
||
|
||
// LoadFromHostsFile 从hosts文件格式加载DNS配置
|
||
func LoadFromHostsFile(filePath string) (*DNSConfig, error) {
|
||
file, err := os.Open(filePath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("打开hosts文件失败: %w", err)
|
||
}
|
||
defer file.Close()
|
||
|
||
config := &DNSConfig{
|
||
Records: make(map[string]string),
|
||
Fallback: true,
|
||
TTL: 300, // 默认5分钟
|
||
}
|
||
|
||
scanner := bufio.NewScanner(file)
|
||
lineNum := 0
|
||
for scanner.Scan() {
|
||
lineNum++
|
||
line := strings.TrimSpace(scanner.Text())
|
||
|
||
// 跳过空行和注释
|
||
if line == "" || strings.HasPrefix(line, "#") {
|
||
continue
|
||
}
|
||
|
||
fields := strings.Fields(line)
|
||
if len(fields) < 2 {
|
||
continue // 行格式不正确,跳过
|
||
}
|
||
|
||
ipPortStr := fields[0]
|
||
domains := fields[1:]
|
||
|
||
// 解析IP和可能的端口
|
||
matches := ipPortRegex.FindStringSubmatch(ipPortStr)
|
||
if matches == nil {
|
||
continue // IP格式不正确,跳过
|
||
}
|
||
|
||
ip := matches[1]
|
||
portStr := matches[2]
|
||
|
||
// 构造记录值
|
||
value := ip
|
||
if portStr != "" {
|
||
value = ip + ":" + portStr
|
||
}
|
||
|
||
for _, domain := range domains {
|
||
// 跳过注释
|
||
if strings.HasPrefix(domain, "#") {
|
||
break
|
||
}
|
||
|
||
// 支持通配符和普通域名
|
||
config.Records[domain] = value
|
||
}
|
||
}
|
||
|
||
if err := scanner.Err(); err != nil {
|
||
return nil, fmt.Errorf("读取hosts文件失败: %w", err)
|
||
}
|
||
|
||
return config, nil
|
||
}
|
||
|
||
// NewResolverFromConfig 从配置创建解析器
|
||
func NewResolverFromConfig(config *DNSConfig) *CustomResolver {
|
||
var ttl time.Duration
|
||
if config.TTL > 0 {
|
||
ttl = time.Duration(config.TTL) * time.Second
|
||
} else {
|
||
ttl = 5 * time.Minute // 默认5分钟
|
||
}
|
||
|
||
resolver := NewResolver(
|
||
WithFallback(config.Fallback),
|
||
WithTTL(ttl),
|
||
)
|
||
|
||
// 加载记录
|
||
resolver.LoadFromMap(config.Records)
|
||
|
||
return resolver
|
||
}
|