From 81b142925fac8c717f8488d1c05dbc0116dac752 Mon Sep 17 00:00:00 2001 From: Brian Cunnie Date: Sun, 24 Sep 2023 10:07:12 +0200 Subject: [PATCH] Clarify UDP binding code In preparation for TCP binding, I re-worked the UDP binding process so that it could be more understandable and more easily replicated. I don't know that it's more understandable. I may have failed. --- src/sslip.io-dns-server/main.go | 87 ++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/src/sslip.io-dns-server/main.go b/src/sslip.io-dns-server/main.go index 2236d3f..21a4768 100644 --- a/src/sslip.io-dns-server/main.go +++ b/src/sslip.io-dns-server/main.go @@ -40,55 +40,42 @@ func main() { log.Println(logmessage) } - connUdp, err := net.ListenUDP("udp", &net.UDPAddr{Port: *bindPort}) - // common err hierarchy: net.OpError → os.SyscallError → syscall.Errno - switch { - case err == nil: - log.Printf("Successfully bound via UDP to all IPs, port %d.\n", *bindPort) - wg.Add(1) - go readFrom(connUdp, &wg, x, *quiet) - case isErrorPermissionsError(err): - log.Printf("Try invoking me with `sudo` because I don't have permission to bind to port %d.\n", *bindPort) + var udpConns []*net.UDPConn + var unboundIPs []string + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: *bindPort}) + if err == nil { + udpConns = append(udpConns, udpConn) + } + if isErrorPermissionsError(err) { + log.Printf("Try invoking me with `sudo` because I don't have permission to bind to UDP port %d.\n", *bindPort) log.Fatal(err.Error()) - case isErrorAddressAlreadyInUse(err): - log.Printf("I couldn't bind via UDP to \"0.0.0.0:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort) - ipCIDRs := listLocalIPCIDRs() - var boundIPsPorts, unboundIPs []string - for _, ipCIDR := range ipCIDRs { - ip, _, err := net.ParseCIDR(ipCIDR) - if err != nil { - log.Printf(`I couldn't parse the local interface "%s".`, ipCIDR) - continue - } - connUdp, err = net.ListenUDP("udp", &net.UDPAddr{ - IP: ip, - Port: *bindPort, - Zone: "", - }) - if err != nil { - unboundIPs = append(unboundIPs, ip.String()) - } else { - wg.Add(1) - boundIPsPorts = append(boundIPsPorts, connUdp.LocalAddr().String()) - go readFrom(connUdp, &wg, x, *quiet) - } - } - if len(boundIPsPorts) == 0 { - log.Fatalf("I couldn't bind via UDP to any IPs on port %d, so I'm exiting", *bindPort) - } - log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundIPsPorts, `", "`)) + } + if isErrorAddressAlreadyInUse(err) { + // do some stuff + log.Printf("I couldn't bind via UDP to \"[::]:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort) + udpConns, unboundIPs = bindUDPAddressesIndividually(*bindPort) if len(unboundIPs) > 0 { log.Printf(`I couldn't bind via UDP to the following IPs: "%s"`, strings.Join(unboundIPs, `", "`)) } - default: + if len(udpConns) == 0 { + log.Fatalf("I couldn't bind via UDP to any IPs on port %d, so I'm exiting", *bindPort) + } + } + if err != nil { log.Fatal(err.Error()) } + var boundIPs []string + for _, udpConn := range udpConns { + boundIPs = append(boundIPs, udpConn.LocalAddr().String()) + go readFrom(udpConn, x, *quiet) + wg.Add(1) + } + log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundIPs, `", "`)) log.Printf("Ready to answer queries") wg.Wait() } -func readFrom(conn *net.UDPConn, wg *sync.WaitGroup, x *xip.Xip, quiet bool) { - defer wg.Done() +func readFrom(conn *net.UDPConn, x *xip.Xip, quiet bool) { for { query := make([]byte, 512) _, addr, err := conn.ReadFromUDP(query) @@ -110,6 +97,28 @@ func readFrom(conn *net.UDPConn, wg *sync.WaitGroup, x *xip.Xip, quiet bool) { } } +func bindUDPAddressesIndividually(bindPort int) (udpConns []*net.UDPConn, unboundIPs []string) { + ipCIDRs := listLocalIPCIDRs() + for _, ipCIDR := range ipCIDRs { + ip, _, err := net.ParseCIDR(ipCIDR) + if err != nil { + log.Printf(`I couldn't parse the local interface "%s".`, ipCIDR) + continue + } + udpConn, err := net.ListenUDP("udp", &net.UDPAddr{ + IP: ip, + Port: bindPort, + Zone: "", + }) + if err != nil { + unboundIPs = append(unboundIPs, ip.String()) + } else { + udpConns = append(udpConns, udpConn) + } + } + return udpConns, unboundIPs +} + func listLocalIPCIDRs() []string { var ifaces []net.Interface var cidrStrings []string