diff --git a/main.go b/main.go index 763ca29..a8f1990 100644 --- a/main.go +++ b/main.go @@ -28,7 +28,15 @@ func main() { quiet := flag.Bool("quiet", false, "do not log informational messages (takes precedence over debug)") json := flag.Bool("json", false, "enable JSON logging") test := flag.Bool("test", false, "test crontab (does not run jobs)") - prometheusListen := flag.String("prometheus-listen-address", "", "give a valid ip:port address to expose Prometheus metrics at /metrics") + prometheusListen := flag.String( + "prometheus-listen-address", + "", + fmt.Sprintf( + "give a valid ip[:port] address to expose Prometheus metrics at /metrics (port defaults to %s), "+ + "use 0.0.0.0 for all network interfaces.", + prometheus_metrics.DefaultPort, + ), + ) splitLogs := flag.Bool("split-logs", false, "split log output into stdout/stderr") passthroughLogs := flag.Bool("passthrough-logs", false, "passthrough logs from commands, do not wrap them in Supercronic logging") sentry := flag.String("sentry-dsn", "", "enable Sentry error logging, using provided DSN") diff --git a/prometheus_metrics/prommetrics.go b/prometheus_metrics/prommetrics.go index 7ee91ae..4a5ff38 100644 --- a/prometheus_metrics/prommetrics.go +++ b/prometheus_metrics/prommetrics.go @@ -2,6 +2,7 @@ package prometheus_metrics import ( "context" + "fmt" "net" "net/http" @@ -11,7 +12,8 @@ import ( ) const ( - namespace = "supercronic" + DefaultPort = "9746" + namespace = "supercronic" ) func genMetricName(name string) string { @@ -99,7 +101,40 @@ func (p *PrometheusMetrics) Reset() { p.CronsExecutionTimeHistogram.Reset() } +func getAddr(listenAddr string) (string, error) { + if listenAddr == "" { + return "", fmt.Errorf("Not address provided") + } + + // If the address is fine as-is, use it + _, _, err1 := net.SplitHostPort(listenAddr) + if err1 == nil { + return listenAddr, nil + } + + // Otherwise, try to add the port + listenAddrWithPort := net.JoinHostPort(listenAddr, DefaultPort) + _, _, err2 := net.SplitHostPort(listenAddrWithPort) + if err2 == nil { + return listenAddrWithPort, nil + } + + return "", fmt.Errorf( + "%s is not a valid address (%v), and neither is %s after adding the default port (%v)", + listenAddr, + err1, + listenAddrWithPort, + err2, + ) + +} + func InitHTTPServer(listenAddr string, shutdownContext context.Context) (func() error, error) { + addr, err := getAddr(listenAddr) + if err != nil { + return nil, err + } + promSrv := &http.Server{} http.Handle("/metrics", promhttp.Handler()) @@ -118,7 +153,7 @@ func InitHTTPServer(listenAddr string, shutdownContext context.Context) (func() return promSrv.Shutdown(shutdownContext) } - listener, err := net.Listen("tcp", listenAddr) + listener, err := net.Listen("tcp", addr) if err != nil { return shutdownClosure, err } diff --git a/prometheus_metrics/prommetrics_test.go b/prometheus_metrics/prommetrics_test.go new file mode 100644 index 0000000..7bbdb59 --- /dev/null +++ b/prometheus_metrics/prommetrics_test.go @@ -0,0 +1,45 @@ +package prometheus_metrics + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetAddr(t *testing.T) { + addr, err := getAddr("127.0.0.1:123") + if assert.Nil(t, err) { + assert.Equal(t, "127.0.0.1:123", addr) + } + + addr, err = getAddr("127.0.0.1") + if assert.Nil(t, err) { + assert.Equal(t, "127.0.0.1:9746", addr) + } + + addr, err = getAddr("[127.0.0.1]") + if assert.Nil(t, err) { + assert.Equal(t, "[127.0.0.1]:9746", addr) + } + + addr, err = getAddr("[::]:123") + if assert.Nil(t, err) { + assert.Equal(t, "[::]:123", addr) + } + + addr, err = getAddr("::") + if assert.Nil(t, err) { + assert.Equal(t, "[::]:9746", addr) + } + + addr, err = getAddr("0.0.0.0") + if assert.Nil(t, err) { + assert.Equal(t, "0.0.0.0:9746", addr) + } + + _, err = getAddr("") + assert.NotNil(t, err) + + _, err = getAddr("[::]") + assert.NotNil(t, err) +}