Split getting DNS configuration.

For UNIX systems read and parse `/etc/resolv.conf` file.
 On Windows use 'GetAdaptersAddresses' syscall to get DNS configuration

Signed-off-by: Yevhen Vydolob <yvydolob@redhat.com>
This commit is contained in:
Yevhen Vydolob
2024-03-26 15:49:23 +02:00
parent 51df06de85
commit 5579f34daa
4 changed files with 355 additions and 13 deletions

View File

@@ -22,7 +22,7 @@ type dnsHandler struct {
nameserver string
}
func newDnsHandler(zones []types.Zone) *dnsHandler {
func newDNSHandler(zones []types.Zone) *dnsHandler {
dnsClient, nameserver := readAndCreateClient()
@@ -35,23 +35,16 @@ func newDnsHandler(zones []types.Zone) *dnsHandler {
}
func readAndCreateClient() (*dns.Client, string) {
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
nameserver, port, err := GetDNSHostAndPort()
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(2)
}
nameserver := conf.Servers[0]
// if the nameserver is from /etc/resolv.conf the [ and ] are already
// added, thereby breaking net.ParseIP. Check for this and don't
// fully qualify such a name
if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' {
nameserver = nameserver[1 : len(nameserver)-1]
}
if i := net.ParseIP(nameserver); i != nil {
nameserver = net.JoinHostPort(nameserver, conf.Port)
nameserver = net.JoinHostPort(nameserver, port)
} else {
nameserver = dns.Fqdn(nameserver) + ":" + conf.Port
nameserver = dns.Fqdn(nameserver) + ":" + port
}
client := new(dns.Client)
client.Net = "udp"
@@ -163,7 +156,7 @@ type Server struct {
}
func New(udpConn net.PacketConn, tcpLn net.Listener, zones []types.Zone) (*Server, error) {
handler := newDnsHandler(zones)
handler := newDNSHandler(zones)
return &Server{udpConn: udpConn, tcpLn: tcpLn, handler: handler}, nil
}

View File

@@ -0,0 +1,28 @@
//go:build !windows
package dns
import (
"fmt"
"os"
"github.com/miekg/dns"
)
func GetDNSHostAndPort() (string, string, error) {
conf, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
fmt.Fprintln(os.Stderr, err)
return "", "", err
}
nameserver := conf.Servers[0]
// if the nameserver is from /etc/resolv.conf the [ and ] are already
// added, thereby breaking net.ParseIP. Check for this and don't
// fully qualify such a name
if nameserver[0] == '[' && nameserver[len(nameserver)-1] == ']' {
nameserver = nameserver[1 : len(nameserver)-1]
}
return nameserver, conf.Port, nil
}

View File

