diff --git a/integration_flags_test.go b/integration_flags_test.go index c38290d..9633aaf 100644 --- a/integration_flags_test.go +++ b/integration_flags_test.go @@ -14,8 +14,14 @@ import ( var _ = Describe("flags", func() { var serverCmd *exec.Cmd var serverSession *Session - var port = getFreePort() + var port int var flags []string + var serverReadyOrDeadOutput string + + BeforeEach(func() { + port = getFreePort() + serverReadyOrDeadOutput = "Ready to answer queries" + }) JustBeforeEach(func() { flags = append(flags, "-port", strconv.Itoa(port), "-blocklistURL", "file://etc/blocklist-test.txt") @@ -25,7 +31,7 @@ var _ = Describe("flags", func() { // takes 0.455s to start up on macOS Big Sur 3.7 GHz Quad Core 22-nm Xeon E5-1620v2 processor (2013 Mac Pro) // takes 1.312s to start up on macOS Big Sur 2.0GHz quad-core 10th-generation Intel Core i5 processor (2020 13" MacBook Pro) // 10 seconds should be long enough for slow container-on-a-VM-with-shared-core - Eventually(serverSession.Err, 10).Should(Say("Ready to answer queries")) + Eventually(serverSession.Err, 10).Should(Say(serverReadyOrDeadOutput)) }) AfterEach(func() { serverSession.Terminate() @@ -234,4 +240,50 @@ var _ = Describe("flags", func() { }) }) }) + + When("-max_queries_per_sec is set", func() { + When("the arguments are missing", func() { + BeforeEach(func() { + flags = []string{"-max_queries_per_sec="} + serverReadyOrDeadOutput = "-max_queries_per_sec: parse error" + }) + It("should give an informative message", func() { + portFail := getFreePort() + flags = append(flags, "-port", strconv.Itoa(portFail), "-blocklistURL", "file://etc/blocklist-test.txt") + serverCmd = exec.Command(serverPath, flags...) + serverSessionFail, err := Start(serverCmd, GinkgoWriter, GinkgoWriter) + Expect(err).ToNot(HaveOccurred()) + // takes 0.455s to start up on macOS Big Sur 3.7 GHz Quad Core 22-nm Xeon E5-1620v2 processor (2013 Mac Pro) + // takes 1.312s to start up on macOS Big Sur 2.0GHz quad-core 10th-generation Intel Core i5 processor (2020 13" MacBook Pro) + // 10 seconds should be long enough for slow container-on-a-VM-with-shared-core + Eventually(serverSessionFail.Err, 10).Should(Say(serverReadyOrDeadOutput)) + Eventually(string(serverSessionFail.Err.Contents())).Should(MatchRegexp(`-max_queries_per_sec`)) + }) + }) + When("the queries exceed the limit", func() { + BeforeEach(func() { + flags = []string{"-max_queries_per_sec=1"} + }) + It("should answer the first query but not the second", func() { + digArgs := "@localhost 169-254-169-254.sslip.io +tries=1 +timeout=1 -p " + strconv.Itoa(port) + digCmd := exec.Command("dig", strings.Split(digArgs, " ")...) + digSession, err := Start(digCmd, GinkgoWriter, GinkgoWriter) + Expect(err).ToNot(HaveOccurred()) + Eventually(digSession).Should(Say(`flags: qr aa rd; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 0`)) + Eventually(digSession).Should(Say(`;; ANSWER SECTION:`)) + Eventually(digSession).Should(Say(`169-254-169-254.sslip.io. 3600 IN A 169.254.169.254\n`)) + Eventually(digSession, 1).Should(Exit(0)) + Eventually(string(serverSession.Err.Contents())).Should(MatchRegexp(`TypeA 169-254-169-254\.sslip\.io\. \? 169\.254\.169\.254`)) + // second command, same as the first, but is throttled and doesn't get a DNS reply + digCmdThrottled := exec.Command("dig", strings.Split(digArgs, " ")...) + digSessionThrottled, err := Start(digCmdThrottled, GinkgoWriter, GinkgoWriter) + Expect(err).ToNot(HaveOccurred()) + Eventually(digSessionThrottled, 2).Should(Exit(0)) + Eventually(string(serverSession.Err.Contents())).Should(MatchRegexp(`429 Too Many Requests: .* queries per second exceeds 1 queries per second limit`)) + }) + }) + }) + + When("-max_queries_per_sec is set", func() { + }) }) diff --git a/main.go b/main.go index 6f67249..5751b36 100644 --- a/main.go +++ b/main.go @@ -5,6 +5,7 @@ import ( "errors" "flag" "log" + "math" "net" "os" "runtime" @@ -42,12 +43,13 @@ func main() { var bindPort = flag.Int("port", 53, "port the DNS server should bind to") var quiet = flag.Bool("quiet", false, "suppresses logging of each DNS response. Use this to avoid Google Cloud charging you $30/month to retain the logs of your GKE-based sslip.io server") var public = flag.Bool("public", true, "allows resolution of public IP addresses. If false, only resolves private IPs including localhost (127/8, ::1), link-local (169.254/16, fe80::/10), CG-NAT (100.64/12), private (10/8, 172.16/12, 192.168/16, fc/7). Set to false if you don't want miscreants impersonating you via public IPs. If unsure, set to false") + var maxQueriesPerSec = flag.Int("max_queries_per_sec", math.MaxInt32, "maximum queries per second. This limit, in queries/second, is measured since the server was started. When the limit is reached, the server stops replying until throughput drops below the limit. Use this if AWS is gouging you for bandwidth. 300 qps is close to 100 GB / month") flag.Parse() log.Printf("%s version %s starting", os.Args[0], xip.VersionSemantic) log.Printf("blocklist URL: %s, name servers: %s, bind port: %d, quiet: %t", *blocklistURL, *nameservers, *bindPort, *quiet) - x, logmessages := xip.NewXip(*blocklistURL, strings.Split(*nameservers, ","), strings.Split(*addresses, ","), strings.Split(*delegates, ",")) + x, logmessages := xip.NewXip(*blocklistURL, strings.Split(*nameservers, ","), strings.Split(*addresses, ","), strings.Split(*delegates, ","), *maxQueriesPerSec) x.Public = *public for _, logmessage := range logmessages { log.Println(logmessage) diff --git a/xip/xip.go b/xip/xip.go index c90550e..d0a2492 100644 --- a/xip/xip.go +++ b/xip/xip.go @@ -32,6 +32,7 @@ type Xip struct { BlocklistUpdated time.Time // The most recent time the Blocklist was updated NameServers []dnsmessage.NSResource // The list of authoritative name servers (NS) Public bool // Whether to resolve public IPs; set to false if security-conscious + MaxQueriesPerSecond int // Max Queries / Second } // Metrics contains the counters of the important/interesting queries @@ -177,8 +178,8 @@ type Response struct { } // NewXip follows convention for constructors: https://go.dev/doc/effective_go#allocation_new -func NewXip(blocklistURL string, nameservers []string, addresses []string, delegates []string) (x *Xip, logmessages []string) { - x = &Xip{Metrics: Metrics{Start: time.Now()}} +func NewXip(blocklistURL string, nameservers []string, addresses []string, delegates []string, maxQueriesPerSec int) (x *Xip, logmessages []string) { + x = &Xip{Metrics: Metrics{Start: time.Now()}, MaxQueriesPerSecond: maxQueriesPerSec} // Download the blocklist logmessages = append(logmessages, x.downloadBlockList(blocklistURL)) @@ -336,6 +337,12 @@ func (x *Xip) QueryResponse(queryBytes []byte, srcAddr net.IP) (responseBytes [] var p dnsmessage.Parser var response Response + // Have we exceeded our throttle? Don't reply, but return an error + if float64(x.Metrics.Queries)/time.Since(x.Metrics.Start).Seconds() > float64(x.MaxQueriesPerSecond) { + return nil, "", fmt.Errorf( + "429 Too Many Requests: %0.2f queries per second exceeds %d queries per second limit", + float64(x.Metrics.Queries)/time.Since(x.Metrics.Start).Seconds(), x.MaxQueriesPerSecond) + } if queryHeader, err = p.Start(queryBytes); err != nil { return nil, "", err } diff --git a/xip/xip_test.go b/xip/xip_test.go index b4280aa..1c3d973 100644 --- a/xip/xip_test.go +++ b/xip/xip_test.go @@ -1,9 +1,11 @@ package xip_test import ( + "math" "math/rand" "net" "strings" + "time" "xip/testhelper" "xip/xip" @@ -79,7 +81,7 @@ var _ = Describe("Xip", func() { Describe("NSResources()", func() { When("we use the default nameservers", func() { - var x, _ = xip.NewXip("file:///", []string{"ns-aws.sslip.io.", "ns-azure.sslip.io.", "ns-gce.sslip.io.", "ns-ovh.sslip.io."}, []string{}, []string{}) + var x, _ = xip.NewXip("file:///", []string{"ns-aws.sslip.io.", "ns-azure.sslip.io.", "ns-gce.sslip.io.", "ns-ovh.sslip.io."}, []string{}, []string{}, math.MaxInt32) It("returns the name servers", func() { randomDomain := testhelper.Random8ByteString() + ".com." ns := x.NSResources(randomDomain) @@ -113,13 +115,13 @@ var _ = Describe("Xip", func() { When("we delegate domains to other nameservers", func() { When(`we don't use the "=" in the arguments`, func() { It("returns an informative log message", func() { - var _, logs = xip.NewXip("file://etc/blocklist-test.txt", []string{"ns-aws.sslip.io.", "ns-azure.sslip.io.", "ns-gce.sslip.io.", "ns-ovh.sslip.io."}, []string{}, []string{"noEquals"}) + var _, logs = xip.NewXip("file://etc/blocklist-test.txt", []string{"ns-aws.sslip.io.", "ns-azure.sslip.io.", "ns-gce.sslip.io.", "ns-ovh.sslip.io."}, []string{}, []string{"noEquals"}, math.MaxInt32) Expect(strings.Join(logs, "")).To(MatchRegexp(`"-delegates: arguments should be in the format "delegatedDomain=nameserver", not "noEquals"`)) }) }) When(`there's no "." at the end of the delegated domain or nameserver`, func() { It(`helpfully adds the "."`, func() { - var x, logs = xip.NewXip("file://etc/blocklist-test.txt", []string{"ns-aws.sslip.io.", "ns-azure.sslip.io.", "ns-gce.sslip.io.", "ns-ovh.sslip.io."}, []string{}, []string{"a=b"}) + var x, logs = xip.NewXip("file://etc/blocklist-test.txt", []string{"ns-aws.sslip.io.", "ns-azure.sslip.io.", "ns-gce.sslip.io.", "ns-ovh.sslip.io."}, []string{}, []string{"a=b"}, math.MaxInt32) Expect(strings.Join(logs, "")).To(MatchRegexp(`Adding delegated NS record "a\.=b\."`)) ns := x.NSResources("a.") Expect(len(ns)).To(Equal(1)) @@ -128,7 +130,7 @@ var _ = Describe("Xip", func() { }) }) When("we override the default nameservers", func() { - var x, _ = xip.NewXip("file:///", []string{"mickey", "minn.ie.", "goo.fy"}, []string{}, []string{}) + var x, _ = xip.NewXip("file:///", []string{"mickey", "minn.ie.", "goo.fy"}, []string{}, []string{}, math.MaxInt32) It("returns the configured servers", func() { randomDomain := testhelper.Random8ByteString() + ".com." ns := x.NSResources(randomDomain) @@ -467,4 +469,40 @@ var _ = Describe("Xip", func() { Entry("Private internets", net.ParseIP("fc00::"), false), ) }) + + Describe("QueryResponse()", func() { + // sample query: the AAAA (IPv6) record of localhost (::1) + msg := dnsmessage.Message{ + Header: dnsmessage.Header{ + ID: 1234, // Choose a random ID + RecursionDesired: true, + }, + Questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("::1."), // Note the trailing dot + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + }, + } + // Pack the message into a byte slice + packedMessage, err := msg.Pack() + Expect(err).ToNot(HaveOccurred()) + loopbackIP := net.ParseIP("127.0.0.1") // the querier's IP is localhost + + When("the response has been throttled (`-max-queries-per-sec` is set)", func() { + It("returns an error, not a response", func() { + x, _ := xip.NewXip("", []string{}, []string{}, []string{}, 1) + Expect(err).ToNot(HaveOccurred()) + _, _, err = x.QueryResponse(packedMessage, loopbackIP) // first query + Expect(err).ToNot(HaveOccurred()) + time.Sleep(1000 * time.Millisecond) // sleep 1 second to stay under the limit + _, _, err = x.QueryResponse(packedMessage, loopbackIP) // second query + Expect(err).ToNot(HaveOccurred()) // should succeed + _, _, err = x.QueryResponse(packedMessage, loopbackIP) // third query + Expect(err).To(HaveOccurred()) // should fail, over the limit + Expect(err.Error()).To(MatchRegexp(`429 Too Many Requests: .* queries per second exceeds 1 queries per second limit`)) + }) + }) + }) })