mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 21:01:58 +08:00
CLI: Added JWT issuance and diagnostics sub commands #5230
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -50,7 +50,6 @@ func authAnyJWT(c *gin.Context, clientIP, authToken string, resource acl.Resourc
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
verifier := clusterjwt.NewVerifier(conf)
|
|
||||||
requiredScopes := []string{"cluster"}
|
requiredScopes := []string{"cluster"}
|
||||||
if resource == acl.ResourceVision {
|
if resource == acl.ResourceVision {
|
||||||
requiredScopes = []string{"vision"}
|
requiredScopes = []string{"vision"}
|
||||||
@@ -77,7 +76,7 @@ func authAnyJWT(c *gin.Context, clientIP, authToken string, resource acl.Resourc
|
|||||||
|
|
||||||
for _, issuer := range issuers {
|
for _, issuer := range issuers {
|
||||||
expected.Issuer = issuer
|
expected.Issuer = issuer
|
||||||
claims, err = verifier.VerifyToken(ctx, authToken, expected)
|
claims, err = get.VerifyJWT(ctx, authToken, expected)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@@ -25,12 +25,12 @@ func TestEcho(t *testing.T) {
|
|||||||
t.Logf("Response Body: %s", r.Body.String())
|
t.Logf("Response Body: %s", r.Body.String())
|
||||||
|
|
||||||
body := r.Body.String()
|
body := r.Body.String()
|
||||||
url := gjson.Get(body, "url").String()
|
bodyUrl := gjson.Get(body, "url").String()
|
||||||
method := gjson.Get(body, "method").String()
|
method := gjson.Get(body, "method").String()
|
||||||
request := gjson.Get(body, "headers.request")
|
request := gjson.Get(body, "headers.request")
|
||||||
response := gjson.Get(body, "headers.response")
|
response := gjson.Get(body, "headers.response")
|
||||||
|
|
||||||
assert.Equal(t, "/api/v1/echo", url)
|
assert.Equal(t, "/api/v1/echo", bodyUrl)
|
||||||
assert.Equal(t, "GET", method)
|
assert.Equal(t, "GET", method)
|
||||||
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
||||||
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
||||||
@@ -49,12 +49,12 @@ func TestEcho(t *testing.T) {
|
|||||||
r := AuthenticatedRequest(app, http.MethodPost, "/api/v1/echo", authToken)
|
r := AuthenticatedRequest(app, http.MethodPost, "/api/v1/echo", authToken)
|
||||||
|
|
||||||
body := r.Body.String()
|
body := r.Body.String()
|
||||||
url := gjson.Get(body, "url").String()
|
bodyUrl := gjson.Get(body, "url").String()
|
||||||
method := gjson.Get(body, "method").String()
|
method := gjson.Get(body, "method").String()
|
||||||
request := gjson.Get(body, "headers.request")
|
request := gjson.Get(body, "headers.request")
|
||||||
response := gjson.Get(body, "headers.response")
|
response := gjson.Get(body, "headers.response")
|
||||||
|
|
||||||
assert.Equal(t, "/api/v1/echo", url)
|
assert.Equal(t, "/api/v1/echo", bodyUrl)
|
||||||
assert.Equal(t, "POST", method)
|
assert.Equal(t, "POST", method)
|
||||||
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
||||||
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
||||||
|
@@ -25,6 +25,20 @@ var (
|
|||||||
errKeyNotFound = errors.New("jwt: key not found")
|
errKeyNotFound = errors.New("jwt: key not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// VerifierStatus captures diagnostic information about a verifier's JWKS cache state.
|
||||||
|
type VerifierStatus struct {
|
||||||
|
CacheURL string `json:"cacheUrl,omitempty"`
|
||||||
|
CacheETag string `json:"cacheEtag,omitempty"`
|
||||||
|
KeyIDs []string `json:"keyIds,omitempty"`
|
||||||
|
KeyCount int `json:"keyCount"`
|
||||||
|
CacheFetchedAt time.Time `json:"cacheFetchedAt,omitempty"`
|
||||||
|
CacheAgeSeconds int64 `json:"cacheAgeSeconds"`
|
||||||
|
CacheTTLSeconds int `json:"cacheTtlSeconds"`
|
||||||
|
CacheStale bool `json:"cacheStale"`
|
||||||
|
CachePath string `json:"cachePath,omitempty"`
|
||||||
|
JWKSURL string `json:"jwksUrl,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// jwksFetchMaxRetries caps the number of immediate retry attempts after a fetch error.
|
// jwksFetchMaxRetries caps the number of immediate retry attempts after a fetch error.
|
||||||
jwksFetchMaxRetries = 3
|
jwksFetchMaxRetries = 3
|
||||||
@@ -99,15 +113,14 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
|||||||
if strings.TrimSpace(expected.Audience) == "" {
|
if strings.TrimSpace(expected.Audience) == "" {
|
||||||
return nil, errors.New("jwt: expected audience required")
|
return nil, errors.New("jwt: expected audience required")
|
||||||
}
|
}
|
||||||
if len(expected.Scope) == 0 {
|
|
||||||
return nil, errors.New("jwt: expected scope required")
|
jwksUrl := strings.TrimSpace(expected.JWKSURL)
|
||||||
|
|
||||||
|
if jwksUrl == "" && v.conf != nil {
|
||||||
|
jwksUrl = strings.TrimSpace(v.conf.JWKSUrl())
|
||||||
}
|
}
|
||||||
|
|
||||||
url := strings.TrimSpace(expected.JWKSURL)
|
if jwksUrl == "" {
|
||||||
if url == "" && v.conf != nil {
|
|
||||||
url = strings.TrimSpace(v.conf.JWKSUrl())
|
|
||||||
}
|
|
||||||
if url == "" {
|
|
||||||
return nil, errors.New("jwt: jwks url not configured")
|
return nil, errors.New("jwt: jwks url not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,16 +139,111 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
|||||||
claims := &Claims{}
|
claims := &Claims{}
|
||||||
keyFunc := func(token *gojwt.Token) (interface{}, error) {
|
keyFunc := func(token *gojwt.Token) (interface{}, error) {
|
||||||
kid, _ := token.Header["kid"].(string)
|
kid, _ := token.Header["kid"].(string)
|
||||||
|
|
||||||
if kid == "" {
|
if kid == "" {
|
||||||
return nil, errors.New("jwt: missing kid header")
|
return nil, errors.New("jwt: missing kid header")
|
||||||
}
|
}
|
||||||
pk, err := v.publicKeyForKid(ctx, url, kid, false)
|
|
||||||
|
pk, err := v.publicKeyForKid(ctx, jwksUrl, kid, false)
|
||||||
|
|
||||||
if errors.Is(err, errKeyNotFound) {
|
if errors.Is(err, errKeyNotFound) {
|
||||||
pk, err = v.publicKeyForKid(ctx, url, kid, true)
|
pk, err = v.publicKeyForKid(ctx, jwksUrl, kid, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return pk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := parser.ParseWithClaims(tokenString, claims, keyFunc); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
|
||||||
|
return nil, errors.New("jwt: missing temporal claims")
|
||||||
|
}
|
||||||
|
|
||||||
|
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
|
||||||
|
return nil, errors.New("jwt: token ttl exceeds maximum")
|
||||||
|
}
|
||||||
|
|
||||||
|
scopeSet := map[string]struct{}{}
|
||||||
|
|
||||||
|
for _, s := range strings.Fields(claims.Scope) {
|
||||||
|
scopeSet[s] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, req := range expected.Scope {
|
||||||
|
if _, ok := scopeSet[req]; !ok {
|
||||||
|
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyTokenWithKeys verifies a token using the provided JWKS keys without performing HTTP fetches.
|
||||||
|
func VerifyTokenWithKeys(tokenString string, expected ExpectedClaims, keys []PublicJWK, leeway time.Duration) (*Claims, error) {
|
||||||
|
if strings.TrimSpace(tokenString) == "" {
|
||||||
|
return nil, errors.New("jwt: token is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil, errors.New("jwt: no jwks keys provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if leeway <= 0 {
|
||||||
|
leeway = 60 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
keyMap := make(map[string]ed25519.PublicKey, len(keys))
|
||||||
|
|
||||||
|
for _, jwk := range keys {
|
||||||
|
if jwk.Kid == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
raw, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(raw) != ed25519.PublicKeySize {
|
||||||
|
return nil, fmt.Errorf("jwt: invalid public key length %d", len(raw))
|
||||||
|
}
|
||||||
|
pk := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||||
|
copy(pk, raw)
|
||||||
|
keyMap[jwk.Kid] = pk
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(keyMap) == 0 {
|
||||||
|
return nil, errors.New("jwt: no valid jwks keys provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
options := []gojwt.ParserOption{
|
||||||
|
gojwt.WithLeeway(leeway),
|
||||||
|
gojwt.WithValidMethods([]string{gojwt.SigningMethodEdDSA.Alg()}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if iss := strings.TrimSpace(expected.Issuer); iss != "" {
|
||||||
|
options = append(options, gojwt.WithIssuer(iss))
|
||||||
|
}
|
||||||
|
|
||||||
|
if aud := strings.TrimSpace(expected.Audience); aud != "" {
|
||||||
|
options = append(options, gojwt.WithAudience(aud))
|
||||||
|
}
|
||||||
|
|
||||||
|
parser := gojwt.NewParser(options...)
|
||||||
|
claims := &Claims{}
|
||||||
|
keyFunc := func(token *gojwt.Token) (interface{}, error) {
|
||||||
|
kid, _ := token.Header["kid"].(string)
|
||||||
|
if kid == "" {
|
||||||
|
return nil, errors.New("jwt: missing kid header")
|
||||||
|
}
|
||||||
|
pk, ok := keyMap[kid]
|
||||||
|
if !ok {
|
||||||
|
return nil, errKeyNotFound
|
||||||
|
}
|
||||||
return pk, nil
|
return pk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,29 +254,70 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
|||||||
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
|
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
|
||||||
return nil, errors.New("jwt: missing temporal claims")
|
return nil, errors.New("jwt: missing temporal claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
|
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
|
||||||
return nil, errors.New("jwt: token ttl exceeds maximum")
|
return nil, errors.New("jwt: token ttl exceeds maximum")
|
||||||
}
|
}
|
||||||
|
|
||||||
scopeSet := map[string]struct{}{}
|
if len(expected.Scope) > 0 {
|
||||||
for _, s := range strings.Fields(claims.Scope) {
|
scopeSet := map[string]struct{}{}
|
||||||
scopeSet[s] = struct{}{}
|
for _, s := range strings.Fields(claims.Scope) {
|
||||||
}
|
scopeSet[s] = struct{}{}
|
||||||
for _, req := range expected.Scope {
|
}
|
||||||
if _, ok := scopeSet[req]; !ok {
|
for _, req := range expected.Scope {
|
||||||
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
if _, ok := scopeSet[req]; !ok {
|
||||||
|
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Status returns diagnostic information about the verifier's current JWKS cache.
|
||||||
|
func (v *Verifier) Status(ttl time.Duration) VerifierStatus {
|
||||||
|
status := VerifierStatus{}
|
||||||
|
|
||||||
|
if ttl > 0 {
|
||||||
|
status.CacheTTLSeconds = int(ttl / time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
|
||||||
|
status.CacheURL = v.cache.URL
|
||||||
|
status.CacheETag = v.cache.ETag
|
||||||
|
status.JWKSURL = v.cache.URL
|
||||||
|
status.KeyCount = len(v.cache.Keys)
|
||||||
|
status.KeyIDs = make([]string, 0, len(v.cache.Keys))
|
||||||
|
|
||||||
|
for _, key := range v.cache.Keys {
|
||||||
|
status.KeyIDs = append(status.KeyIDs, key.Kid)
|
||||||
|
}
|
||||||
|
|
||||||
|
status.CachePath = v.cachePath
|
||||||
|
|
||||||
|
if v.cache.FetchedAt > 0 {
|
||||||
|
fetched := time.Unix(v.cache.FetchedAt, 0).UTC()
|
||||||
|
status.CacheFetchedAt = fetched
|
||||||
|
age := time.Since(fetched)
|
||||||
|
status.CacheAgeSeconds = int64(age.Seconds())
|
||||||
|
if ttl > 0 && age > ttl {
|
||||||
|
status.CacheStale = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
|
||||||
// publicKeyForKid resolves the public key for the given key ID, fetching JWKS data if needed.
|
// publicKeyForKid resolves the public key for the given key ID, fetching JWKS data if needed.
|
||||||
func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force bool) (ed25519.PublicKey, error) {
|
func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force bool) (ed25519.PublicKey, error) {
|
||||||
keys, err := v.keysForURL(ctx, url, force)
|
keys, err := v.keysForURL(ctx, url, force)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, k := range keys {
|
for _, k := range keys {
|
||||||
if k.Kid != kid {
|
if k.Kid != kid {
|
||||||
continue
|
continue
|
||||||
@@ -184,12 +333,14 @@ func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force b
|
|||||||
copy(pk, raw)
|
copy(pk, raw)
|
||||||
return pk, nil
|
return pk, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, errKeyNotFound
|
return nil, errKeyNotFound
|
||||||
}
|
}
|
||||||
|
|
||||||
// keysForURL returns JWKS keys for the specified endpoint, reusing cache when possible.
|
// keysForURL returns JWKS keys for the specified endpoint, reusing cache when possible.
|
||||||
func (v *Verifier) keysForURL(ctx context.Context, url string, force bool) ([]PublicJWK, error) {
|
func (v *Verifier) keysForURL(ctx context.Context, url string, force bool) ([]PublicJWK, error) {
|
||||||
ttl := 300 * time.Second
|
ttl := 300 * time.Second
|
||||||
|
|
||||||
if v.conf != nil && v.conf.JWKSCacheTTL() > 0 {
|
if v.conf != nil && v.conf.JWKSCacheTTL() > 0 {
|
||||||
ttl = time.Duration(v.conf.JWKSCacheTTL()) * time.Second
|
ttl = time.Duration(v.conf.JWKSCacheTTL()) * time.Second
|
||||||
}
|
}
|
||||||
@@ -250,13 +401,16 @@ func (v *Verifier) cachedKeys(url string, ttl time.Duration, cache cacheEntry, f
|
|||||||
if force || cache.URL != url || len(cache.Keys) == 0 {
|
if force || cache.URL != url || len(cache.Keys) == 0 {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
age := v.now().Unix() - cache.FetchedAt
|
age := v.now().Unix() - cache.FetchedAt
|
||||||
if age < 0 {
|
if age < 0 {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if time.Duration(age)*time.Second > ttl {
|
if time.Duration(age)*time.Second > ttl {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return append([]PublicJWK(nil), cache.Keys...), true
|
return append([]PublicJWK(nil), cache.Keys...), true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,17 +424,21 @@ type jwksFetchResult struct {
|
|||||||
// fetchJWKS downloads the JWKS document (respecting conditional requests) and returns the parsed keys.
|
// fetchJWKS downloads the JWKS document (respecting conditional requests) and returns the parsed keys.
|
||||||
func (v *Verifier) fetchJWKS(ctx context.Context, url, etag string) (*jwksFetchResult, error) {
|
func (v *Verifier) fetchJWKS(ctx context.Context, url, etag string) (*jwksFetchResult, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if etag != "" {
|
if etag != "" {
|
||||||
req.Header.Set("If-None-Match", etag)
|
req.Header.Set("If-None-Match", etag)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := v.httpClient.Do(req)
|
resp, err := v.httpClient.Do(req)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
switch resp.StatusCode {
|
switch resp.StatusCode {
|
||||||
@@ -331,6 +489,7 @@ func (v *Verifier) updateCache(url string, result *jwksFetchResult) ([]PublicJWK
|
|||||||
Keys: append([]PublicJWK(nil), result.keys...),
|
Keys: append([]PublicJWK(nil), result.keys...),
|
||||||
FetchedAt: result.fetchedAt,
|
FetchedAt: result.fetchedAt,
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = v.saveCacheLocked()
|
_ = v.saveCacheLocked()
|
||||||
return append([]PublicJWK(nil), v.cache.Keys...), true
|
return append([]PublicJWK(nil), v.cache.Keys...), true
|
||||||
}
|
}
|
||||||
@@ -347,7 +506,7 @@ func (v *Verifier) loadCache() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var entry cacheEntry
|
var entry cacheEntry
|
||||||
if err := json.Unmarshal(b, &entry); err != nil {
|
if err = json.Unmarshal(b, &entry); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,13 +519,17 @@ func (v *Verifier) saveCacheLocked() error {
|
|||||||
if v.cachePath == "" {
|
if v.cachePath == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := fs.MkdirAll(filepath.Dir(v.cachePath)); err != nil {
|
if err := fs.MkdirAll(filepath.Dir(v.cachePath)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(v.cache)
|
data, err := json.Marshal(v.cache)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return os.WriteFile(v.cachePath, data, fs.ModeSecretFile)
|
return os.WriteFile(v.cachePath, data, fs.ModeSecretFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,11 +540,13 @@ func backoffDuration(attempt int) time.Duration {
|
|||||||
}
|
}
|
||||||
|
|
||||||
base := jwksFetchBaseDelay << (attempt - 1)
|
base := jwksFetchBaseDelay << (attempt - 1)
|
||||||
|
|
||||||
if base > jwksFetchMaxDelay {
|
if base > jwksFetchMaxDelay {
|
||||||
base = jwksFetchMaxDelay
|
base = jwksFetchMaxDelay
|
||||||
}
|
}
|
||||||
|
|
||||||
jitterRange := base / 2
|
jitterRange := base / 2
|
||||||
|
|
||||||
if jitterRange > 0 {
|
if jitterRange > 0 {
|
||||||
base += time.Duration(randInt63n(int64(jitterRange) + 1))
|
base += time.Duration(randInt63n(int64(jitterRange) + 1))
|
||||||
}
|
}
|
||||||
|
@@ -103,6 +103,56 @@ func TestVerifierPrimeAndVerify(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestVerifyTokenWithKeys(t *testing.T) {
|
||||||
|
portalCfg := newTestConfig(t)
|
||||||
|
clusterUUID := rnd.UUIDv7()
|
||||||
|
portalCfg.Options().ClusterUUID = clusterUUID
|
||||||
|
|
||||||
|
mgr, err := NewManager(portalCfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
mgr.now = func() time.Time { return time.Date(2025, 9, 24, 10, 30, 0, 0, time.UTC) }
|
||||||
|
_, err = mgr.EnsureActiveKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
issuer := NewIssuer(mgr)
|
||||||
|
issuer.now = func() time.Time { return time.Now().UTC() }
|
||||||
|
|
||||||
|
spec := ClaimsSpec{
|
||||||
|
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||||
|
Subject: "portal:client-test",
|
||||||
|
Audience: "node:1234",
|
||||||
|
Scope: []string{"cluster"},
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := issuer.Issue(spec)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keys := mgr.JWKS().Keys
|
||||||
|
claims, err := VerifyTokenWithKeys(token, ExpectedClaims{
|
||||||
|
Issuer: spec.Issuer,
|
||||||
|
Audience: spec.Audience,
|
||||||
|
Scope: []string{"cluster"},
|
||||||
|
}, keys, 60*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, spec.Subject, claims.Subject)
|
||||||
|
|
||||||
|
// Ensure scope filtering is honored when expected scope is empty.
|
||||||
|
claims, err = VerifyTokenWithKeys(token, ExpectedClaims{
|
||||||
|
Issuer: spec.Issuer,
|
||||||
|
Audience: spec.Audience,
|
||||||
|
}, keys, 60*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, spec.Subject, claims.Subject)
|
||||||
|
|
||||||
|
// Missing scope should fail when explicitly required.
|
||||||
|
_, err = VerifyTokenWithKeys(token, ExpectedClaims{
|
||||||
|
Issuer: spec.Issuer,
|
||||||
|
Audience: spec.Audience,
|
||||||
|
Scope: []string{"vision"},
|
||||||
|
}, keys, 60*time.Second)
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestIssuerClampTTL(t *testing.T) {
|
func TestIssuerClampTTL(t *testing.T) {
|
||||||
portalCfg := newTestConfig(t)
|
portalCfg := newTestConfig(t)
|
||||||
mgr, err := NewManager(portalCfg)
|
mgr, err := NewManager(portalCfg)
|
||||||
|
@@ -15,6 +15,7 @@ var AuthCommands = &cli.Command{
|
|||||||
AuthShowCommand,
|
AuthShowCommand,
|
||||||
AuthRemoveCommand,
|
AuthRemoveCommand,
|
||||||
AuthResetCommand,
|
AuthResetCommand,
|
||||||
|
AuthJWTCommands,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
16
internal/commands/auth_jwt.go
Normal file
16
internal/commands/auth_jwt.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import "github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
// AuthJWTCommands groups JWT-related auth helpers under photoprism auth jwt.
|
||||||
|
var AuthJWTCommands = &cli.Command{
|
||||||
|
Name: "jwt",
|
||||||
|
Usage: "JWT issuance and diagnostics",
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
AuthJWTIssueCommand,
|
||||||
|
AuthJWTInspectCommand,
|
||||||
|
AuthJWTKeysCommand,
|
||||||
|
AuthJWTStatusCommand,
|
||||||
|
},
|
||||||
|
}
|
154
internal/commands/auth_jwt_inspect.go
Normal file
154
internal/commands/auth_jwt_inspect.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||||
|
"github.com/photoprism/photoprism/internal/config"
|
||||||
|
"github.com/photoprism/photoprism/pkg/clean"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthJWTInspectCommand inspects and verifies portal-issued JWTs.
|
||||||
|
var AuthJWTInspectCommand = &cli.Command{
|
||||||
|
Name: "inspect",
|
||||||
|
Usage: "Decodes and verifies a portal JWT",
|
||||||
|
ArgsUsage: "<token>",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.StringFlag{Name: "file", Aliases: []string{"f"}, Usage: "read token from file"},
|
||||||
|
&cli.StringFlag{Name: "expect-audience", Usage: "expected audience (e.g., node:<uuid>)"},
|
||||||
|
&cli.StringSliceFlag{Name: "require-scope", Usage: "require specific scope(s)"},
|
||||||
|
&cli.BoolFlag{Name: "skip-verify", Usage: "decode without signature verification"},
|
||||||
|
JsonFlag(),
|
||||||
|
},
|
||||||
|
Action: authJWTInspectAction,
|
||||||
|
}
|
||||||
|
|
||||||
|
// authJWTInspectAction decodes and optionally verifies a portal-issued JWT.
|
||||||
|
func authJWTInspectAction(ctx *cli.Context) error {
|
||||||
|
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||||
|
if err := requirePortal(conf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := readTokenInput(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
header, claims, err := decodeJWTClaims(token)
|
||||||
|
if err != nil {
|
||||||
|
return cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
var verified bool
|
||||||
|
tokenScopes := clean.Scopes(claims.Scope)
|
||||||
|
|
||||||
|
if !ctx.Bool("skip-verify") {
|
||||||
|
expected := clusterjwt.ExpectedClaims{}
|
||||||
|
if clusterUUID := strings.TrimSpace(conf.ClusterUUID()); clusterUUID != "" {
|
||||||
|
expected.Issuer = fmt.Sprintf("portal:%s", clusterUUID)
|
||||||
|
} else if portal := strings.TrimSpace(conf.PortalUrl()); portal != "" {
|
||||||
|
expected.Issuer = strings.TrimRight(portal, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectAud := strings.TrimSpace(ctx.String("expect-audience")); expectAud != "" {
|
||||||
|
expected.Audience = expectAud
|
||||||
|
} else if len(claims.Audience) > 0 {
|
||||||
|
expected.Audience = claims.Audience[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
if required := ctx.StringSlice("require-scope"); len(required) > 0 {
|
||||||
|
scopes, scopeErr := normalizeScopes(required)
|
||||||
|
if scopeErr != nil {
|
||||||
|
return scopeErr
|
||||||
|
}
|
||||||
|
expected.Scope = scopes
|
||||||
|
} else {
|
||||||
|
expected.Scope = tokenScopes
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := verifyPortalToken(conf, token, expected); err != nil {
|
||||||
|
return cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
verified = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Bool("json") {
|
||||||
|
payload := map[string]any{
|
||||||
|
"token": token,
|
||||||
|
"verified": verified,
|
||||||
|
"header": header,
|
||||||
|
"claims": claims,
|
||||||
|
}
|
||||||
|
return printJSON(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("JWT header:")
|
||||||
|
for k, v := range header {
|
||||||
|
fmt.Printf(" %s: %v\n", k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\nJWT claims:")
|
||||||
|
fmt.Printf(" issuer: %s\n", claims.Issuer)
|
||||||
|
fmt.Printf(" subject: %s\n", claims.Subject)
|
||||||
|
fmt.Printf(" audience: %s\n", strings.Join(claims.Audience, " "))
|
||||||
|
fmt.Printf(" scope: %s\n", strings.Join(tokenScopes, " "))
|
||||||
|
if claims.IssuedAt != nil {
|
||||||
|
fmt.Printf(" issuedAt: %s\n", claims.IssuedAt.Time.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if claims.ExpiresAt != nil {
|
||||||
|
fmt.Printf(" expiresAt: %s\n", claims.ExpiresAt.Time.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if claims.NotBefore != nil {
|
||||||
|
fmt.Printf(" notBefore: %s\n", claims.NotBefore.Time.UTC().Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
if claims.ID != "" {
|
||||||
|
fmt.Printf(" jti: %s\n", claims.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if verified {
|
||||||
|
fmt.Println("\nSignature: verified")
|
||||||
|
} else {
|
||||||
|
fmt.Println("\nSignature: not verified (skipped)")
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nToken:\n%s\n\n", token)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// readTokenInput loads the token from CLI args, file, or STDIN.
|
||||||
|
func readTokenInput(ctx *cli.Context) (string, error) {
|
||||||
|
if file := strings.TrimSpace(ctx.String("file")); file != "" {
|
||||||
|
data, err := os.ReadFile(file)
|
||||||
|
if err != nil {
|
||||||
|
return "", cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
return strings.TrimSpace(string(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Args().Len() == 0 {
|
||||||
|
return "", cli.Exit(errors.New("token argument required"), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := strings.TrimSpace(ctx.Args().First())
|
||||||
|
if token == "-" {
|
||||||
|
data, err := io.ReadAll(os.Stdin)
|
||||||
|
if err != nil {
|
||||||
|
return "", cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
token = strings.TrimSpace(string(data))
|
||||||
|
}
|
||||||
|
if token == "" {
|
||||||
|
return "", cli.Exit(errors.New("token argument required"), 2)
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
117
internal/commands/auth_jwt_issue.go
Normal file
117
internal/commands/auth_jwt_issue.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||||
|
"github.com/photoprism/photoprism/internal/config"
|
||||||
|
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthJWTIssueCommand issues portal-signed JWTs for cluster nodes.
|
||||||
|
var AuthJWTIssueCommand = &cli.Command{
|
||||||
|
Name: "issue",
|
||||||
|
Usage: "Issues a portal-signed JWT for a node",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
&cli.StringFlag{Name: "node", Aliases: []string{"n"}, Usage: "target node uuid, client id, or DNS label", Required: true},
|
||||||
|
&cli.StringSliceFlag{Name: "scope", Aliases: []string{"s"}, Usage: "token scope", Value: cli.NewStringSlice("cluster")},
|
||||||
|
&cli.DurationFlag{Name: "ttl", Usage: "token lifetime", Value: clusterjwt.TokenTTL},
|
||||||
|
&cli.StringFlag{Name: "subject", Usage: "token subject (default portal:<clusterUUID>)"},
|
||||||
|
JsonFlag(),
|
||||||
|
},
|
||||||
|
Action: authJWTIssueAction,
|
||||||
|
}
|
||||||
|
|
||||||
|
// authJWTIssueAction handles CLI issuance of portal-signed JWTs for nodes.
|
||||||
|
func authJWTIssueAction(ctx *cli.Context) error {
|
||||||
|
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||||
|
if err := requirePortal(conf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
node, err := resolveNode(conf, ctx.String("node"))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
scopes, err := normalizeScopes(ctx.StringSlice("scope"), "cluster")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := ctx.Duration("ttl")
|
||||||
|
if ttl <= 0 {
|
||||||
|
ttl = clusterjwt.TokenTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
|
||||||
|
if clusterUUID == "" {
|
||||||
|
return cli.Exit(fmt.Errorf("cluster uuid not configured"), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
subject := strings.TrimSpace(ctx.String("subject"))
|
||||||
|
if subject == "" {
|
||||||
|
subject = fmt.Sprintf("portal:%s", clusterUUID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token string
|
||||||
|
if subject == fmt.Sprintf("portal:%s", clusterUUID) {
|
||||||
|
token, err = get.IssuePortalJWTForNode(node.UUID, scopes, ttl)
|
||||||
|
} else {
|
||||||
|
spec := clusterjwt.ClaimsSpec{
|
||||||
|
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||||
|
Subject: subject,
|
||||||
|
Audience: fmt.Sprintf("node:%s", node.UUID),
|
||||||
|
Scope: scopes,
|
||||||
|
TTL: ttl,
|
||||||
|
}
|
||||||
|
token, err = get.IssuePortalJWT(spec)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
header, claims, err := decodeJWTClaims(token)
|
||||||
|
if err != nil {
|
||||||
|
return cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Bool("json") {
|
||||||
|
payload := map[string]any{
|
||||||
|
"token": token,
|
||||||
|
"header": header,
|
||||||
|
"claims": claims,
|
||||||
|
"node": map[string]string{
|
||||||
|
"uuid": node.UUID,
|
||||||
|
"clientId": node.ClientID,
|
||||||
|
"name": node.Name,
|
||||||
|
"role": string(node.Role),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return printJSON(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
expires := "unknown"
|
||||||
|
if claims.ExpiresAt != nil {
|
||||||
|
expires = claims.ExpiresAt.Time.UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
audience := strings.Join(claims.Audience, " ")
|
||||||
|
if audience == "" {
|
||||||
|
audience = "(none)"
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("\nIssued JWT for node %s (%s)\n", node.Name, node.UUID)
|
||||||
|
fmt.Printf("Scopes: %s\n", strings.Join(scopes, " "))
|
||||||
|
fmt.Printf("Expires: %s\n", expires)
|
||||||
|
fmt.Printf("Audience: %s\n", audience)
|
||||||
|
fmt.Printf("Subject: %s\n", claims.Subject)
|
||||||
|
fmt.Printf("Key ID: %v\n", header["kid"])
|
||||||
|
fmt.Printf("\n%s\n", token)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
107
internal/commands/auth_jwt_keys.go
Normal file
107
internal/commands/auth_jwt_keys.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
"github.com/photoprism/photoprism/internal/config"
|
||||||
|
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthJWTKeysCommand groups JWT key management helpers.
|
||||||
|
var AuthJWTKeysCommand = &cli.Command{
|
||||||
|
Name: "keys",
|
||||||
|
Usage: "JWT signing key helpers",
|
||||||
|
Subcommands: []*cli.Command{
|
||||||
|
AuthJWTKeysListCommand,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthJWTKeysListCommand lists JWT signing keys.
|
||||||
|
var AuthJWTKeysListCommand = &cli.Command{
|
||||||
|
Name: "ls",
|
||||||
|
Usage: "Lists JWT signing keys",
|
||||||
|
Aliases: []string{"list"},
|
||||||
|
ArgsUsage: "",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
JsonFlag(),
|
||||||
|
},
|
||||||
|
Action: authJWTKeysListAction,
|
||||||
|
}
|
||||||
|
|
||||||
|
// authJWTKeysListAction lists portal signing keys with metadata.
|
||||||
|
func authJWTKeysListAction(ctx *cli.Context) error {
|
||||||
|
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||||
|
if err := requirePortal(conf); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := get.JWTManager()
|
||||||
|
if manager == nil {
|
||||||
|
return cli.Exit(errors.New("jwt manager not available"), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := manager.AllKeys()
|
||||||
|
active, _ := manager.ActiveKey()
|
||||||
|
activeKid := ""
|
||||||
|
if active != nil {
|
||||||
|
activeKid = active.Kid
|
||||||
|
}
|
||||||
|
|
||||||
|
type keyInfo struct {
|
||||||
|
Kid string `json:"kid"`
|
||||||
|
CreatedAt string `json:"createdAt"`
|
||||||
|
NotAfter string `json:"notAfter,omitempty"`
|
||||||
|
Active bool `json:"active"`
|
||||||
|
}
|
||||||
|
|
||||||
|
rows := make([]keyInfo, 0, len(keys))
|
||||||
|
for _, k := range keys {
|
||||||
|
info := keyInfo{Kid: k.Kid, Active: k.Kid == activeKid}
|
||||||
|
if k.CreatedAt > 0 {
|
||||||
|
info.CreatedAt = time.Unix(k.CreatedAt, 0).UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
if k.NotAfter > 0 {
|
||||||
|
info.NotAfter = time.Unix(k.NotAfter, 0).UTC().Format(time.RFC3339)
|
||||||
|
}
|
||||||
|
rows = append(rows, info)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.Bool("json") {
|
||||||
|
payload := map[string]any{
|
||||||
|
"keys": rows,
|
||||||
|
}
|
||||||
|
return printJSON(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rows) == 0 {
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("No signing keys found.")
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Println("JWT signing keys:")
|
||||||
|
for _, row := range rows {
|
||||||
|
status := ""
|
||||||
|
if row.Active {
|
||||||
|
status = " (active)"
|
||||||
|
}
|
||||||
|
parts := []string{fmt.Sprintf("KID: %s%s", row.Kid, status)}
|
||||||
|
if row.CreatedAt != "" {
|
||||||
|
parts = append(parts, fmt.Sprintf("created %s", row.CreatedAt))
|
||||||
|
}
|
||||||
|
if row.NotAfter != "" {
|
||||||
|
parts = append(parts, fmt.Sprintf("expires %s", row.NotAfter))
|
||||||
|
}
|
||||||
|
fmt.Printf("- %s\n", strings.Join(parts, ", "))
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
67
internal/commands/auth_jwt_status.go
Normal file
67
internal/commands/auth_jwt_status.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
"github.com/photoprism/photoprism/internal/config"
|
||||||
|
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthJWTStatusCommand reports verifier cache diagnostics.
|
||||||
|
var AuthJWTStatusCommand = &cli.Command{
|
||||||
|
Name: "status",
|
||||||
|
Usage: "Shows JWT verifier cache status",
|
||||||
|
Flags: []cli.Flag{
|
||||||
|
JsonFlag(),
|
||||||
|
},
|
||||||
|
Action: authJWTStatusAction,
|
||||||
|
}
|
||||||
|
|
||||||
|
// authJWTStatusAction prints JWKS cache diagnostics for the current node.
|
||||||
|
func authJWTStatusAction(ctx *cli.Context) error {
|
||||||
|
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||||
|
verifier := get.JWTVerifier()
|
||||||
|
if verifier == nil {
|
||||||
|
return cli.Exit(errors.New("jwt verifier not available"), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := time.Duration(conf.JWKSCacheTTL()) * time.Second
|
||||||
|
status := verifier.Status(ttl)
|
||||||
|
status.JWKSURL = strings.TrimSpace(conf.JWKSUrl())
|
||||||
|
|
||||||
|
if ctx.Bool("json") {
|
||||||
|
return printJSON(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println()
|
||||||
|
fmt.Printf("JWKS URL: %s\n", status.JWKSURL)
|
||||||
|
fmt.Printf("Cache Path: %s\n", status.CachePath)
|
||||||
|
fmt.Printf("Cache URL: %s\n", status.CacheURL)
|
||||||
|
fmt.Printf("Cache ETag: %s\n", status.CacheETag)
|
||||||
|
fmt.Printf("Cached Keys: %d\n", status.KeyCount)
|
||||||
|
if len(status.KeyIDs) > 0 {
|
||||||
|
fmt.Printf("Key IDs: %s\n", strings.Join(status.KeyIDs, ", "))
|
||||||
|
}
|
||||||
|
if !status.CacheFetchedAt.IsZero() {
|
||||||
|
fmt.Printf("Last Fetch: %s\n", status.CacheFetchedAt.Format(time.RFC3339))
|
||||||
|
} else {
|
||||||
|
fmt.Println("Last Fetch: never")
|
||||||
|
}
|
||||||
|
fmt.Printf("Cache Age: %ds\n", status.CacheAgeSeconds)
|
||||||
|
if status.CacheTTLSeconds > 0 {
|
||||||
|
fmt.Printf("Cache TTL: %ds\n", status.CacheTTLSeconds)
|
||||||
|
}
|
||||||
|
if status.CacheStale {
|
||||||
|
fmt.Println("Cache Status: STALE")
|
||||||
|
} else {
|
||||||
|
fmt.Println("Cache Status: fresh")
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
94
internal/commands/auth_jwt_test.go
Normal file
94
internal/commands/auth_jwt_test.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/photoprism/photoprism/internal/config"
|
||||||
|
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||||
|
"github.com/photoprism/photoprism/internal/service/cluster"
|
||||||
|
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
|
||||||
|
"github.com/photoprism/photoprism/pkg/rnd"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAuthJWTCommands(t *testing.T) {
|
||||||
|
conf := get.Config()
|
||||||
|
|
||||||
|
origEdition := conf.Options().Edition
|
||||||
|
origRole := conf.Options().NodeRole
|
||||||
|
origUUID := conf.Options().ClusterUUID
|
||||||
|
origPortal := conf.Options().PortalUrl
|
||||||
|
origJWKS := conf.JWKSUrl()
|
||||||
|
|
||||||
|
conf.Options().Edition = config.Portal
|
||||||
|
conf.Options().NodeRole = string(cluster.RolePortal)
|
||||||
|
conf.Options().ClusterUUID = "11111111-1111-4111-8111-111111111111"
|
||||||
|
conf.Options().PortalUrl = "https://portal.test"
|
||||||
|
conf.SetJWKSUrl("https://portal.test/.well-known/jwks.json")
|
||||||
|
|
||||||
|
get.SetConfig(conf)
|
||||||
|
conf.RegisterDb()
|
||||||
|
|
||||||
|
require.True(t, conf.IsPortal())
|
||||||
|
|
||||||
|
manager := get.JWTManager()
|
||||||
|
require.NotNil(t, manager)
|
||||||
|
_, err := manager.EnsureActiveKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
registry, err := reg.NewClientRegistryWithConfig(conf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
nodeUUID := rnd.UUID()
|
||||||
|
node := ®.Node{}
|
||||||
|
node.UUID = nodeUUID
|
||||||
|
node.Name = "pp-node-01"
|
||||||
|
node.Role = string(cluster.RoleInstance)
|
||||||
|
require.NoError(t, registry.Put(node))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
conf.Options().Edition = origEdition
|
||||||
|
conf.Options().NodeRole = origRole
|
||||||
|
conf.Options().ClusterUUID = origUUID
|
||||||
|
conf.Options().PortalUrl = origPortal
|
||||||
|
conf.SetJWKSUrl(origJWKS)
|
||||||
|
get.SetConfig(conf)
|
||||||
|
conf.RegisterDb()
|
||||||
|
})
|
||||||
|
|
||||||
|
output, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, output, "Issued JWT")
|
||||||
|
|
||||||
|
jsonOut, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--json"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonOut), &payload))
|
||||||
|
require.NotEmpty(t, payload.Token)
|
||||||
|
|
||||||
|
inspectOut, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", payload.Token})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, inspectOut, "\"verified\": true")
|
||||||
|
|
||||||
|
inspectStrict, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", "--expect-audience", "node:" + nodeUUID, "--require-scope", "cluster", payload.Token})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, inspectStrict, "\"verified\": true")
|
||||||
|
|
||||||
|
keysOut, err := RunWithTestContext(AuthJWTKeysListCommand, []string{"ls", "--json"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, keysOut, "\"keys\"")
|
||||||
|
|
||||||
|
statusOut, err := RunWithTestContext(AuthJWTStatusCommand, []string{"status"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, statusOut, "JWKS URL")
|
||||||
|
assert.Contains(t, statusOut, "Cached Keys")
|
||||||
|
|
||||||
|
// invalid scope should fail
|
||||||
|
_, err = RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--scope", "unknown"})
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
@@ -19,8 +19,9 @@ type healthResponse struct {
|
|||||||
// ClusterHealthCommand prints a minimal health response (Portal-only).
|
// ClusterHealthCommand prints a minimal health response (Portal-only).
|
||||||
var ClusterHealthCommand = &cli.Command{
|
var ClusterHealthCommand = &cli.Command{
|
||||||
Name: "health",
|
Name: "health",
|
||||||
Usage: "Shows cluster health (Portal-only)",
|
Usage: "Shows cluster health status",
|
||||||
Flags: report.CliFlags,
|
Flags: report.CliFlags,
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
Action: clusterHealthAction,
|
Action: clusterHealthAction,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -14,8 +14,9 @@ import (
|
|||||||
|
|
||||||
// ClusterNodesCommands groups node subcommands.
|
// ClusterNodesCommands groups node subcommands.
|
||||||
var ClusterNodesCommands = &cli.Command{
|
var ClusterNodesCommands = &cli.Command{
|
||||||
Name: "nodes",
|
Name: "nodes",
|
||||||
Usage: "Node registry subcommands",
|
Usage: "Node registry subcommands",
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
Subcommands: []*cli.Command{
|
Subcommands: []*cli.Command{
|
||||||
ClusterNodesListCommand,
|
ClusterNodesListCommand,
|
||||||
ClusterNodesShowCommand,
|
ClusterNodesShowCommand,
|
||||||
@@ -28,9 +29,10 @@ var ClusterNodesCommands = &cli.Command{
|
|||||||
// ClusterNodesListCommand lists registered nodes.
|
// ClusterNodesListCommand lists registered nodes.
|
||||||
var ClusterNodesListCommand = &cli.Command{
|
var ClusterNodesListCommand = &cli.Command{
|
||||||
Name: "ls",
|
Name: "ls",
|
||||||
Usage: "Lists registered cluster nodes (Portal-only)",
|
Usage: "Lists registered cluster nodes",
|
||||||
Flags: append(report.CliFlags, CountFlag, OffsetFlag),
|
Flags: append(report.CliFlags, CountFlag, OffsetFlag),
|
||||||
ArgsUsage: "",
|
ArgsUsage: "",
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
Action: clusterNodesListAction,
|
Action: clusterNodesListAction,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -22,9 +22,10 @@ var (
|
|||||||
// ClusterNodesModCommand updates node fields.
|
// ClusterNodesModCommand updates node fields.
|
||||||
var ClusterNodesModCommand = &cli.Command{
|
var ClusterNodesModCommand = &cli.Command{
|
||||||
Name: "mod",
|
Name: "mod",
|
||||||
Usage: "Updates node properties (Portal-only)",
|
Usage: "Updates node properties",
|
||||||
ArgsUsage: "<id|name>",
|
ArgsUsage: "<id|name>",
|
||||||
Flags: []cli.Flag{nodesModRoleFlag, nodesModInternal, nodesModLabel, &cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"}},
|
Flags: []cli.Flag{nodesModRoleFlag, nodesModInternal, nodesModLabel, &cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"}},
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
Action: clusterNodesModAction,
|
Action: clusterNodesModAction,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -14,12 +14,13 @@ import (
|
|||||||
// ClusterNodesRemoveCommand deletes a node from the registry.
|
// ClusterNodesRemoveCommand deletes a node from the registry.
|
||||||
var ClusterNodesRemoveCommand = &cli.Command{
|
var ClusterNodesRemoveCommand = &cli.Command{
|
||||||
Name: "rm",
|
Name: "rm",
|
||||||
Usage: "Deletes a node from the registry (Portal-only)",
|
Usage: "Deletes a node from the registry",
|
||||||
ArgsUsage: "<id|name>",
|
ArgsUsage: "<id|name>",
|
||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
&cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"},
|
&cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"},
|
||||||
&cli.BoolFlag{Name: "all-ids", Usage: "delete all records that share the same UUID (admin cleanup)"},
|
&cli.BoolFlag{Name: "all-ids", Usage: "delete all records that share the same UUID (admin cleanup)"},
|
||||||
},
|
},
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
Action: clusterNodesRemoveAction,
|
Action: clusterNodesRemoveAction,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -2,6 +2,7 @@ package commands
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
@@ -106,11 +107,13 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
|
|||||||
}
|
}
|
||||||
b, _ := json.Marshal(payload)
|
b, _ := json.Marshal(payload)
|
||||||
|
|
||||||
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||||
|
|
||||||
var resp cluster.RegisterResponse
|
var resp cluster.RegisterResponse
|
||||||
if err := postWithBackoff(url, token, b, &resp); err != nil {
|
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
|
||||||
// Map common HTTP errors similarly to register command
|
// Map common HTTP errors similarly to register command
|
||||||
if he, ok := err.(*httpError); ok {
|
var he *httpError
|
||||||
|
if errors.As(err, &he) {
|
||||||
switch he.Status {
|
switch he.Status {
|
||||||
case 401, 403:
|
case 401, 403:
|
||||||
return cli.Exit(fmt.Errorf("%s", he.Error()), 4)
|
return cli.Exit(fmt.Errorf("%s", he.Error()), 4)
|
||||||
@@ -151,6 +154,7 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
|
|||||||
fmt.Printf("DSN: %s\n", resp.Database.DSN)
|
fmt.Printf("DSN: %s\n", resp.Database.DSN)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@@ -15,9 +15,10 @@ import (
|
|||||||
// ClusterNodesShowCommand shows node details.
|
// ClusterNodesShowCommand shows node details.
|
||||||
var ClusterNodesShowCommand = &cli.Command{
|
var ClusterNodesShowCommand = &cli.Command{
|
||||||
Name: "show",
|
Name: "show",
|
||||||
Usage: "Shows node details (Portal-only)",
|
Usage: "Shows node details",
|
||||||
ArgsUsage: "<id|name>",
|
ArgsUsage: "<id|name>",
|
||||||
Flags: report.CliFlags,
|
Flags: report.CliFlags,
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
Action: clusterNodesShowAction,
|
Action: clusterNodesShowAction,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -24,7 +24,7 @@ import (
|
|||||||
"github.com/photoprism/photoprism/pkg/txt/report"
|
"github.com/photoprism/photoprism/pkg/txt/report"
|
||||||
)
|
)
|
||||||
|
|
||||||
// flags for register
|
// Supported cluster node register flags.
|
||||||
var (
|
var (
|
||||||
regNameFlag = &cli.StringFlag{Name: "name", Usage: "node `NAME` (lowercase letters, digits, hyphens)"}
|
regNameFlag = &cli.StringFlag{Name: "name", Usage: "node `NAME` (lowercase letters, digits, hyphens)"}
|
||||||
regRoleFlag = &cli.StringFlag{Name: "role", Usage: "node `ROLE` (instance, service)", Value: "instance"}
|
regRoleFlag = &cli.StringFlag{Name: "role", Usage: "node `ROLE` (instance, service)", Value: "instance"}
|
||||||
@@ -42,7 +42,7 @@ var (
|
|||||||
// ClusterRegisterCommand registers a node with the Portal via HTTP.
|
// ClusterRegisterCommand registers a node with the Portal via HTTP.
|
||||||
var ClusterRegisterCommand = &cli.Command{
|
var ClusterRegisterCommand = &cli.Command{
|
||||||
Name: "register",
|
Name: "register",
|
||||||
Usage: "Registers/rotates a node via Portal (HTTP)",
|
Usage: "Registers a node or updates its credentials within a cluster",
|
||||||
Flags: append(append([]cli.Flag{regNameFlag, regRoleFlag, regIntUrlFlag, regLabelFlag, regRotateDatabase, regRotateSec, regPortalURL, regPortalTok, regWriteConf, regForceFlag, regDryRun}, report.CliFlags...)),
|
Flags: append(append([]cli.Flag{regNameFlag, regRoleFlag, regIntUrlFlag, regLabelFlag, regRotateDatabase, regRotateSec, regPortalURL, regPortalTok, regWriteConf, regForceFlag, regDryRun}, report.CliFlags...)),
|
||||||
Action: clusterRegisterAction,
|
Action: clusterRegisterAction,
|
||||||
}
|
}
|
||||||
@@ -52,15 +52,18 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
|||||||
// Resolve inputs
|
// Resolve inputs
|
||||||
name := clean.DNSLabel(ctx.String("name"))
|
name := clean.DNSLabel(ctx.String("name"))
|
||||||
derivedName := false
|
derivedName := false
|
||||||
|
|
||||||
if name == "" { // default from config if set
|
if name == "" { // default from config if set
|
||||||
name = clean.DNSLabel(conf.NodeName())
|
name = clean.DNSLabel(conf.NodeName())
|
||||||
if name != "" {
|
if name != "" {
|
||||||
derivedName = true
|
derivedName = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if name == "" {
|
if name == "" {
|
||||||
return cli.Exit(fmt.Errorf("node name is required (use --name or set node-name)"), 2)
|
return cli.Exit(fmt.Errorf("node name is required (use --name or set node-name)"), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeRole := clean.TypeLowerDash(ctx.String("role"))
|
nodeRole := clean.TypeLowerDash(ctx.String("role"))
|
||||||
switch nodeRole {
|
switch nodeRole {
|
||||||
case "instance", "service":
|
case "instance", "service":
|
||||||
@@ -76,7 +79,6 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
|||||||
derivedPortal = true
|
derivedPortal = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// In dry-run, we allow empty portalURL (will print derived/empty values).
|
|
||||||
|
|
||||||
// Derive advertise/site URLs when omitted.
|
// Derive advertise/site URLs when omitted.
|
||||||
advertise := ctx.String("advertise-url")
|
advertise := ctx.String("advertise-url")
|
||||||
@@ -93,17 +95,20 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
|||||||
RotateDatabase: ctx.Bool("rotate"),
|
RotateDatabase: ctx.Bool("rotate"),
|
||||||
RotateSecret: ctx.Bool("rotate-secret"),
|
RotateSecret: ctx.Bool("rotate-secret"),
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we already have client credentials (e.g., re-register), include them so the
|
// If we already have client credentials (e.g., re-register), include them so the
|
||||||
// portal can verify and authorize UUID/name moves or metadata updates.
|
// portal can verify and authorize UUID/name moves or metadata updates.
|
||||||
if id, secret := strings.TrimSpace(conf.NodeClientID()), strings.TrimSpace(conf.NodeClientSecret()); id != "" && secret != "" {
|
if id, secret := strings.TrimSpace(conf.NodeClientID()), strings.TrimSpace(conf.NodeClientSecret()); id != "" && secret != "" {
|
||||||
payload.ClientID = id
|
payload.ClientID = id
|
||||||
payload.ClientSecret = secret
|
payload.ClientSecret = secret
|
||||||
}
|
}
|
||||||
|
|
||||||
if site != "" && site != advertise {
|
if site != "" && site != advertise {
|
||||||
payload.SiteUrl = site
|
payload.SiteUrl = site
|
||||||
}
|
}
|
||||||
b, _ := json.Marshal(payload)
|
b, _ := json.Marshal(payload)
|
||||||
|
|
||||||
|
// In dry-run, we allow empty portalURL (will print derived/empty values).
|
||||||
if ctx.Bool("dry-run") {
|
if ctx.Bool("dry-run") {
|
||||||
if ctx.Bool("json") {
|
if ctx.Bool("json") {
|
||||||
out := map[string]any{"portalUrl": portalURL, "payload": payload}
|
out := map[string]any{"portalUrl": portalURL, "payload": payload}
|
||||||
@@ -140,18 +145,22 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
|||||||
if portalURL == "" {
|
if portalURL == "" {
|
||||||
return cli.Exit(fmt.Errorf("portal URL is required (use --portal-url or set portal-url)"), 2)
|
return cli.Exit(fmt.Errorf("portal URL is required (use --portal-url or set portal-url)"), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
token := ctx.String("join-token")
|
token := ctx.String("join-token")
|
||||||
|
|
||||||
if token == "" {
|
if token == "" {
|
||||||
token = conf.JoinToken()
|
token = conf.JoinToken()
|
||||||
}
|
}
|
||||||
|
|
||||||
if token == "" {
|
if token == "" {
|
||||||
return cli.Exit(fmt.Errorf("portal token is required (use --join-token or set join-token)"), 2)
|
return cli.Exit(fmt.Errorf("portal token is required (use --join-token or set join-token)"), 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
// POST with bounded backoff on 429
|
// POST with bounded backoff on 429
|
||||||
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||||
|
|
||||||
var resp cluster.RegisterResponse
|
var resp cluster.RegisterResponse
|
||||||
if err := postWithBackoff(url, token, b, &resp); err != nil {
|
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
|
||||||
var httpErr *httpError
|
var httpErr *httpError
|
||||||
if errors.As(err, &httpErr) && httpErr.Status == http.StatusTooManyRequests {
|
if errors.As(err, &httpErr) && httpErr.Status == http.StatusTooManyRequests {
|
||||||
return cli.Exit(fmt.Errorf("portal rate-limited registration attempts"), 6)
|
return cli.Exit(fmt.Errorf("portal rate-limited registration attempts"), 6)
|
||||||
@@ -179,13 +188,17 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
|||||||
} else {
|
} else {
|
||||||
// Human-readable: node row and credentials if present (UUID first as primary identifier)
|
// Human-readable: node row and credentials if present (UUID first as primary identifier)
|
||||||
cols := []string{"UUID", "ClientID", "Name", "Role", "DB Driver", "DB Name", "DB User", "Host", "Port"}
|
cols := []string{"UUID", "ClientID", "Name", "Role", "DB Driver", "DB Name", "DB User", "Host", "Port"}
|
||||||
|
|
||||||
var dbName, dbUser string
|
var dbName, dbUser string
|
||||||
|
|
||||||
if resp.Database.Name != "" {
|
if resp.Database.Name != "" {
|
||||||
dbName = resp.Database.Name
|
dbName = resp.Database.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Database.User != "" {
|
if resp.Database.User != "" {
|
||||||
dbUser = resp.Database.User
|
dbUser = resp.Database.User
|
||||||
}
|
}
|
||||||
|
|
||||||
rows := [][]string{{resp.Node.UUID, resp.Node.ClientID, resp.Node.Name, resp.Node.Role, resp.Database.Driver, dbName, dbUser, resp.Database.Host, fmt.Sprintf("%d", resp.Database.Port)}}
|
rows := [][]string{{resp.Node.UUID, resp.Node.ClientID, resp.Node.Name, resp.Node.Role, resp.Database.Driver, dbName, dbUser, resp.Database.Host, fmt.Sprintf("%d", resp.Database.Port)}}
|
||||||
out, _ := report.RenderFormat(rows, cols, report.CliFormat(ctx))
|
out, _ := report.RenderFormat(rows, cols, report.CliFormat(ctx))
|
||||||
fmt.Printf("\n%s\n", out)
|
fmt.Printf("\n%s\n", out)
|
||||||
|
@@ -13,11 +13,12 @@ import (
|
|||||||
"github.com/photoprism/photoprism/pkg/txt/report"
|
"github.com/photoprism/photoprism/pkg/txt/report"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClusterSummaryCommand prints a minimal cluster summary (Portal-only).
|
// ClusterSummaryCommand prints a minimal cluster summary.
|
||||||
var ClusterSummaryCommand = &cli.Command{
|
var ClusterSummaryCommand = &cli.Command{
|
||||||
Name: "summary",
|
Name: "summary",
|
||||||
Usage: "Shows cluster summary (Portal-only)",
|
Usage: "Shows cluster summary",
|
||||||
Flags: report.CliFlags,
|
Flags: report.CliFlags,
|
||||||
|
Hidden: true, // Required for cluster-management only.
|
||||||
Action: clusterSummaryAction,
|
Action: clusterSummaryAction,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -97,7 +97,6 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
|
|||||||
|
|
||||||
// Ensure DB connection is open for each command run (some commands call Shutdown).
|
// Ensure DB connection is open for each command run (some commands call Shutdown).
|
||||||
if c := get.Config(); c != nil {
|
if c := get.Config(); c != nil {
|
||||||
_ = c.Init() // safe to call; re-opens DB if needed
|
|
||||||
c.RegisterDb() // (re)register provider
|
c.RegisterDb() // (re)register provider
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -110,5 +109,11 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
|
|||||||
err = cmd.Run(ctx, args...)
|
err = cmd.Run(ctx, args...)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Re-open the database after the command completed so follow-up checks
|
||||||
|
// (potentially issued by the test itself) have an active connection.
|
||||||
|
if c := get.Config(); c != nil {
|
||||||
|
c.RegisterDb()
|
||||||
|
}
|
||||||
|
|
||||||
return output, err
|
return output, err
|
||||||
}
|
}
|
||||||
|
@@ -81,9 +81,10 @@ func TestDownloadImpl_FileMethod_AutoSkipsRemux(t *testing.T) {
|
|||||||
if conf == nil {
|
if conf == nil {
|
||||||
t.Fatalf("missing test config")
|
t.Fatalf("missing test config")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure DB is initialized and registered (bypassing CLI InitConfig)
|
// Ensure DB is initialized and registered (bypassing CLI InitConfig)
|
||||||
_ = conf.Init()
|
|
||||||
conf.RegisterDb()
|
conf.RegisterDb()
|
||||||
|
|
||||||
// Override yt-dlp after config init (config may set dl.YtDlpBin)
|
// Override yt-dlp after config init (config may set dl.YtDlpBin)
|
||||||
dl.YtDlpBin = fake
|
dl.YtDlpBin = fake
|
||||||
t.Logf("using yt-dlp binary: %s", dl.YtDlpBin)
|
t.Logf("using yt-dlp binary: %s", dl.YtDlpBin)
|
||||||
@@ -125,7 +126,6 @@ func TestDownloadImpl_FileMethod_Skip_NoRemux(t *testing.T) {
|
|||||||
if conf == nil {
|
if conf == nil {
|
||||||
t.Fatalf("missing test config")
|
t.Fatalf("missing test config")
|
||||||
}
|
}
|
||||||
_ = conf.Init()
|
|
||||||
conf.RegisterDb()
|
conf.RegisterDb()
|
||||||
dl.YtDlpBin = fake
|
dl.YtDlpBin = fake
|
||||||
|
|
||||||
@@ -196,8 +196,9 @@ func TestDownloadImpl_FileMethod_Always_RemuxFails(t *testing.T) {
|
|||||||
if conf == nil {
|
if conf == nil {
|
||||||
t.Fatalf("missing test config")
|
t.Fatalf("missing test config")
|
||||||
}
|
}
|
||||||
_ = conf.Init()
|
|
||||||
conf.RegisterDb()
|
conf.RegisterDb()
|
||||||
|
|
||||||
dl.YtDlpBin = fake
|
dl.YtDlpBin = fake
|
||||||
|
|
||||||
err := runDownload(conf, DownloadOpts{
|
err := runDownload(conf, DownloadOpts{
|
||||||
|
8
internal/commands/flags.go
Normal file
8
internal/commands/flags.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import "github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
// JsonFlag returns the shared CLI flag definition for JSON output across commands.
|
||||||
|
func JsonFlag() *cli.BoolFlag {
|
||||||
|
return &cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"}
|
||||||
|
}
|
173
internal/commands/jwt_helpers.go
Normal file
173
internal/commands/jwt_helpers.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package commands
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
|
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||||
|
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||||
|
"github.com/photoprism/photoprism/internal/config"
|
||||||
|
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||||
|
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
|
||||||
|
"github.com/photoprism/photoprism/pkg/clean"
|
||||||
|
)
|
||||||
|
|
||||||
|
var allowedJWTScope = func() map[string]struct{} {
|
||||||
|
out := make(map[string]struct{}, len(acl.ResourceNames))
|
||||||
|
for _, res := range acl.ResourceNames {
|
||||||
|
out[res.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}()
|
||||||
|
|
||||||
|
// requirePortal returns a CLI error when the active configuration is not a portal node.
|
||||||
|
func requirePortal(conf *config.Config) error {
|
||||||
|
if conf == nil || !conf.IsPortal() {
|
||||||
|
return cli.Exit(errors.New("command requires a Portal node"), 2)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveNode finds a node by UUID, client ID, or DNS label using the portal registry.
|
||||||
|
func resolveNode(conf *config.Config, identifier string) (*reg.Node, error) {
|
||||||
|
if err := requirePortal(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
key := strings.TrimSpace(identifier)
|
||||||
|
if key == "" {
|
||||||
|
return nil, cli.Exit(errors.New("node identifier required"), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry, err := reg.NewClientRegistryWithConfig(conf)
|
||||||
|
if err != nil {
|
||||||
|
return nil, cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if node, err := registry.FindByNodeUUID(key); err == nil && node != nil {
|
||||||
|
return node, nil
|
||||||
|
}
|
||||||
|
if node, err := registry.FindByClientID(key); err == nil && node != nil {
|
||||||
|
return node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
name := clean.DNSLabel(key)
|
||||||
|
if name == "" {
|
||||||
|
return nil, cli.Exit(errors.New("invalid node identifier"), 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
node, err := registry.FindByName(name)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, reg.ErrNotFound) {
|
||||||
|
return nil, cli.Exit(fmt.Errorf("node %q not found", identifier), 3)
|
||||||
|
}
|
||||||
|
return nil, cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
return node, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeJWTClaims decodes the compact JWT and returns header and claims without verifying the signature.
|
||||||
|
func decodeJWTClaims(token string) (map[string]any, *clusterjwt.Claims, error) {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, nil, errors.New("jwt: token must contain three segments")
|
||||||
|
}
|
||||||
|
|
||||||
|
decode := func(segment string) ([]byte, error) {
|
||||||
|
return base64.RawURLEncoding.DecodeString(segment)
|
||||||
|
}
|
||||||
|
|
||||||
|
headerBytes, err := decode(parts[0])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
payloadBytes, err := decode(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var header map[string]any
|
||||||
|
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := &clusterjwt.Claims{}
|
||||||
|
if err := json.Unmarshal(payloadBytes, claims); err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return header, claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyPortalToken verifies a JWT using the portal's in-memory key manager.
|
||||||
|
func verifyPortalToken(conf *config.Config, token string, expected clusterjwt.ExpectedClaims) (*clusterjwt.Claims, error) {
|
||||||
|
if err := requirePortal(conf); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := get.JWTManager()
|
||||||
|
if manager == nil {
|
||||||
|
return nil, cli.Exit(errors.New("jwt issuer not available"), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks := manager.JWKS()
|
||||||
|
if jwks == nil || len(jwks.Keys) == 0 {
|
||||||
|
return nil, cli.Exit(errors.New("jwks key set is empty"), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
leeway := time.Duration(conf.JWTLeeway()) * time.Second
|
||||||
|
if leeway <= 0 {
|
||||||
|
leeway = 60 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := clusterjwt.VerifyTokenWithKeys(token, expected, jwks.Keys, leeway)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalizeScopes trims and de-duplicates scope values, falling back to defaults when necessary.
|
||||||
|
func normalizeScopes(values []string, defaults ...string) ([]string, error) {
|
||||||
|
src := values
|
||||||
|
if len(src) == 0 {
|
||||||
|
src = defaults
|
||||||
|
}
|
||||||
|
out := make([]string, 0, len(src))
|
||||||
|
seen := make(map[string]struct{}, len(src))
|
||||||
|
for _, raw := range src {
|
||||||
|
for _, parsed := range clean.Scopes(raw) {
|
||||||
|
scope := clean.Scope(parsed)
|
||||||
|
if scope == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[scope]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := allowedJWTScope[scope]; !ok {
|
||||||
|
return nil, cli.Exit(fmt.Errorf("unsupported scope %q", scope), 2)
|
||||||
|
}
|
||||||
|
seen[scope] = struct{}{}
|
||||||
|
out = append(out, scope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
return nil, cli.Exit(errors.New("at least one scope is required"), 2)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// printJSON pretty-prints the payload as JSON.
|
||||||
|
func printJSON(payload any) error {
|
||||||
|
data, err := json.MarshalIndent(payload, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return cli.Exit(err, 1)
|
||||||
|
}
|
||||||
|
fmt.Printf("%s\n", data)
|
||||||
|
return nil
|
||||||
|
}
|
@@ -16,7 +16,7 @@ var ShowCommandsCommand = &cli.Command{
|
|||||||
Name: "commands",
|
Name: "commands",
|
||||||
Usage: "Displays a structured catalog of CLI commands",
|
Usage: "Displays a structured catalog of CLI commands",
|
||||||
Flags: []cli.Flag{
|
Flags: []cli.Flag{
|
||||||
&cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"},
|
JsonFlag(),
|
||||||
&cli.BoolFlag{Name: "all", Usage: "include hidden commands and flags"},
|
&cli.BoolFlag{Name: "all", Usage: "include hidden commands and flags"},
|
||||||
&cli.BoolFlag{Name: "short", Usage: "omit flags in Markdown output"},
|
&cli.BoolFlag{Name: "short", Usage: "omit flags in Markdown output"},
|
||||||
&cli.IntFlag{Name: "base-heading", Value: 2, Usage: "base Markdown heading level"},
|
&cli.IntFlag{Name: "base-heading", Value: 2, Usage: "base Markdown heading level"},
|
||||||
|
@@ -6,6 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -43,9 +44,9 @@ func statusAction(ctx *cli.Context) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
url := fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
|
endpointUrl := buildStatusEndpoint(conf)
|
||||||
|
|
||||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
req, err := http.NewRequest(http.MethodGet, endpointUrl, nil)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -53,12 +54,12 @@ func statusAction(ctx *cli.Context) error {
|
|||||||
|
|
||||||
var status string
|
var status string
|
||||||
|
|
||||||
if resp, err := client.Do(req); err != nil {
|
if resp, reqErr := client.Do(req); reqErr != nil {
|
||||||
return fmt.Errorf("cannot connect to %s:%d", conf.HttpHost(), conf.HttpPort())
|
return fmt.Errorf("cannot connect to %s:%d", conf.HttpHost(), conf.HttpPort())
|
||||||
} else if resp.StatusCode != 200 {
|
} else if resp.StatusCode != 200 {
|
||||||
return fmt.Errorf("server running at %s:%d, bad status %d\n", conf.HttpHost(), conf.HttpPort(), resp.StatusCode)
|
return fmt.Errorf("server running at %s:%d, bad status %d\n", conf.HttpHost(), conf.HttpPort(), resp.StatusCode)
|
||||||
} else if body, err := io.ReadAll(resp.Body); err != nil {
|
} else if body, readErr := io.ReadAll(resp.Body); readErr != nil {
|
||||||
return err
|
return readErr
|
||||||
} else {
|
} else {
|
||||||
status = string(body)
|
status = string(body)
|
||||||
}
|
}
|
||||||
@@ -73,3 +74,21 @@ func statusAction(ctx *cli.Context) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildStatusEndpoint returns the status endpoint URL, preferring the public
|
||||||
|
// SiteUrl (which carries the correct scheme) and falling back to the local
|
||||||
|
// HTTP host/port. When a Unix socket is configured, an http+unix style URL is
|
||||||
|
// used so the custom transport can dial the socket.
|
||||||
|
func buildStatusEndpoint(conf *config.Config) string {
|
||||||
|
if socket := conf.HttpSocket(); socket != nil {
|
||||||
|
return fmt.Sprintf("%s://%s/api/v1/status", socket.Scheme, strings.TrimPrefix(socket.Path, "/"))
|
||||||
|
}
|
||||||
|
|
||||||
|
siteUrl := strings.TrimRight(conf.SiteUrl(), "/")
|
||||||
|
|
||||||
|
if siteUrl != "" {
|
||||||
|
return siteUrl + "/api/v1/status"
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
|
||||||
|
}
|
||||||
|
@@ -2,6 +2,8 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
|
"net"
|
||||||
|
urlpkg "net/url"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -222,7 +224,36 @@ func (c *Config) SetJWKSUrl(url string) {
|
|||||||
if c == nil || c.options == nil {
|
if c == nil || c.options == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.options.JWKSUrl = strings.TrimSpace(url)
|
|
||||||
|
trimmed := strings.TrimSpace(url)
|
||||||
|
if trimmed == "" {
|
||||||
|
c.options.JWKSUrl = ""
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := urlpkg.Parse(trimmed)
|
||||||
|
if err != nil || parsed == nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||||
|
log.Warnf("config: ignoring JWKS URL %q (%v)", trimmed, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
scheme := strings.ToLower(parsed.Scheme)
|
||||||
|
host := parsed.Hostname()
|
||||||
|
|
||||||
|
switch scheme {
|
||||||
|
case "https":
|
||||||
|
// Always allowed.
|
||||||
|
case "http":
|
||||||
|
if !isLoopbackHost(host) {
|
||||||
|
log.Warnf("config: rejecting JWKS URL %q (http only allowed for localhost/loopback)", trimmed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
log.Warnf("config: rejecting JWKS URL %q (unsupported scheme)", trimmed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.options.JWKSUrl = trimmed
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWKSCacheTTL returns the JWKS cache lifetime in seconds (default 300, max 3600).
|
// JWKSCacheTTL returns the JWKS cache lifetime in seconds (default 300, max 3600).
|
||||||
@@ -261,6 +292,23 @@ func (c *Config) AdvertiseUrl() string {
|
|||||||
return c.SiteUrl()
|
return c.SiteUrl()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isLoopbackHost returns true when host represents localhost or a loopback IP.
|
||||||
|
func isLoopbackHost(host string) bool {
|
||||||
|
if host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.EqualFold(host, "localhost") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip := net.ParseIP(host); ip != nil {
|
||||||
|
return ip.IsLoopback()
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// SaveClusterUUID writes or updates the ClusterUUID key in options.yml without
|
// SaveClusterUUID writes or updates the ClusterUUID key in options.yml without
|
||||||
// touching unrelated keys. Creates the file and directories if needed.
|
// touching unrelated keys. Creates the file and directories if needed.
|
||||||
func (c *Config) SaveClusterUUID(uuid string) error {
|
func (c *Config) SaveClusterUUID(uuid string) error {
|
||||||
|
@@ -73,15 +73,77 @@ func TestConfig_Cluster(t *testing.T) {
|
|||||||
c.Options().NodeRole = ""
|
c.Options().NodeRole = ""
|
||||||
})
|
})
|
||||||
t.Run("JWKSUrlSetter", func(t *testing.T) {
|
t.Run("JWKSUrlSetter", func(t *testing.T) {
|
||||||
c := NewConfig(CliTestContext())
|
const existing = "https://existing.example/.well-known/jwks.json"
|
||||||
c.options.JWKSUrl = ""
|
tests := []struct {
|
||||||
assert.Equal(t, "", c.JWKSUrl())
|
name string
|
||||||
|
prev string
|
||||||
|
input string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "TrimHTTPS",
|
||||||
|
prev: "",
|
||||||
|
input: " https://portal.example/.well-known/jwks.json ",
|
||||||
|
expect: "https://portal.example/.well-known/jwks.json",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CaseInsensitiveScheme",
|
||||||
|
prev: "",
|
||||||
|
input: "HTTPS://portal.example/.well-known/jwks.json",
|
||||||
|
expect: "HTTPS://portal.example/.well-known/jwks.json",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AllowHTTPOnLocalhost",
|
||||||
|
prev: "",
|
||||||
|
input: "http://localhost:2342/.well-known/jwks.json",
|
||||||
|
expect: "http://localhost:2342/.well-known/jwks.json",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AllowHTTPOnLoopbackIPv4",
|
||||||
|
prev: "",
|
||||||
|
input: "http://127.0.0.1/.well-known/jwks.json",
|
||||||
|
expect: "http://127.0.0.1/.well-known/jwks.json",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "AllowHTTPOnLoopbackIPv6",
|
||||||
|
prev: "",
|
||||||
|
input: "http://[::1]/.well-known/jwks.json",
|
||||||
|
expect: "http://[::1]/.well-known/jwks.json",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RejectHTTPNonLoopback",
|
||||||
|
prev: existing,
|
||||||
|
input: "http://portal.example/.well-known/jwks.json",
|
||||||
|
expect: existing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RejectUnsupportedScheme",
|
||||||
|
prev: existing,
|
||||||
|
input: "ftp://portal.example/.well-known/jwks.json",
|
||||||
|
expect: existing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RejectMalformedURL",
|
||||||
|
prev: existing,
|
||||||
|
input: "://not-a-url",
|
||||||
|
expect: existing,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ClearValue",
|
||||||
|
prev: existing,
|
||||||
|
input: "",
|
||||||
|
expect: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
c.SetJWKSUrl(" https://portal.example/.well-known/jwks.json ")
|
for _, tc := range tests {
|
||||||
assert.Equal(t, "https://portal.example/.well-known/jwks.json", c.JWKSUrl())
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
c := NewConfig(CliTestContext())
|
||||||
c.SetJWKSUrl("")
|
c.options.JWKSUrl = tc.prev
|
||||||
assert.Equal(t, "", c.JWKSUrl())
|
c.SetJWKSUrl(tc.input)
|
||||||
|
assert.Equal(t, tc.expect, c.JWKSUrl())
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
t.Run("Paths", func(t *testing.T) {
|
t.Run("Paths", func(t *testing.T) {
|
||||||
c := NewConfig(CliTestContext())
|
c := NewConfig(CliTestContext())
|
||||||
|
@@ -342,12 +342,18 @@ func (c *Config) SetDbOptions() {
|
|||||||
case Postgres:
|
case Postgres:
|
||||||
// Ignore for now.
|
// Ignore for now.
|
||||||
case SQLite3:
|
case SQLite3:
|
||||||
// Not required as unicode is default.
|
// Not required as Unicode is default.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterDb sets the database options and connection provider.
|
// RegisterDb opens a database connection if needed,
|
||||||
|
// sets the database options and connection provider.
|
||||||
func (c *Config) RegisterDb() {
|
func (c *Config) RegisterDb() {
|
||||||
|
if err := c.connectDb(); err != nil {
|
||||||
|
log.Errorf("config: %s (register db)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
c.SetDbOptions()
|
c.SetDbOptions()
|
||||||
entity.SetDbProvider(c)
|
entity.SetDbProvider(c)
|
||||||
}
|
}
|
||||||
@@ -456,6 +462,11 @@ func (c *Config) connectDb() error {
|
|||||||
mutex.Db.Lock()
|
mutex.Db.Lock()
|
||||||
defer mutex.Db.Unlock()
|
defer mutex.Db.Unlock()
|
||||||
|
|
||||||
|
// Database connection already exists.
|
||||||
|
if c.db != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get database driver and data source name.
|
// Get database driver and data source name.
|
||||||
dbDriver := c.DatabaseDriver()
|
dbDriver := c.DatabaseDriver()
|
||||||
dbDsn := c.DatabaseDSN()
|
dbDsn := c.DatabaseDSN()
|
||||||
|
@@ -227,13 +227,16 @@ func TestDownloadPlaylistEntry(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Download the same file but with the direct link
|
// Download the same file but with the direct link
|
||||||
url := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
|
dlUrl := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
|
||||||
|
|
||||||
stderrBuf = &bytes.Buffer{}
|
stderrBuf = &bytes.Buffer{}
|
||||||
r, err = NewMetadata(context.Background(), url, Options{
|
|
||||||
|
r, err = NewMetadata(context.Background(), dlUrl, Options{
|
||||||
StderrFn: func(cmd *exec.Cmd) io.Writer {
|
StderrFn: func(cmd *exec.Cmd) io.Writer {
|
||||||
return stderrBuf
|
return stderrBuf
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@@ -1,30 +1,42 @@
|
|||||||
package get
|
package get
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/photoprism/photoprism/internal/auth/jwt"
|
"github.com/photoprism/photoprism/internal/auth/jwt"
|
||||||
"github.com/photoprism/photoprism/pkg/clean"
|
"github.com/photoprism/photoprism/pkg/clean"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
onceJWTManager sync.Once
|
onceJWTManager sync.Once
|
||||||
onceJWTIssuer sync.Once
|
onceJWTIssuer sync.Once
|
||||||
|
onceJWTVerifier sync.Once
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// initJWTManager lazily initializes the shared portal key manager for JWT issuance.
|
||||||
func initJWTManager() {
|
func initJWTManager() {
|
||||||
conf := Config()
|
if conf == nil {
|
||||||
if conf == nil || !conf.IsPortal() {
|
return
|
||||||
|
} else if !conf.IsPortal() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := jwt.NewManager(conf)
|
manager, err := jwt.NewManager(conf)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("jwt: manager init failed (%s)", clean.Error(err))
|
log.Warnf("jwt: manager init failed (%s)", clean.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, err := manager.EnsureActiveKey(); err != nil {
|
|
||||||
|
if _, err = manager.EnsureActiveKey(); err != nil {
|
||||||
log.Warnf("jwt: ensure signing key failed (%s)", clean.Error(err))
|
log.Warnf("jwt: ensure signing key failed (%s)", clean.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
services.JWTManager = manager
|
services.JWTManager = manager
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,6 +46,7 @@ func JWTManager() *jwt.Manager {
|
|||||||
return services.JWTManager
|
return services.JWTManager
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initJWTIssuer lazily binds the shared issuer to the active portal key manager.
|
||||||
func initJWTIssuer() {
|
func initJWTIssuer() {
|
||||||
manager := JWTManager()
|
manager := JWTManager()
|
||||||
if manager == nil {
|
if manager == nil {
|
||||||
@@ -50,9 +63,98 @@ func JWTIssuer() *jwt.Issuer {
|
|||||||
|
|
||||||
// JWTVerifier returns a verifier bound to the current config.
|
// JWTVerifier returns a verifier bound to the current config.
|
||||||
func JWTVerifier() *jwt.Verifier {
|
func JWTVerifier() *jwt.Verifier {
|
||||||
conf := Config()
|
onceJWTVerifier.Do(initJWTVerifier)
|
||||||
if conf == nil {
|
return services.JWTVerifier
|
||||||
return nil
|
}
|
||||||
}
|
|
||||||
return jwt.NewVerifier(conf)
|
// VerifyJWT verifies a token using the shared verifier instance.
|
||||||
|
func VerifyJWT(ctx context.Context, token string, expected jwt.ExpectedClaims) (*jwt.Claims, error) {
|
||||||
|
verifier := JWTVerifier()
|
||||||
|
if verifier == nil {
|
||||||
|
return nil, errors.New("jwt: verifier not available")
|
||||||
|
}
|
||||||
|
return verifier.VerifyToken(ctx, token, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
// initJWTVerifier lazily constructs the shared verifier for the current configuration.
|
||||||
|
func initJWTVerifier() {
|
||||||
|
if conf != nil {
|
||||||
|
services.JWTVerifier = jwt.NewVerifier(conf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetJWTVerifier clears the cached verifier so it can be rebuilt for a new configuration.
|
||||||
|
func resetJWTVerifier() {
|
||||||
|
services.JWTVerifier = nil
|
||||||
|
onceJWTVerifier = sync.Once{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetJWTIssuer clears the cached issuer so it can be recreated for a new configuration.
|
||||||
|
func resetJWTIssuer() {
|
||||||
|
services.JWTIssuer = nil
|
||||||
|
onceJWTIssuer = sync.Once{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetJWTManager clears the cached key manager so subsequent calls reload keys for the active configuration.
|
||||||
|
func resetJWTManager() {
|
||||||
|
services.JWTManager = nil
|
||||||
|
onceJWTManager = sync.Once{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resetJWT clears all cached JWT helpers.
|
||||||
|
func resetJWT() {
|
||||||
|
resetJWTVerifier()
|
||||||
|
resetJWTIssuer()
|
||||||
|
resetJWTManager()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssuePortalJWT signs a token using the shared portal issuer with the provided claims.
|
||||||
|
func IssuePortalJWT(spec jwt.ClaimsSpec) (string, error) {
|
||||||
|
if issuer := JWTIssuer(); issuer == nil {
|
||||||
|
return "", errors.New("jwt: issuer not available")
|
||||||
|
} else {
|
||||||
|
return issuer.Issue(spec)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssuePortalJWTForNode issues a portal-signed JWT targeting the specified node UUID.
|
||||||
|
func IssuePortalJWTForNode(nodeUUID string, scopes []string, ttl time.Duration) (string, error) {
|
||||||
|
if conf == nil {
|
||||||
|
return "", errors.New("jwt: missing config")
|
||||||
|
} else if !conf.IsPortal() {
|
||||||
|
return "", errors.New("jwt: not supported on nodes")
|
||||||
|
}
|
||||||
|
|
||||||
|
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
|
||||||
|
if clusterUUID == "" {
|
||||||
|
return "", errors.New("jwt: cluster uuid not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
nodeUUID = strings.TrimSpace(nodeUUID)
|
||||||
|
if nodeUUID == "" {
|
||||||
|
return "", errors.New("jwt: node uuid required")
|
||||||
|
}
|
||||||
|
if len(scopes) == 0 {
|
||||||
|
return "", errors.New("jwt: at least one scope is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := make([]string, 0, len(scopes))
|
||||||
|
for _, s := range scopes {
|
||||||
|
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||||
|
normalized = append(normalized, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(normalized) == 0 {
|
||||||
|
return "", errors.New("jwt: at least one scope is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
spec := jwt.ClaimsSpec{
|
||||||
|
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||||
|
Subject: fmt.Sprintf("portal:%s", clusterUUID),
|
||||||
|
Audience: fmt.Sprintf("node:%s", nodeUUID),
|
||||||
|
Scope: normalized,
|
||||||
|
TTL: ttl,
|
||||||
|
}
|
||||||
|
|
||||||
|
return IssuePortalJWT(spec)
|
||||||
}
|
}
|
||||||
|
39
internal/photoprism/get/jwt_test.go
Normal file
39
internal/photoprism/get/jwt_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package get
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/photoprism/photoprism/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJWTVerifierReuse(t *testing.T) {
|
||||||
|
verifier1 := JWTVerifier()
|
||||||
|
require.NotNil(t, verifier1)
|
||||||
|
|
||||||
|
verifier2 := JWTVerifier()
|
||||||
|
require.NotNil(t, verifier2)
|
||||||
|
|
||||||
|
assert.Same(t, verifier1, verifier2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTVerifierResetOnConfigChange(t *testing.T) {
|
||||||
|
orig := Config()
|
||||||
|
verifier1 := JWTVerifier()
|
||||||
|
require.NotNil(t, verifier1)
|
||||||
|
|
||||||
|
tempConf := config.NewMinimalTestConfigWithDb("jwt-reset", t.TempDir())
|
||||||
|
SetConfig(tempConf)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
SetConfig(orig)
|
||||||
|
tempConf.CloseDb()
|
||||||
|
orig.RegisterDb()
|
||||||
|
})
|
||||||
|
|
||||||
|
verifier2 := JWTVerifier()
|
||||||
|
require.NotNil(t, verifier2)
|
||||||
|
|
||||||
|
assert.NotSame(t, verifier1, verifier2)
|
||||||
|
}
|
@@ -57,13 +57,17 @@ var services struct {
|
|||||||
OIDC *oidc.Client
|
OIDC *oidc.Client
|
||||||
JWTManager *clusterjwt.Manager
|
JWTManager *clusterjwt.Manager
|
||||||
JWTIssuer *clusterjwt.Issuer
|
JWTIssuer *clusterjwt.Issuer
|
||||||
|
JWTVerifier *clusterjwt.Verifier
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetConfig(c *config.Config) {
|
func SetConfig(c *config.Config) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
log.Panic("panic: argument is nil in get.SetConfig(c *config.Config)")
|
log.Panic("panic: argument is nil in get.SetConfig(c *config.Config)")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
resetJWT()
|
||||||
|
|
||||||
conf = c
|
conf = c
|
||||||
|
|
||||||
photoprism.SetConfig(c)
|
photoprism.SetConfig(c)
|
||||||
@@ -72,6 +76,7 @@ func SetConfig(c *config.Config) {
|
|||||||
func Config() *config.Config {
|
func Config() *config.Config {
|
||||||
if conf == nil {
|
if conf == nil {
|
||||||
log.Panic("panic: conf is nil in get.Config()")
|
log.Panic("panic: conf is nil in get.Config()")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return conf
|
return conf
|
||||||
|
@@ -229,9 +229,9 @@ func persistRegistration(c *config.Config, r *cluster.RegisterResponse, wantRota
|
|||||||
updates["NodeClientSecret"] = r.Secrets.ClientSecret
|
updates["NodeClientSecret"] = r.Secrets.ClientSecret
|
||||||
}
|
}
|
||||||
|
|
||||||
if url := strings.TrimSpace(r.JWKSUrl); url != "" {
|
if jwksUrl := strings.TrimSpace(r.JWKSUrl); jwksUrl != "" {
|
||||||
updates["JWKSUrl"] = url
|
updates["JWKSUrl"] = jwksUrl
|
||||||
c.SetJWKSUrl(url)
|
c.SetJWKSUrl(jwksUrl)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Persist NodeUUID from portal response if provided and not set locally.
|
// Persist NodeUUID from portal response if provided and not set locally.
|
||||||
|
@@ -221,13 +221,13 @@ func (c *Config) ReSync(token string) (err error) {
|
|||||||
// interrupt reading of the Response.Body.
|
// interrupt reading of the Response.Body.
|
||||||
client := &http.Client{Timeout: 60 * time.Second}
|
client := &http.Client{Timeout: 60 * time.Second}
|
||||||
|
|
||||||
url := ServiceURL
|
endpointUrl := ServiceURL
|
||||||
method := http.MethodPost
|
method := http.MethodPost
|
||||||
|
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
|
|
||||||
if c.Key != "" {
|
if c.Key != "" {
|
||||||
url = fmt.Sprintf(ServiceURL+"/%s", c.Key)
|
endpointUrl = fmt.Sprintf(ServiceURL+"/%s", c.Key)
|
||||||
method = http.MethodPut
|
method = http.MethodPut
|
||||||
log.Tracef("config: requesting updated keys for maps and places")
|
log.Tracef("config: requesting updated keys for maps and places")
|
||||||
} else {
|
} else {
|
||||||
@@ -239,7 +239,7 @@ func (c *Config) ReSync(token string) (err error) {
|
|||||||
|
|
||||||
if j, err = json.Marshal(NewRequest(c.Version, c.Serial, c.Env, c.PartnerID, token)); err != nil {
|
if j, err = json.Marshal(NewRequest(c.Version, c.Serial, c.Env, c.PartnerID, token)); err != nil {
|
||||||
return err
|
return err
|
||||||
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
|
} else if req, err = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -67,17 +67,17 @@ func (c *Config) SendFeedback(frm form.Feedback) (err error) {
|
|||||||
// interrupt reading of the Response.Body.
|
// interrupt reading of the Response.Body.
|
||||||
client := &http.Client{Timeout: 60 * time.Second}
|
client := &http.Client{Timeout: 60 * time.Second}
|
||||||
|
|
||||||
url := fmt.Sprintf(FeedbackURL, c.Key)
|
endpointUrl := fmt.Sprintf(FeedbackURL, c.Key)
|
||||||
method := http.MethodPost
|
method := http.MethodPost
|
||||||
|
|
||||||
var req *http.Request
|
var req *http.Request
|
||||||
|
|
||||||
log.Debugf("sending feedback to %s", ApiHost())
|
log.Debugf("sending feedback to %s", ApiHost())
|
||||||
|
|
||||||
if j, err := json.Marshal(feedback); err != nil {
|
if j, reqErr := json.Marshal(feedback); reqErr != nil {
|
||||||
return err
|
return reqErr
|
||||||
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
|
} else if req, reqErr = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); reqErr != nil {
|
||||||
return err
|
return reqErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set user agent.
|
// Set user agent.
|
||||||
|
Reference in New Issue
Block a user