Feat: Add JWKS rotation API endpoint (#4463)

Co-authored-by: aler9 <46489434+aler9@users.noreply.github.com>
This commit is contained in:
Dan Nicholls
2025-05-10 21:44:02 +10:00
committed by GitHub
parent defee1eed9
commit 7360981aa7
6 changed files with 191 additions and 25 deletions

View File

@@ -1026,6 +1026,22 @@ components:
$ref: '#/components/schemas/WebRTCSession' $ref: '#/components/schemas/WebRTCSession'
paths: paths:
/v3/auth/jwks/refresh:
post:
operationId: authJwksRefresh
tags: [Authentication]
summary: Manually refreshes the JWT JWKS.
responses:
'200':
description: the request was successful.
'500':
description: server error.
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
/v3/config/global/get: /v3/config/global/get:
get: get:
operationId: configGlobalGet operationId: configGlobalGet

View File

@@ -79,6 +79,7 @@ func recordingsOfPath(
type apiAuthManager interface { type apiAuthManager interface {
Authenticate(req *auth.Request) error Authenticate(req *auth.Request) error
RefreshJWTJWKS()
} }
type apiParent interface { type apiParent interface {
@@ -121,6 +122,8 @@ func (a *API) Initialize() error {
group := router.Group("/v3") group := router.Group("/v3")
group.POST("/auth/jwks/refresh", a.onAuthJwksRefresh)
group.GET("/config/global/get", a.onConfigGlobalGet) group.GET("/config/global/get", a.onConfigGlobalGet)
group.PATCH("/config/global/patch", a.onConfigGlobalPatch) group.PATCH("/config/global/patch", a.onConfigGlobalPatch)
@@ -536,6 +539,11 @@ func (a *API) onConfigPathsDelete(ctx *gin.Context) {
ctx.Status(http.StatusOK) ctx.Status(http.StatusOK)
} }
func (a *API) onAuthJwksRefresh(ctx *gin.Context) {
a.AuthManager.RefreshJWTJWKS()
ctx.Status(http.StatusOK)
}
func (a *API) onPathsList(ctx *gin.Context) { func (a *API) onPathsList(ctx *gin.Context) {
data, err := a.PathManager.APIPathsList() data, err := a.PathManager.APIPathsList()
if err != nil { if err != nil {

View File

@@ -11,6 +11,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/bluenviron/mediamtx/internal/auth"
"github.com/bluenviron/mediamtx/internal/conf" "github.com/bluenviron/mediamtx/internal/conf"
"github.com/bluenviron/mediamtx/internal/logger" "github.com/bluenviron/mediamtx/internal/logger"
"github.com/bluenviron/mediamtx/internal/test" "github.com/bluenviron/mediamtx/internal/test"
@@ -718,3 +719,41 @@ func TestRecordingsDeleteSegment(t *testing.T) {
defer res.Body.Close() defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode) require.Equal(t, http.StatusOK, res.StatusCode)
} }
func TestAuthJWKSRefresh(t *testing.T) {
ok := false
api := API{
Address: "localhost:9997",
ReadTimeout: conf.Duration(10 * time.Second),
AuthManager: &test.AuthManager{
AuthenticateImpl: func(_ *auth.Request) error {
return nil
},
RefreshJWTJWKSImpl: func() {
ok = true
},
},
Parent: &testParent{},
}
err := api.Initialize()
require.NoError(t, err)
defer api.Close()
tr := &http.Transport{}
defer tr.CloseIdleConnections()
hc := &http.Client{Transport: tr}
u, err := url.Parse("http://localhost:9997/v3/auth/jwks/refresh")
require.NoError(t, err)
req, err := http.NewRequest(http.MethodPost, u.String(), nil)
require.NoError(t, err)
res, err := hc.Do(req)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusOK, res.StatusCode)
require.True(t, ok)
}

View File

@@ -267,3 +267,11 @@ func (m *Manager) pullJWTJWKS() (jwt.Keyfunc, error) {
return m.jwtKeyFunc.Keyfunc, nil return m.jwtKeyFunc.Keyfunc, nil
} }
// RefreshJWTJWKS refreshes the JWT JWKS.
func (m *Manager) RefreshJWTJWKS() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.jwtLastRefresh = time.Time{}
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand" "crypto/rand"
"crypto/rsa" "crypto/rsa"
"encoding/json" "encoding/json"
"fmt"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@@ -326,19 +327,19 @@ func TestAuthJWT(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 1024) key, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err) require.NoError(t, err)
jwk, err := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{
Metadata: jwkset.JWKMetadataOptions{
KID: "test-key-id",
},
})
require.NoError(t, err)
jwkSet := jwkset.NewMemoryStorage()
err = jwkSet.KeyWrite(context.Background(), jwk)
require.NoError(t, err)
httpServ := &http.Server{ httpServ := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwk, err2 := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{
Metadata: jwkset.JWKMetadataOptions{
KID: "test-key-id",
},
})
require.NoError(t, err2)
jwkSet := jwkset.NewMemoryStorage()
err2 = jwkSet.KeyWrite(context.Background(), jwk)
require.NoError(t, err2)
response, err2 := jwkSet.JSONPublic(r.Context()) response, err2 := jwkSet.JSONPublic(r.Context())
if err2 != nil { if err2 != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
@@ -420,19 +421,19 @@ func TestAuthJWTAsString(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 1024) key, err := rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err) require.NoError(t, err)
jwk, err := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{
Metadata: jwkset.JWKMetadataOptions{
KID: "test-key-id",
},
})
require.NoError(t, err)
jwkSet := jwkset.NewMemoryStorage()
err = jwkSet.KeyWrite(context.Background(), jwk)
require.NoError(t, err)
httpServ := &http.Server{ httpServ := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
jwk, err2 := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{
Metadata: jwkset.JWKMetadataOptions{
KID: "test-key-id",
},
})
require.NoError(t, err2)
jwkSet := jwkset.NewMemoryStorage()
err2 = jwkSet.KeyWrite(context.Background(), jwk)
require.NoError(t, err2)
response, err2 := jwkSet.JSONPublic(r.Context()) response, err2 := jwkSet.JSONPublic(r.Context())
if err2 != nil { if err2 != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
@@ -515,3 +516,89 @@ func TestAuthJWTExclude(t *testing.T) {
}) })
require.NoError(t, err) require.NoError(t, err)
} }
func TestAuthJWTRefresh(t *testing.T) {
// taken from
// https://github.com/MicahParks/jwkset/blob/master/examples/http_server/main.go
var key *rsa.PrivateKey
httpServ := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Println("AA")
jwk, err := jwkset.NewJWKFromKey(key, jwkset.JWKOptions{
Metadata: jwkset.JWKMetadataOptions{
KID: "test-key-id",
},
})
require.NoError(t, err)
jwkSet := jwkset.NewMemoryStorage()
err = jwkSet.KeyWrite(context.Background(), jwk)
require.NoError(t, err)
response, err2 := jwkSet.JSONPublic(r.Context())
if err2 != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write(response)
}),
}
ln, err := net.Listen("tcp", "localhost:4567")
require.NoError(t, err)
go httpServ.Serve(ln)
defer httpServ.Shutdown(context.Background())
m := Manager{
Method: conf.AuthMethodJWT,
JWTJWKS: "http://localhost:4567/jwks",
JWTClaimKey: "my_permission_key",
}
for i := 0; i < 2; i++ {
key, err = rsa.GenerateKey(rand.Reader, 1024)
require.NoError(t, err)
type customClaims struct {
jwt.RegisteredClaims
MediaMTXPermissions []conf.AuthInternalUserPermission `json:"my_permission_key"`
}
claims := customClaims{
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
NotBefore: jwt.NewNumericDate(time.Now()),
Issuer: "test",
Subject: "somebody",
ID: "1",
},
MediaMTXPermissions: []conf.AuthInternalUserPermission{{
Action: conf.AuthActionPublish,
Path: "mypath",
}},
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header[jwkset.HeaderKID] = "test-key-id"
ss, err := token.SignedString(key)
require.NoError(t, err)
err = m.Authenticate(&Request{
IP: net.ParseIP("127.0.0.1"),
Action: conf.AuthActionPublish,
Path: "mypath",
Protocol: ProtocolRTSP,
Query: "param=value&jwt=" + ss,
})
require.NoError(t, err)
m.RefreshJWTJWKS()
}
}

View File

@@ -4,17 +4,25 @@ import "github.com/bluenviron/mediamtx/internal/auth"
// AuthManager is a dummy auth manager. // AuthManager is a dummy auth manager.
type AuthManager struct { type AuthManager struct {
fnc func(req *auth.Request) error AuthenticateImpl func(req *auth.Request) error
RefreshJWTJWKSImpl func()
} }
// Authenticate replicates auth.Manager.Replicate // Authenticate replicates auth.Manager.Replicate
func (m *AuthManager) Authenticate(req *auth.Request) error { func (m *AuthManager) Authenticate(req *auth.Request) error {
return m.fnc(req) return m.AuthenticateImpl(req)
}
// RefreshJWTJWKS is a function that simulates a JWKS refresh.
func (m *AuthManager) RefreshJWTJWKS() {
m.RefreshJWTJWKSImpl()
} }
// NilAuthManager is an auth manager that accepts everything. // NilAuthManager is an auth manager that accepts everything.
var NilAuthManager = &AuthManager{ var NilAuthManager = &AuthManager{
fnc: func(_ *auth.Request) error { AuthenticateImpl: func(_ *auth.Request) error {
return nil return nil
}, },
RefreshJWTJWKSImpl: func() {
},
} }