Listen on TCP, not solely UDP

I've wanted sslip.io to bind to both UDP & TCP, mostly because TCP is
more secure (at least with regards to DNS cache poisoning).

In general, the process to receive a packet, whether TCP or UDP, is
similar.

- UDP uses `net.UDPConn`, TCP uses `net.TCPListener`
- Once bound, UDP uses `ReadFromUDP()` to get the data; TCP first
  requires an `AcceptTCP()` followed by a `Read()`
- Technically you can ask several queries over a single TCP socket, but
  I close the connection after the first query.
- DNS TCP packet has a two-byte length field that has no counterpart in
  the DNS UDP packet.
- The TCP integration tests are lacking.
This commit is contained in:
Brian Cunnie
2023-09-30 09:14:20 -07:00
parent b09bccdd86
commit 81f9f6c9a3
2 changed files with 114 additions and 15 deletions

View File

@@ -168,6 +168,10 @@ var _ = Describe("sslip.io-dns-server", func() {
"@127.0.0.1 -x 2600:: +short",
`\A2600--.sslip.io.\n\z`,
`TypePTR 0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.6.2.ip6.arpa. \? 2600--.sslip.io.\n`),
Entry(`over TCP, A (customized) for sslip.io`,
"@localhost sslip.io +short +vc",
`\A78.46.204.247\n\z`,
`TypeA sslip.io. \? 78.46.204.247\n`),
)
})
Describe("for more complex assertions", func() {

View File

@@ -1,6 +1,7 @@
package main
import (
"encoding/binary"
"errors"
"flag"
"log"
@@ -39,7 +40,9 @@ func main() {
}
var udpConns []*net.UDPConn
var unboundIPs []string
var tcpListeners []*net.TCPListener
var unboundUDPIPs []string
var unboundTCPIPs []string
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{Port: *bindPort})
switch {
case err == nil: // success! We've bound to all interfaces
@@ -49,32 +52,67 @@ func main() {
log.Fatal(err.Error())
case isErrorAddressAlreadyInUse(err):
log.Printf("I couldn't bind via UDP to \"[::]:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort)
udpConns, unboundIPs = bindUDPAddressesIndividually(*bindPort)
if len(unboundIPs) > 0 {
log.Printf(`I couldn't bind via UDP to the following IPs: "%s"`, strings.Join(unboundIPs, `", "`))
udpConns, unboundUDPIPs = bindUDPAddressesIndividually(*bindPort)
if len(unboundUDPIPs) > 0 {
log.Printf(`I couldn't bind via UDP to the following IPs: "%s"`, strings.Join(unboundUDPIPs, `", "`))
}
default:
log.Fatal(err.Error())
}
if len(udpConns) == 0 {
tcpListener, err := net.ListenTCP("tcp", &net.TCPAddr{Port: *bindPort})
switch {
case err == nil: // success! We've bound to all interfaces
tcpListeners = append(tcpListeners, tcpListener)
case isErrorPermissionsError(err): // unnecessary because it should've bombed out earlier when attempting to bind UDP
log.Printf("Try invoking me with `sudo` because I don't have permission to bind to TCP port %d.\n", *bindPort)
log.Println(err.Error())
case isErrorAddressAlreadyInUse(err):
log.Printf("I couldn't bind via TCP to \"[::]:%d\" (INADDR_ANY, all interfaces), so I'll try to bind to each address individually.\n", *bindPort)
tcpListeners, unboundTCPIPs = bindTCPAddressesIndividually(*bindPort)
if len(unboundTCPIPs) > 0 {
log.Printf(`I couldn't bind via TCP to the following IPs: "%s"`, strings.Join(unboundTCPIPs, `", "`))
}
default:
log.Println(err.Error()) // Unlike UDP, we don't exit on TCP errors, we merely log
}
if len(tcpListeners) == 0 {
// unlike UDP failure to bind, we don't exit because TCP is optional, UDP, mandatory
log.Printf("I couldn't bind via TCP to any IPs on port %d", *bindPort)
}
// Log the list of IPs that we've bound to because it helps troubleshooting
var boundUDPIPs []string
for _, udpConn := range udpConns {
boundUDPIPs = append(boundUDPIPs, udpConn.LocalAddr().String())
}
log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundUDPIPs, `", "`))
var boundTCPIPs []string
for _, tcpListener := range tcpListeners {
boundTCPIPs = append(boundTCPIPs, tcpListener.Addr().String())
}
log.Printf(`I bound via TCP to the following IPs: "%s"`, strings.Join(boundTCPIPs, `", "`))
if len(udpConns) == 0 { // couldn't bind to UDP anywhere? exit
log.Fatalf("I couldn't bind via UDP to any IPs on port %d, so I'm exiting", *bindPort)
}
// Log the list of IPs that we've bound to because it helps troubleshooting
var boundIPs []string
for _, udpConn := range udpConns {
boundIPs = append(boundIPs, udpConn.LocalAddr().String())
if len(tcpListeners) == 0 { // couldn't bind to TCP anywhere? don't exit; TCP is optional
log.Printf("I couldn't bind via TCP to any IPs on port %d", *bindPort)
}
log.Printf(`I bound via UDP to the following IPs: "%s"`, strings.Join(boundIPs, `", "`))
// Read from the UDP connections
for _, udpConn := range udpConns[1:] { // use goroutines to read from all the UDP connections EXCEPT the first
go readFrom(udpConn, x, *quiet)
// Read from the UDP connections & TCP Listeners
// use goroutines to read from all the UDP connections EXCEPT the first; we don't use a goroutine for that
// one because we use the first one to keep this program from exiting
for _, udpConn := range udpConns[1:] {
go readFromUDP(udpConn, x, *quiet)
}
for _, tcpListener := range tcpListeners {
go readFromTCP(tcpListener, x, *quiet)
}
log.Printf("Ready to answer queries")
readFrom(udpConns[0], x, *quiet) // refrain from exiting; There should always be a udpConns[0], and readFrom() _never_ returns
readFromUDP(udpConns[0], x, *quiet) // refrain from exiting; There should always be a udpConns[0], and readFromUDP() _never_ returns
}
func readFrom(conn *net.UDPConn, x *xip.Xip, quiet bool) {
func readFromUDP(conn *net.UDPConn, x *xip.Xip, quiet bool) {
for {
query := make([]byte, 512)
_, addr, err := conn.ReadFromUDP(query)
@@ -96,6 +134,43 @@ func readFrom(conn *net.UDPConn, x *xip.Xip, quiet bool) {
}
}
func readFromTCP(tcpListener *net.TCPListener, x *xip.Xip, quiet bool) {
for {
query := make([]byte, 65535) // 2-byte length field means largest size is 65535
tcpConn, err := tcpListener.AcceptTCP()
if err != nil {
log.Println(err.Error())
continue
}
_, err = tcpConn.Read(query)
query = query[2:] // remove the 2-byte length at the beginning of the query
if err != nil {
log.Println(err.Error())
continue
}
remoteAddrPort := tcpConn.RemoteAddr().String()
addr, port, err := net.SplitHostPort(remoteAddrPort)
go func() {
defer tcpConn.Close()
response, logMessage, err := x.QueryResponse(query, net.ParseIP(addr))
if err != nil {
log.Println(err.Error())
return
}
// insert the 2-byte length to the beginning of the response
responseSize := uint16(len(response))
responseSizeBigEndianBytes := make([]byte, 2)
binary.BigEndian.PutUint16(responseSizeBigEndianBytes, responseSize)
response = append(responseSizeBigEndianBytes, response...)
_, err = tcpConn.Write(response)
if !quiet {
log.Printf("%s.%s %s", addr, port, logMessage)
}
}()
}
}
func bindUDPAddressesIndividually(bindPort int) (udpConns []*net.UDPConn, unboundIPs []string) {
ipCIDRs := listLocalIPCIDRs()
for _, ipCIDR := range ipCIDRs {
@@ -118,6 +193,26 @@ func bindUDPAddressesIndividually(bindPort int) (udpConns []*net.UDPConn, unboun
return udpConns, unboundIPs
}
func bindTCPAddressesIndividually(bindPort int) (tcpListeners []*net.TCPListener, unboundIPs []string) {
ipCIDRs := listLocalIPCIDRs()
for _, ipCIDR := range ipCIDRs {
ip, _, err := net.ParseCIDR(ipCIDR)
if err != nil {
log.Printf(`I couldn't parse the local interface "%s".`, ipCIDR)
continue
}
listener, err := net.ListenTCP("tcp", &net.TCPAddr{IP: ip, Port: bindPort})
if err != nil {
unboundIPs = append(unboundIPs, ip.String())
} else {
tcpListeners = append(tcpListeners, listener)
}
}
return tcpListeners, unboundIPs
}
// TODO: replace this function with net.InterfaceAddrs() ([]Addr, error)
// typical addr "10.9.9.161/24"
func listLocalIPCIDRs() []string {
var ifaces []net.Interface
var cidrStrings []string