diff --git a/go.mod b/go.mod index f6e2d621..e53f73b5 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21 require ( code.cloudfoundry.org/bytefmt v0.0.0 + github.com/MicahParks/jwkset v0.5.17 github.com/MicahParks/keyfunc/v3 v3.3.2 github.com/abema/go-mp4 v1.2.0 github.com/alecthomas/kong v0.9.0 @@ -34,7 +35,6 @@ require ( ) require ( - github.com/MicahParks/jwkset v0.5.17 // indirect github.com/asticode/go-astikit v0.30.0 // indirect github.com/asticode/go-astits v1.13.0 // indirect github.com/benburkert/openpgp v0.0.0-20160410205803-c2471f86866c // indirect diff --git a/internal/api/api_test.go b/internal/api/api_test.go index 6e298d50..3ca9fe1e 100644 --- a/internal/api/api_test.go +++ b/internal/api/api_test.go @@ -181,7 +181,7 @@ func TestConfigGlobalPatch(t *testing.T) { require.Equal(t, float64(4096), out["readBufferCount"]) } -func TestAPIConfigGlobalPatchUnknownField(t *testing.T) { //nolint:dupl +func TestConfigGlobalPatchUnknownField(t *testing.T) { //nolint:dupl cnf := tempConf(t, "api: yes\n") api := API{ @@ -218,7 +218,7 @@ func TestAPIConfigGlobalPatchUnknownField(t *testing.T) { //nolint:dupl checkError(t, "json: unknown field \"test\"", res.Body) } -func TestAPIConfigPathDefaultsGet(t *testing.T) { +func TestConfigPathDefaultsGet(t *testing.T) { cnf := tempConf(t, "api: yes\n") api := API{ @@ -241,7 +241,7 @@ func TestAPIConfigPathDefaultsGet(t *testing.T) { require.Equal(t, "publisher", out["source"]) } -func TestAPIConfigPathDefaultsPatch(t *testing.T) { +func TestConfigPathDefaultsPatch(t *testing.T) { cnf := tempConf(t, "api: yes\n") api := API{ @@ -273,7 +273,7 @@ func TestAPIConfigPathDefaultsPatch(t *testing.T) { require.Equal(t, "mypass", out["readPass"]) } -func TestAPIConfigPathsList(t *testing.T) { +func TestConfigPathsList(t *testing.T) { cnf := tempConf(t, "api: yes\n"+ "paths:\n"+ " path1:\n"+ @@ -318,7 +318,7 @@ func TestAPIConfigPathsList(t *testing.T) { require.Equal(t, "mypass2", out.Items[1]["readPass"]) } -func TestAPIConfigPathsGet(t *testing.T) { +func TestConfigPathsGet(t *testing.T) { cnf := tempConf(t, "api: yes\n"+ "paths:\n"+ " my/path:\n"+ @@ -346,7 +346,7 @@ func TestAPIConfigPathsGet(t *testing.T) { require.Equal(t, "myuser", out["readUser"]) } -func TestAPIConfigPathsAdd(t *testing.T) { +func TestConfigPathsAdd(t *testing.T) { cnf := tempConf(t, "api: yes\n") api := API{ @@ -380,7 +380,7 @@ func TestAPIConfigPathsAdd(t *testing.T) { require.Equal(t, true, out["rpiCameraVFlip"]) } -func TestAPIConfigPathsAddUnknownField(t *testing.T) { //nolint:dupl +func TestConfigPathsAddUnknownField(t *testing.T) { //nolint:dupl cnf := tempConf(t, "api: yes\n") api := API{ @@ -417,7 +417,7 @@ func TestAPIConfigPathsAddUnknownField(t *testing.T) { //nolint:dupl checkError(t, "json: unknown field \"test\"", res.Body) } -func TestAPIConfigPathsPatch(t *testing.T) { //nolint:dupl +func TestConfigPathsPatch(t *testing.T) { //nolint:dupl cnf := tempConf(t, "api: yes\n") api := API{ @@ -457,7 +457,7 @@ func TestAPIConfigPathsPatch(t *testing.T) { //nolint:dupl require.Equal(t, true, out["rpiCameraVFlip"]) } -func TestAPIConfigPathsReplace(t *testing.T) { //nolint:dupl +func TestConfigPathsReplace(t *testing.T) { //nolint:dupl cnf := tempConf(t, "api: yes\n") api := API{ @@ -497,7 +497,7 @@ func TestAPIConfigPathsReplace(t *testing.T) { //nolint:dupl require.Equal(t, false, out["rpiCameraVFlip"]) } -func TestAPIConfigPathsDelete(t *testing.T) { +func TestConfigPathsDelete(t *testing.T) { cnf := tempConf(t, "api: yes\n") api := API{ diff --git a/internal/auth/manager_test.go b/internal/auth/manager_test.go index c97a1a58..14498cb5 100644 --- a/internal/auth/manager_test.go +++ b/internal/auth/manager_test.go @@ -2,15 +2,20 @@ package auth import ( "context" + "crypto/rand" + "crypto/rsa" "encoding/json" "net" "net/http" "testing" + "time" + "github.com/MicahParks/jwkset" "github.com/bluenviron/gortsplib/v4/pkg/auth" "github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/mediamtx/internal/conf" + "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/require" ) @@ -25,66 +30,6 @@ func mustParseCIDR(v string) net.IPNet { return *ne } -type testHTTPAuthenticator struct { - *http.Server -} - -func (ts *testHTTPAuthenticator) initialize(t *testing.T, protocol string, action string) { - firstReceived := false - - ts.Server = &http.Server{ - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - require.Equal(t, http.MethodPost, r.Method) - require.Equal(t, "/auth", r.URL.Path) - - var in struct { - IP string `json:"ip"` - User string `json:"user"` - Password string `json:"password"` - Path string `json:"path"` - Protocol string `json:"protocol"` - ID string `json:"id"` - Action string `json:"action"` - Query string `json:"query"` - } - err := json.NewDecoder(r.Body).Decode(&in) - require.NoError(t, err) - - var user string - if action == "publish" { - user = "testpublisher" - } else { - user = "testreader" - } - - if in.IP != "127.0.0.1" || - in.User != user || - in.Password != "testpass" || - in.Path != "teststream" || - in.Protocol != protocol || - (firstReceived && in.ID == "") || - in.Action != action || - (in.Query != "user=testreader&pass=testpass¶m=value" && - in.Query != "user=testpublisher&pass=testpass¶m=value" && - in.Query != "param=value") { - w.WriteHeader(http.StatusBadRequest) - return - } - - firstReceived = true - }), - } - - ln, err := net.Listen("tcp", "127.0.0.1:9120") - require.NoError(t, err) - - go ts.Server.Serve(ln) -} - -func (ts *testHTTPAuthenticator) close() { - ts.Server.Shutdown(context.Background()) -} - func TestAuthInternal(t *testing.T) { for _, outcome := range []string{ "ok", @@ -105,12 +50,10 @@ func TestAuthInternal(t *testing.T) { InternalUsers: []conf.AuthInternalUser{ { IPs: conf.IPNetworks{mustParseCIDR("127.1.1.1/32")}, - Permissions: []conf.AuthInternalUserPermission{ - { - Action: conf.AuthActionPublish, - Path: "mypath", - }, - }, + Permissions: []conf.AuthInternalUserPermission{{ + Action: conf.AuthActionPublish, + Path: "mypath", + }}, }, }, HTTPAddress: "", @@ -207,12 +150,10 @@ func TestAuthInternalRTSPDigest(t *testing.T) { User: "myuser", Pass: "mypass", IPs: conf.IPNetworks{mustParseCIDR("127.1.1.1/32")}, - Permissions: []conf.AuthInternalUserPermission{ - { - Action: conf.AuthActionPublish, - Path: "mypath", - }, - }, + Permissions: []conf.AuthInternalUserPermission{{ + Action: conf.AuthActionPublish, + Path: "mypath", + }}, }, }, HTTPAddress: "", @@ -249,16 +190,56 @@ func TestAuthInternalRTSPDigest(t *testing.T) { func TestAuthHTTP(t *testing.T) { for _, outcome := range []string{"ok", "fail"} { t.Run(outcome, func(t *testing.T) { + firstReceived := false + + httpServ := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/auth", r.URL.Path) + + var in struct { + IP string `json:"ip"` + User string `json:"user"` + Password string `json:"password"` + Path string `json:"path"` + Protocol string `json:"protocol"` + ID string `json:"id"` + Action string `json:"action"` + Query string `json:"query"` + } + err := json.NewDecoder(r.Body).Decode(&in) + require.NoError(t, err) + + if in.IP != "127.0.0.1" || + in.User != "testpublisher" || + in.Password != "testpass" || + in.Path != "teststream" || + in.Protocol != "rtsp" || + (firstReceived && in.ID == "") || + in.Action != "publish" || + (in.Query != "user=testreader&pass=testpass¶m=value" && + in.Query != "user=testpublisher&pass=testpass¶m=value" && + in.Query != "param=value") { + w.WriteHeader(http.StatusBadRequest) + return + } + + firstReceived = true + }), + } + + ln, err := net.Listen("tcp", "127.0.0.1:9120") + require.NoError(t, err) + + go httpServ.Serve(ln) + defer httpServ.Shutdown(context.Background()) + m := Manager{ Method: conf.AuthMethodHTTP, HTTPAddress: "http://127.0.0.1:9120/auth", RTSPAuthMethods: nil, } - au := &testHTTPAuthenticator{} - au.initialize(t, "rtsp", "publish") - defer au.close() - if outcome == "ok" { err := m.Authenticate(&Request{ User: "testpublisher", @@ -307,3 +288,82 @@ func TestAuthHTTPExclude(t *testing.T) { }) require.NoError(t, err) } + +func TestAuthJWT(t *testing.T) { + // taken from + // https://github.com/MicahParks/jwkset/blob/master/examples/http_server/main.go + + key, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + + jwk, err := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{ + Metadata: jwkset.JWKMetadataOptions{ + KID: "test-key-id", + }, + }) + require.NoError(t, err) + + jwkSet := jwkset.NewMemoryStorage() + err = jwkSet.KeyWrite(context.Background(), jwk) + require.NoError(t, err) + + httpServ := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + response, err := jwkSet.JSONPublic(r.Context()) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(response) + }), + } + + ln, err := net.Listen("tcp", "localhost:4567") + require.NoError(t, err) + + go httpServ.Serve(ln) + defer httpServ.Shutdown(context.Background()) + + type customClaims struct { + jwt.RegisteredClaims + MediaMTXPermissions []conf.AuthInternalUserPermission `json:"mediamtx_permissions"` + } + + claims := customClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)), + IssuedAt: jwt.NewNumericDate(time.Now()), + NotBefore: jwt.NewNumericDate(time.Now()), + Issuer: "test", + Subject: "somebody", + ID: "1", + }, + MediaMTXPermissions: []conf.AuthInternalUserPermission{{ + Action: conf.AuthActionPublish, + Path: "mypath", + }}, + } + + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token.Header[jwkset.HeaderKID] = "test-key-id" + ss, err := token.SignedString(key) + require.NoError(t, err) + + m := Manager{ + Method: conf.AuthMethodJWT, + JWTJWKS: "http://localhost:4567/jwks", + } + + err = m.Authenticate(&Request{ + User: "", + Pass: "", + IP: net.ParseIP("127.0.0.1"), + Action: conf.AuthActionPublish, + Path: "mypath", + Protocol: ProtocolRTSP, + Query: "param=value&jwt=" + ss, + }) + require.NoError(t, err) +} diff --git a/internal/staticsources/hls/source_test.go b/internal/staticsources/hls/source_test.go index 429e821a..b4d9faef 100644 --- a/internal/staticsources/hls/source_test.go +++ b/internal/staticsources/hls/source_test.go @@ -79,10 +79,11 @@ func TestSource(t *testing.T) { require.NoError(t, err) }) + s := &http.Server{Handler: router} + ln, err := net.Listen("tcp", "localhost:5780") require.NoError(t, err) - s := &http.Server{Handler: router} go s.Serve(ln) defer s.Shutdown(context.Background())