diff --git a/cmd/api/api.go b/cmd/api/api.go index cb4f1215..ccd88530 100644 --- a/cmd/api/api.go +++ b/cmd/api/api.go @@ -8,6 +8,7 @@ import ( "net/http" "os" "strconv" + "strings" "sync" ) @@ -15,6 +16,8 @@ func Init() { var cfg struct { Mod struct { Listen string `yaml:"listen"` + Username string `yaml:"username"` + Password string `yaml:"password"` BasePath string `yaml:"base_path"` StaticDir string `yaml:"static_dir"` Origin string `yaml:"origin"` @@ -52,14 +55,18 @@ func Init() { log.Info().Str("addr", cfg.Mod.Listen).Msg("[api] listen") s := http.Server{} - s.Handler = http.DefaultServeMux - - if log.Trace().Enabled() { - s.Handler = middlewareLog(s.Handler) - } + s.Handler = http.DefaultServeMux // 4th if cfg.Mod.Origin == "*" { - s.Handler = middlewareCORS(s.Handler) + s.Handler = middlewareCORS(s.Handler) // 3rd + } + + if cfg.Mod.Username != "" { + s.Handler = middlewareAuth(cfg.Mod.Username, cfg.Mod.Password, s.Handler) // 2nd + } + + if log.Trace().Enabled() { + s.Handler = middlewareLog(s.Handler) // 1st } go func() { @@ -87,7 +94,22 @@ var log zerolog.Logger func middlewareLog(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - log.Trace().Msgf("[api] %s %s", r.Method, r.URL) + log.Trace().Msgf("[api] %s %s %s", r.Method, r.URL, r.RemoteAddr) + next.ServeHTTP(w, r) + }) +} + +func middlewareAuth(username, password string, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.RemoteAddr, "127.") && !strings.HasPrefix(r.RemoteAddr, "[::1]") { + user, pass, ok := r.BasicAuth() + if !ok || user != username || pass != password { + w.Header().Set("Www-Authenticate", `Basic realm="go2rtc"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + } + next.ServeHTTP(w, r) }) }