@@ -0,0 +1,296 @@
//go:build windows
package dns
import (
"errors"
"fmt"
"net/netip"
"strconv"
"syscall"
"unsafe"
)
func GetDNSHostAndPort() (string, string, error) {
nameservers := getDNSServers()
var nameserver netip.AddrPort
for _, n := range nameservers {
// return first non ipv6 nameserver
if n.Addr().Is4() {
nameserver = n
break
}
}
return nameserver.Addr().String(), strconv.Itoa(int(nameserver.Port())), nil
}
// copied from https://github.com/qdm12/dns/blob/v2.0.0-beta/pkg/nameserver/getlocal_windows.go
// this function will use windows syscall to get DNS configuration
func getDNSServers() (nameservers []netip.AddrPort) {
const defaultDNSPort = 53
defaultLocalNameservers := []netip.AddrPort{
netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), defaultDNSPort),
netip.AddrPortFrom(netip.AddrFrom16([16]byte{0, 0, 0, 0, 0, 0, 0, 1}), defaultDNSPort),
}
adapterAddresses, err := getAdapterAddresses()
if err != nil {
return defaultLocalNameservers
}
for _, adapterAddress := range adapterAddresses {
const statusUp = 0x01
if adapterAddress.operStatus != statusUp {
continue
}
if adapterAddress.firstGatewayAddress == nil {
// Only search DNS servers for adapters having a gateway
continue
}
dnsServerAddress := adapterAddress.firstDnsServerAddress
for dnsServerAddress != nil {
ip, ok := sockAddressToIP(dnsServerAddress.address.rawSockAddrAny)
if !ok || ipIsSiteLocalAnycast(ip) {
// fec0/10 IPv6 addresses are site local anycast DNS
// addresses Microsoft sets by default if no other
// IPv6 DNS address is set. Site local anycast is
// deprecated since 2004, see
// https://datatracker.ietf.org/doc/html/rfc3879
dnsServerAddress = dnsServerAddress.next
continue
}
nameserver := netip.AddrPortFrom(ip, defaultDNSPort)
nameservers = append(nameservers, nameserver)
dnsServerAddress = dnsServerAddress.next
}
}
if len(nameservers) == 0 {
return defaultLocalNameservers
}
return nameservers
}
var (
errBufferOverflowUnexpected = errors.New("unexpected buffer overflowed because buffer was large enough")
)
func getAdapterAddresses() (
adapterAddresses []*ipAdapterAddresses, err error) {
var buffer []byte
const initialBufferLength uint32 = 15000
sizeVar := initialBufferLength
for {
buffer = make([]byte, sizeVar)
err := runProcGetAdaptersAddresses(
(*ipAdapterAddresses)(unsafe.Pointer(&buffer[0])),
&sizeVar)
if err != nil {
if err.(syscall.Errno) == syscall.ERROR_BUFFER_OVERFLOW {
if sizeVar <= uint32(len(buffer)) {
return nil, fmt.Errorf("%w: buffer size variable %d is "+
"equal or lower to the buffer current length %d",
errBufferOverflowUnexpected, sizeVar, len(buffer))
}
continue
}
return nil, fmt.Errorf("getting adapters addresses: %w", err)
}
noDataFound := sizeVar == 0
if noDataFound {
return nil, nil
}
break
}
adapterAddress := (*ipAdapterAddresses)(unsafe.Pointer(&buffer[0]))
for adapterAddress != nil {
adapterAddresses = append(adapterAddresses, adapterAddress)
adapterAddress = adapterAddress.next
}
return adapterAddresses, nil
}
var (
procGetAdaptersAddresses = syscall.NewLazyDLL("iphlpapi.dll").
NewProc("GetAdaptersAddresses")
)
func runProcGetAdaptersAddresses(adapterAddresses *ipAdapterAddresses,
sizePointer *uint32) (errcode error) {
const family = syscall.AF_UNSPEC
const GAA_FLAG_SKIP_UNICAST = 0x0001
const GAA_FLAG_SKIP_ANYCAST = 0x0002
const GAA_FLAG_SKIP_MULTICAST = 0x0004
const GAA_FLAG_SKIP_FRIENDLY_NAME = 0x0020
const GAA_FLAG_INCLUDE_GATEWAYS = 0x0080
const flags = GAA_FLAG_SKIP_UNICAST | GAA_FLAG_SKIP_ANYCAST |
GAA_FLAG_SKIP_MULTICAST | GAA_FLAG_SKIP_FRIENDLY_NAME |
GAA_FLAG_INCLUDE_GATEWAYS
const reserved = 0
// See https://learn.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses
r1, _, err := syscall.SyscallN(procGetAdaptersAddresses.Addr(),
uintptr(family), uintptr(flags), uintptr(reserved),
uintptr(unsafe.Pointer(adapterAddresses)),
uintptr(unsafe.Pointer(sizePointer)))
switch {
case err != 0:
return err
case r1 != 0:
return syscall.Errno(r1)
default:
return nil
}
}
func sockAddressToIP(rawSockAddress *syscall.RawSockaddrAny) (ip netip.Addr, ok bool) {
if rawSockAddress == nil {
return netip.Addr{}, false
}
sockAddress, err := rawSockAddress.Sockaddr()
if err != nil {
return netip.Addr{}, false
}
switch sockAddress := sockAddress.(type) {
case *syscall.SockaddrInet4:
return netip.AddrFrom4([4]byte{
sockAddress.Addr[0], sockAddress.Addr[1], sockAddress.Addr[2], sockAddress.Addr[3]}),
true
case *syscall.SockaddrInet6:
return netip.AddrFrom16([16]byte{
sockAddress.Addr[0], sockAddress.Addr[1], sockAddress.Addr[2], sockAddress.Addr[3],
sockAddress.Addr[4], sockAddress.Addr[5], sockAddress.Addr[6], sockAddress.Addr[7],
sockAddress.Addr[8], sockAddress.Addr[9], sockAddress.Addr[10], sockAddress.Addr[11],
sockAddress.Addr[12], sockAddress.Addr[13], sockAddress.Addr[14], sockAddress.Addr[15]}),
true
default:
return netip.Addr{}, false
}
}
func ipIsSiteLocalAnycast(ip netip.Addr) bool {
if !ip.Is6() {
return false
}
array := ip.As16()
return array[0] == 0xfe && array[1] == 0xc0
}
// See https://learn.microsoft.com/en-us/windows/win32/api/iptypes/ns-iptypes-ip_adapter_addresses_lh
type ipAdapterAddresses struct {
// The order of fields DOES matter since they are read
// raw from a bytes buffer. However, we are only interested
// in a few select fields, so unneeded fields are either
// named as "_" or removed if they are after the fields
// we are interested in.
_ uint32
_ uint32
next *ipAdapterAddresses
_ *byte
_ *ipAdapterUnicastAddress
_ *ipAdapterAnycastAddress
_ *ipAdapterMulticastAddress
firstDnsServerAddress *ipAdapterDnsServerAdapter
_ *uint16
_ *uint16
_ *uint16
_ [syscall.MAX_ADAPTER_ADDRESS_LENGTH]byte
_ uint32
_ uint32
_ uint32
_ uint32
operStatus uint32
_ uint32
_ [16]uint32
_ *ipAdapterPrefix
_ uint64
_ uint64
_ *ipAdapterWinsServerAddress
firstGatewayAddress *ipAdapterGatewayAddress
// Additional fields not needed here
}
type ipAdapterUnicastAddress struct {
// The order of fields DOES matter since they are read raw
// from a bytes buffer. However, we are not interested in
// the value of any field, so they are all named as "_".
_ uint32
_ uint32
_ *ipAdapterUnicastAddress
_ ipAdapterSocketAddress
_ int32
_ int32
_ int32
_ uint32
_ uint32
_ uint32
_ uint8
}
type ipAdapterAnycastAddress struct {
// The order of fields DOES matter since they are read raw
// from a bytes buffer. However, we are not interested in
// the value of any field, so they are all named as "_".
_ uint32
_ uint32
_ *ipAdapterAnycastAddress
_ ipAdapterSocketAddress
}
type ipAdapterMulticastAddress struct {
// The order of fields DOES matter since they are read raw
// from a bytes buffer. However, we are only interested in
// a few select fields, so unneeded fields are named as "_".
_ uint32
_ uint32
_ *ipAdapterMulticastAddress
_ ipAdapterSocketAddress
}
type ipAdapterDnsServerAdapter struct {
// The order of fields DOES matter since they are read raw
// from a bytes buffer. However, we are only interested in
// a few select fields, so unneeded fields are named as "_".
_ uint32
_ uint32
next *ipAdapterDnsServerAdapter
address ipAdapterSocketAddress
}
type ipAdapterPrefix struct {
_ uint32
_ uint32
_ *ipAdapterPrefix
_ ipAdapterSocketAddress
_ uint32
}
type ipAdapterWinsServerAddress struct {
_ uint32
_ uint32
_ *ipAdapterWinsServerAddress
_ ipAdapterSocketAddress
}
type ipAdapterGatewayAddress struct {
_ uint32
_ uint32
_ *ipAdapterGatewayAddress
_ ipAdapterSocketAddress
}
type ipAdapterSocketAddress struct {
rawSockAddrAny *syscall.RawSockaddrAny
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/containers/gvisor-tap-vsock/pkg/types"
"github.com/miekg/dns"
"github.com/onsi/ginkgo"
"github.com/onsi/gomega"
)
@@ -191,4 +192,28 @@ var _ = ginkgo.Describe("dns add test", func() {
},
}))
})
ginkgo.It("Should pass DNS requests to default system DNS server", func() {
m := &dns.Msg{
MsgHdr: dns.MsgHdr{
Authoritative: false,
AuthenticatedData: false,
CheckingDisabled: false,
RecursionDesired: true,
Opcode: 0,
},
Question: make([]dns.Question, 1),
}
m.Question[0] = dns.Question{
Name: "redhat.com.",
Qtype: 1,
Qclass: 1,
}
server.handler.addAnswers(m)
gomega.Expect(m.Answer[0].Header().Name).To(gomega.Equal("redhat.com."))
gomega.Expect(m.Answer[0].String()).To(gomega.SatisfyAny(gomega.ContainSubstring("34.235.198.240"), gomega.ContainSubstring("52.200.142.250")))
})
})