From 81f9f6c9a330cab5fd7db176ee0833e8d5a98d1b Mon Sep 17 00:00:00 2001 From: Brian Cunnie Date: Sat, 30 Sep 2023 09:14:20 -0700 Subject: [PATCH] Listen on TCP, not solely UDP I've wanted sslip.io to bind to both UDP & TCP, mostly because TCP is more secure (at least with regards to DNS cache poisoning). In general, the process to receive a packet, whether TCP or UDP, is similar. - UDP uses `net.UDPConn`, TCP uses `net.TCPListener` - Once bound, UDP uses `ReadFromUDP()` to get the data; TCP first requires an `AcceptTCP()` followed by a `Read()` - Technically you can ask several queries over a single TCP socket, but I close the connection after the first query. - DNS TCP packet has a two-byte length field that has no counterpart in the DNS UDP packet. - The TCP integration tests are lacking. --- src/sslip.io-dns-server/integration_test.go | 4 + src/sslip.io-dns-server/main.go | 125 +++++++++++++++++--- 2 files changed, 114 insertions(+), 15 deletions(-) diff --git a/src/sslip.io-dns-server/integration_test.go b/src/sslip.io-dns-server/integration_test.go index 793f682..23f0705 100644 --- a/src/sslip.io-dns-server/integration_test.go +++ b/src/sslip.io-dns-server/integration_test.go @@ -168,6 +168,10 @@ var _ = Describe("sslip.io-dns-server", func() { "@127.0.0.1 -x 2600:: +short", `\A2600--.sslip.io.\n\z`, `TypePTR 0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.6.2.ip6.arpa. \? 2600--.sslip.io.\n`), + Entry(`over TCP, A (customized) for sslip.io`, + "@localhost sslip.io +short +vc", + `\A78.46.204.247\n\z`, + `TypeA sslip.io. \? 78.46.204.247\n`), ) }) Describe("for more complex assertions", func() { diff --git a/src/sslip.io-dns-server/main.go b/src/sslip.io-dns-server/main.go index 7b098d3..763daf4 100644 --- a/src/sslip.io-dns-server/main.go +++ b/src/sslip.io-dns-server/main.go @@ -1,6 +1,7 @@ package main import ( + "encoding/binary" "errors" "flag" "log" @@ -39,7 +40,9 @@ func main() { } var udpConns []*net.UDPConn - var unboundIPs []string + var tcpListeners []*net.TCPListener + var unboundUDPIPs []string + var unboundTCPIPs []string udpConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: *bindPort}) switch { case err == nil: // success! We've bound to all interfaces @@ -49,32 +52,67 @@ func main() { log.Fatal(err.Error()) case isErrorAddressAlreadyInUse(err): log.Printf("I couldn't bind via UDP to \"[::]:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort) - udpConns, unboundIPs = bindUDPAddressesIndividually(*bindPort) - if len(unboundIPs) > 0 { - log.Printf(`I couldn't bind via UDP to the following IPs: "%s"`, strings.Join(unboundIPs, `", "`)) + udpConns, unboundUDPIPs = bindUDPAddressesIndividually(*bindPort) + if len(unboundUDPIPs) > 0 { + log.Printf(`I couldn't bind via UDP to the following IPs: "%s"`, strings.Join(unboundUDPIPs, `", "`)) } default: log.Fatal(err.Error()) } - if len(udpConns) == 0 { + tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: *bindPort}) + switch { + case err == nil: // success! We've bound to all interfaces + tcpListeners = append(tcpListeners, tcpListener) + case isErrorPermissionsError(err): // unnecessary because it should've bombed out earlier when attempting to bind UDP + log.Printf("Try invoking me with `sudo` because I don't have permission to bind to TCP port %d.\n", *bindPort) + log.Println(err.Error()) + case isErrorAddressAlreadyInUse(err): + log.Printf("I couldn't bind via TCP to \"[::]:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort) + tcpListeners, unboundTCPIPs = bindTCPAddressesIndividually(*bindPort) + if len(unboundTCPIPs) > 0 { + log.Printf(`I couldn't bind via TCP to the following IPs: "%s"`, strings.Join(unboundTCPIPs, `", "`)) + } + default: + log.Println(err.Error()) // Unlike UDP, we don't exit on TCP errors, we merely log + } + if len(tcpListeners) == 0 { + // unlike UDP failure to bind, we don't exit because TCP is optional, UDP, mandatory + log.Printf("I couldn't bind via TCP to any IPs on port %d", *bindPort) + } + + // Log the list of IPs that we've bound to because it helps troubleshooting + var boundUDPIPs []string + for _, udpConn := range udpConns { + boundUDPIPs = append(boundUDPIPs, udpConn.LocalAddr().String()) + } + log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundUDPIPs, `", "`)) + var boundTCPIPs []string + for _, tcpListener := range tcpListeners { + boundTCPIPs = append(boundTCPIPs, tcpListener.Addr().String()) + } + log.Printf(`I bound via TCP to the following IPs: "%s"`, strings.Join(boundTCPIPs, `", "`)) + + if len(udpConns) == 0 { // couldn't bind to UDP anywhere? exit log.Fatalf("I couldn't bind via UDP to any IPs on port %d, so I'm exiting", *bindPort) } - // Log the list of IPs that we've bound to because it helps troubleshooting - var boundIPs []string - for _, udpConn := range udpConns { - boundIPs = append(boundIPs, udpConn.LocalAddr().String()) + if len(tcpListeners) == 0 { // couldn't bind to TCP anywhere? don't exit; TCP is optional + log.Printf("I couldn't bind via TCP to any IPs on port %d", *bindPort) } - log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundIPs, `", "`)) - // Read from the UDP connections - for _, udpConn := range udpConns[1:] { // use goroutines to read from all the UDP connections EXCEPT the first - go readFrom(udpConn, x, *quiet) + // Read from the UDP connections & TCP Listeners + // use goroutines to read from all the UDP connections EXCEPT the first; we don't use a goroutine for that + // one because we use the first one to keep this program from exiting + for _, udpConn := range udpConns[1:] { + go readFromUDP(udpConn, x, *quiet) + } + for _, tcpListener := range tcpListeners { + go readFromTCP(tcpListener, x, *quiet) } log.Printf("Ready to answer queries") - readFrom(udpConns[0], x, *quiet) // refrain from exiting; There should always be a udpConns[0], and readFrom() _never_ returns + readFromUDP(udpConns[0], x, *quiet) // refrain from exiting; There should always be a udpConns[0], and readFromUDP() _never_ returns } -func readFrom(conn *net.UDPConn, x *xip.Xip, quiet bool) { +func readFromUDP(conn *net.UDPConn, x *xip.Xip, quiet bool) { for { query := make([]byte, 512) _, addr, err := conn.ReadFromUDP(query) @@ -96,6 +134,43 @@ func readFrom(conn *net.UDPConn, x *xip.Xip, quiet bool) { } } +func readFromTCP(tcpListener *net.TCPListener, x *xip.Xip, quiet bool) { + for { + query := make([]byte, 65535) // 2-byte length field means largest size is 65535 + tcpConn, err := tcpListener.AcceptTCP() + if err != nil { + log.Println(err.Error()) + continue + } + _, err = tcpConn.Read(query) + query = query[2:] // remove the 2-byte length at the beginning of the query + if err != nil { + log.Println(err.Error()) + continue + } + remoteAddrPort := tcpConn.RemoteAddr().String() + addr, port, err := net.SplitHostPort(remoteAddrPort) + + go func() { + defer tcpConn.Close() + response, logMessage, err := x.QueryResponse(query, net.ParseIP(addr)) + if err != nil { + log.Println(err.Error()) + return + } + // insert the 2-byte length to the beginning of the response + responseSize := uint16(len(response)) + responseSizeBigEndianBytes := make([]byte, 2) + binary.BigEndian.PutUint16(responseSizeBigEndianBytes, responseSize) + response = append(responseSizeBigEndianBytes, response...) + _, err = tcpConn.Write(response) + if !quiet { + log.Printf("%s.%s %s", addr, port, logMessage) + } + }() + } +} + func bindUDPAddressesIndividually(bindPort int) (udpConns []*net.UDPConn, unboundIPs []string) { ipCIDRs := listLocalIPCIDRs() for _, ipCIDR := range ipCIDRs { @@ -118,6 +193,26 @@ func bindUDPAddressesIndividually(bindPort int) (udpConns []*net.UDPConn, unboun return udpConns, unboundIPs } +func bindTCPAddressesIndividually(bindPort int) (tcpListeners []*net.TCPListener, unboundIPs []string) { + ipCIDRs := listLocalIPCIDRs() + for _, ipCIDR := range ipCIDRs { + ip, _, err := net.ParseCIDR(ipCIDR) + if err != nil { + log.Printf(`I couldn't parse the local interface "%s".`, ipCIDR) + continue + } + listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: ip, Port: bindPort}) + if err != nil { + unboundIPs = append(unboundIPs, ip.String()) + } else { + tcpListeners = append(tcpListeners, listener) + } + } + return tcpListeners, unboundIPs +} + +// TODO: replace this function with net.InterfaceAddrs() ([]Addr, error) +// typical addr "10.9.9.161/24" func listLocalIPCIDRs() []string { var ifaces []net.Interface var cidrStrings []string