Clarify UDP binding code

In preparation for TCP binding, I re-worked the UDP binding process so
that it could be more understandable and more easily replicated.

I don't know that it's more understandable. I may have failed.
This commit is contained in:
Brian Cunnie
2023-09-24 10:07:12 +02:00
parent 4a095ca2b6
commit 81b142925f

View File

@@ -40,55 +40,42 @@ func main() {
log.Println(logmessage) log.Println(logmessage)
} }
connUdp, err := net.ListenUDP("udp", &net.UDPAddr{Port: *bindPort}) var udpConns []*net.UDPConn
// common err hierarchy: net.OpError → os.SyscallError → syscall.Errno var unboundIPs []string
switch { udpConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: *bindPort})
case err == nil: if err == nil {
log.Printf("Successfully bound via UDP to all IPs, port %d.\n", *bindPort) udpConns = append(udpConns, udpConn)
wg.Add(1) }
go readFrom(connUdp, &wg, x, *quiet) if isErrorPermissionsError(err) {
case isErrorPermissionsError(err): log.Printf("Try invoking me with `sudo` because I don't have permission to bind to UDP port %d.\n", *bindPort)
log.Printf("Try invoking me with `sudo` because I don't have permission to bind to port %d.\n", *bindPort)
log.Fatal(err.Error()) log.Fatal(err.Error())
case isErrorAddressAlreadyInUse(err): }
log.Printf("I couldn't bind via UDP to \"0.0.0.0:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort) if isErrorAddressAlreadyInUse(err) {
ipCIDRs := listLocalIPCIDRs() // do some stuff
var boundIPsPorts, unboundIPs []string log.Printf("I couldn't bind via UDP to \"[::]:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort)
for _, ipCIDR := range ipCIDRs { udpConns, unboundIPs = bindUDPAddressesIndividually(*bindPort)
ip, _, err := net.ParseCIDR(ipCIDR)
if err != nil {
log.Printf(`I couldn't parse the local interface "%s".`, ipCIDR)
continue
}
connUdp, err = net.ListenUDP("udp", &net.UDPAddr{
IP: ip,
Port: *bindPort,
Zone: "",
})
if err != nil {
unboundIPs = append(unboundIPs, ip.String())
} else {
wg.Add(1)
boundIPsPorts = append(boundIPsPorts, connUdp.LocalAddr().String())
go readFrom(connUdp, &wg, x, *quiet)
}
}
if len(boundIPsPorts) == 0 {
log.Fatalf("I couldn't bind via UDP to any IPs on port %d, so I'm exiting", *bindPort)
}
log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundIPsPorts, `", "`))
if len(unboundIPs) > 0 { if len(unboundIPs) > 0 {
log.Printf(`I couldn't bind via UDP to the following IPs: "%s"`, strings.Join(unboundIPs, `", "`)) log.Printf(`I couldn't bind via UDP to the following IPs: "%s"`, strings.Join(unboundIPs, `", "`))
} }
default: if len(udpConns) == 0 {
log.Fatalf("I couldn't bind via UDP to any IPs on port %d, so I'm exiting", *bindPort)
}
}
if err != nil {
log.Fatal(err.Error()) log.Fatal(err.Error())
} }
var boundIPs []string
for _, udpConn := range udpConns {
boundIPs = append(boundIPs, udpConn.LocalAddr().String())
go readFrom(udpConn, x, *quiet)
wg.Add(1)
}
log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundIPs, `", "`))
log.Printf("Ready to answer queries") log.Printf("Ready to answer queries")
wg.Wait() wg.Wait()
} }
func readFrom(conn *net.UDPConn, wg *sync.WaitGroup, x *xip.Xip, quiet bool) { func readFrom(conn *net.UDPConn, x *xip.Xip, quiet bool) {
defer wg.Done()
for { for {
query := make([]byte, 512) query := make([]byte, 512)
_, addr, err := conn.ReadFromUDP(query) _, addr, err := conn.ReadFromUDP(query)
@@ -110,6 +97,28 @@ func readFrom(conn *net.UDPConn, wg *sync.WaitGroup, x *xip.Xip, quiet bool) {
} }
} }
func bindUDPAddressesIndividually(bindPort int) (udpConns []*net.UDPConn, unboundIPs []string) {
ipCIDRs := listLocalIPCIDRs()
for _, ipCIDR := range ipCIDRs {
ip, _, err := net.ParseCIDR(ipCIDR)
if err != nil {
log.Printf(`I couldn't parse the local interface "%s".`, ipCIDR)
continue
}
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{
IP: ip,
Port: bindPort,
Zone: "",
})
if err != nil {
unboundIPs = append(unboundIPs, ip.String())
} else {
udpConns = append(udpConns, udpConn)
}
}
return udpConns, unboundIPs
}
func listLocalIPCIDRs() []string { func listLocalIPCIDRs() []string {
var ifaces []net.Interface var ifaces []net.Interface
var cidrStrings []string var cidrStrings []string