CLI: Added JWT issuance and diagnostics sub commands #5230

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2025-09-26 02:38:49 +02:00
parent 566eed05e0
commit 32c054da7a
35 changed files with 1354 additions and 84 deletions

View File

@@ -50,7 +50,6 @@ func authAnyJWT(c *gin.Context, clientIP, authToken string, resource acl.Resourc
return nil
}
verifier := clusterjwt.NewVerifier(conf)
requiredScopes := []string{"cluster"}
if resource == acl.ResourceVision {
requiredScopes = []string{"vision"}
@@ -77,7 +76,7 @@ func authAnyJWT(c *gin.Context, clientIP, authToken string, resource acl.Resourc
for _, issuer := range issuers {
expected.Issuer = issuer
claims, err = verifier.VerifyToken(ctx, authToken, expected)
claims, err = get.VerifyJWT(ctx, authToken, expected)
if err == nil {
break
}

View File

@@ -25,12 +25,12 @@ func TestEcho(t *testing.T) {
t.Logf("Response Body: %s", r.Body.String())
body := r.Body.String()
url := gjson.Get(body, "url").String()
bodyUrl := gjson.Get(body, "url").String()
method := gjson.Get(body, "method").String()
request := gjson.Get(body, "headers.request")
response := gjson.Get(body, "headers.response")
assert.Equal(t, "/api/v1/echo", url)
assert.Equal(t, "/api/v1/echo", bodyUrl)
assert.Equal(t, "GET", method)
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
@@ -49,12 +49,12 @@ func TestEcho(t *testing.T) {
r := AuthenticatedRequest(app, http.MethodPost, "/api/v1/echo", authToken)
body := r.Body.String()
url := gjson.Get(body, "url").String()
bodyUrl := gjson.Get(body, "url").String()
method := gjson.Get(body, "method").String()
request := gjson.Get(body, "headers.request")
response := gjson.Get(body, "headers.response")
assert.Equal(t, "/api/v1/echo", url)
assert.Equal(t, "/api/v1/echo", bodyUrl)
assert.Equal(t, "POST", method)
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())

View File

@@ -25,6 +25,20 @@ var (
errKeyNotFound = errors.New("jwt: key not found")
)
// VerifierStatus captures diagnostic information about a verifier's JWKS cache state.
type VerifierStatus struct {
CacheURL string `json:"cacheUrl,omitempty"`
CacheETag string `json:"cacheEtag,omitempty"`
KeyIDs []string `json:"keyIds,omitempty"`
KeyCount int `json:"keyCount"`
CacheFetchedAt time.Time `json:"cacheFetchedAt,omitempty"`
CacheAgeSeconds int64 `json:"cacheAgeSeconds"`
CacheTTLSeconds int `json:"cacheTtlSeconds"`
CacheStale bool `json:"cacheStale"`
CachePath string `json:"cachePath,omitempty"`
JWKSURL string `json:"jwksUrl,omitempty"`
}
const (
// jwksFetchMaxRetries caps the number of immediate retry attempts after a fetch error.
jwksFetchMaxRetries = 3
@@ -99,15 +113,14 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
if strings.TrimSpace(expected.Audience) == "" {
return nil, errors.New("jwt: expected audience required")
}
if len(expected.Scope) == 0 {
return nil, errors.New("jwt: expected scope required")
jwksUrl := strings.TrimSpace(expected.JWKSURL)
if jwksUrl == "" && v.conf != nil {
jwksUrl = strings.TrimSpace(v.conf.JWKSUrl())
}
url := strings.TrimSpace(expected.JWKSURL)
if url == "" && v.conf != nil {
url = strings.TrimSpace(v.conf.JWKSUrl())
}
if url == "" {
if jwksUrl == "" {
return nil, errors.New("jwt: jwks url not configured")
}
@@ -126,16 +139,111 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
claims := &Claims{}
keyFunc := func(token *gojwt.Token) (interface{}, error) {
kid, _ := token.Header["kid"].(string)
if kid == "" {
return nil, errors.New("jwt: missing kid header")
}
pk, err := v.publicKeyForKid(ctx, url, kid, false)
pk, err := v.publicKeyForKid(ctx, jwksUrl, kid, false)
if errors.Is(err, errKeyNotFound) {
pk, err = v.publicKeyForKid(ctx, url, kid, true)
pk, err = v.publicKeyForKid(ctx, jwksUrl, kid, true)
}
if err != nil {
return nil, err
}
return pk, nil
}
if _, err := parser.ParseWithClaims(tokenString, claims, keyFunc); err != nil {
return nil, err
}
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
return nil, errors.New("jwt: missing temporal claims")
}
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
return nil, errors.New("jwt: token ttl exceeds maximum")
}
scopeSet := map[string]struct{}{}
for _, s := range strings.Fields(claims.Scope) {
scopeSet[s] = struct{}{}
}
for _, req := range expected.Scope {
if _, ok := scopeSet[req]; !ok {
return nil, fmt.Errorf("jwt: missing scope %s", req)
}
}
return claims, nil
}
// VerifyTokenWithKeys verifies a token using the provided JWKS keys without performing HTTP fetches.
func VerifyTokenWithKeys(tokenString string, expected ExpectedClaims, keys []PublicJWK, leeway time.Duration) (*Claims, error) {
if strings.TrimSpace(tokenString) == "" {
return nil, errors.New("jwt: token is empty")
}
if len(keys) == 0 {
return nil, errors.New("jwt: no jwks keys provided")
}
if leeway <= 0 {
leeway = 60 * time.Second
}
keyMap := make(map[string]ed25519.PublicKey, len(keys))
for _, jwk := range keys {
if jwk.Kid == "" {
continue
}
raw, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, err
}
if len(raw) != ed25519.PublicKeySize {
return nil, fmt.Errorf("jwt: invalid public key length %d", len(raw))
}
pk := make(ed25519.PublicKey, ed25519.PublicKeySize)
copy(pk, raw)
keyMap[jwk.Kid] = pk
}
if len(keyMap) == 0 {
return nil, errors.New("jwt: no valid jwks keys provided")
}
options := []gojwt.ParserOption{
gojwt.WithLeeway(leeway),
gojwt.WithValidMethods([]string{gojwt.SigningMethodEdDSA.Alg()}),
}
if iss := strings.TrimSpace(expected.Issuer); iss != "" {
options = append(options, gojwt.WithIssuer(iss))
}
if aud := strings.TrimSpace(expected.Audience); aud != "" {
options = append(options, gojwt.WithAudience(aud))
}
parser := gojwt.NewParser(options...)
claims := &Claims{}
keyFunc := func(token *gojwt.Token) (interface{}, error) {
kid, _ := token.Header["kid"].(string)
if kid == "" {
return nil, errors.New("jwt: missing kid header")
}
pk, ok := keyMap[kid]
if !ok {
return nil, errKeyNotFound
}
return pk, nil
}
@@ -146,29 +254,70 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
return nil, errors.New("jwt: missing temporal claims")
}
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
return nil, errors.New("jwt: token ttl exceeds maximum")
}
scopeSet := map[string]struct{}{}
for _, s := range strings.Fields(claims.Scope) {
scopeSet[s] = struct{}{}
}
for _, req := range expected.Scope {
if _, ok := scopeSet[req]; !ok {
return nil, fmt.Errorf("jwt: missing scope %s", req)
if len(expected.Scope) > 0 {
scopeSet := map[string]struct{}{}
for _, s := range strings.Fields(claims.Scope) {
scopeSet[s] = struct{}{}
}
for _, req := range expected.Scope {
if _, ok := scopeSet[req]; !ok {
return nil, fmt.Errorf("jwt: missing scope %s", req)
}
}
}
return claims, nil
}
// Status returns diagnostic information about the verifier's current JWKS cache.
func (v *Verifier) Status(ttl time.Duration) VerifierStatus {
status := VerifierStatus{}
if ttl > 0 {
status.CacheTTLSeconds = int(ttl / time.Second)
}
v.mu.Lock()
defer v.mu.Unlock()
status.CacheURL = v.cache.URL
status.CacheETag = v.cache.ETag
status.JWKSURL = v.cache.URL
status.KeyCount = len(v.cache.Keys)
status.KeyIDs = make([]string, 0, len(v.cache.Keys))
for _, key := range v.cache.Keys {
status.KeyIDs = append(status.KeyIDs, key.Kid)
}
status.CachePath = v.cachePath
if v.cache.FetchedAt > 0 {
fetched := time.Unix(v.cache.FetchedAt, 0).UTC()
status.CacheFetchedAt = fetched
age := time.Since(fetched)
status.CacheAgeSeconds = int64(age.Seconds())
if ttl > 0 && age > ttl {
status.CacheStale = true
}
}
return status
}
// publicKeyForKid resolves the public key for the given key ID, fetching JWKS data if needed.
func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force bool) (ed25519.PublicKey, error) {
keys, err := v.keysForURL(ctx, url, force)
if err != nil {
return nil, err
}
for _, k := range keys {
if k.Kid != kid {
continue
@@ -184,12 +333,14 @@ func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force b
copy(pk, raw)
return pk, nil
}
return nil, errKeyNotFound
}
// keysForURL returns JWKS keys for the specified endpoint, reusing cache when possible.
func (v *Verifier) keysForURL(ctx context.Context, url string, force bool) ([]PublicJWK, error) {
ttl := 300 * time.Second
if v.conf != nil && v.conf.JWKSCacheTTL() > 0 {
ttl = time.Duration(v.conf.JWKSCacheTTL()) * time.Second
}
@@ -250,13 +401,16 @@ func (v *Verifier) cachedKeys(url string, ttl time.Duration, cache cacheEntry, f
if force || cache.URL != url || len(cache.Keys) == 0 {
return nil, false
}
age := v.now().Unix() - cache.FetchedAt
if age < 0 {
return nil, false
}
if time.Duration(age)*time.Second > ttl {
return nil, false
}
return append([]PublicJWK(nil), cache.Keys...), true
}
@@ -270,17 +424,21 @@ type jwksFetchResult struct {
// fetchJWKS downloads the JWKS document (respecting conditional requests) and returns the parsed keys.
func (v *Verifier) fetchJWKS(ctx context.Context, url, etag string) (*jwksFetchResult, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := v.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
switch resp.StatusCode {
@@ -331,6 +489,7 @@ func (v *Verifier) updateCache(url string, result *jwksFetchResult) ([]PublicJWK
Keys: append([]PublicJWK(nil), result.keys...),
FetchedAt: result.fetchedAt,
}
_ = v.saveCacheLocked()
return append([]PublicJWK(nil), v.cache.Keys...), true
}
@@ -347,7 +506,7 @@ func (v *Verifier) loadCache() error {
}
var entry cacheEntry
if err := json.Unmarshal(b, &entry); err != nil {
if err = json.Unmarshal(b, &entry); err != nil {
return err
}
@@ -360,13 +519,17 @@ func (v *Verifier) saveCacheLocked() error {
if v.cachePath == "" {
return nil
}
if err := fs.MkdirAll(filepath.Dir(v.cachePath)); err != nil {
return err
}
data, err := json.Marshal(v.cache)
if err != nil {
return err
}
return os.WriteFile(v.cachePath, data, fs.ModeSecretFile)
}
@@ -377,11 +540,13 @@ func backoffDuration(attempt int) time.Duration {
}
base := jwksFetchBaseDelay << (attempt - 1)
if base > jwksFetchMaxDelay {
base = jwksFetchMaxDelay
}
jitterRange := base / 2
if jitterRange > 0 {
base += time.Duration(randInt63n(int64(jitterRange) + 1))
}

View File

@@ -103,6 +103,56 @@ func TestVerifierPrimeAndVerify(t *testing.T) {
require.Error(t, err)
}
func TestVerifyTokenWithKeys(t *testing.T) {
portalCfg := newTestConfig(t)
clusterUUID := rnd.UUIDv7()
portalCfg.Options().ClusterUUID = clusterUUID
mgr, err := NewManager(portalCfg)
require.NoError(t, err)
mgr.now = func() time.Time { return time.Date(2025, 9, 24, 10, 30, 0, 0, time.UTC) }
_, err = mgr.EnsureActiveKey()
require.NoError(t, err)
issuer := NewIssuer(mgr)
issuer.now = func() time.Time { return time.Now().UTC() }
spec := ClaimsSpec{
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
Subject: "portal:client-test",
Audience: "node:1234",
Scope: []string{"cluster"},
}
token, err := issuer.Issue(spec)
require.NoError(t, err)
keys := mgr.JWKS().Keys
claims, err := VerifyTokenWithKeys(token, ExpectedClaims{
Issuer: spec.Issuer,
Audience: spec.Audience,
Scope: []string{"cluster"},
}, keys, 60*time.Second)
require.NoError(t, err)
require.Equal(t, spec.Subject, claims.Subject)
// Ensure scope filtering is honored when expected scope is empty.
claims, err = VerifyTokenWithKeys(token, ExpectedClaims{
Issuer: spec.Issuer,
Audience: spec.Audience,
}, keys, 60*time.Second)
require.NoError(t, err)
require.Equal(t, spec.Subject, claims.Subject)
// Missing scope should fail when explicitly required.
_, err = VerifyTokenWithKeys(token, ExpectedClaims{
Issuer: spec.Issuer,
Audience: spec.Audience,
Scope: []string{"vision"},
}, keys, 60*time.Second)
require.Error(t, err)
}
func TestIssuerClampTTL(t *testing.T) {
portalCfg := newTestConfig(t)
mgr, err := NewManager(portalCfg)

View File

@@ -15,6 +15,7 @@ var AuthCommands = &cli.Command{
AuthShowCommand,
AuthRemoveCommand,
AuthResetCommand,
AuthJWTCommands,
},
}

View File

@@ -0,0 +1,16 @@
package commands
import "github.com/urfave/cli/v2"
// AuthJWTCommands groups JWT-related auth helpers under photoprism auth jwt.
var AuthJWTCommands = &cli.Command{
Name: "jwt",
Usage: "JWT issuance and diagnostics",
Hidden: true, // Required for cluster-management only.
Subcommands: []*cli.Command{
AuthJWTIssueCommand,
AuthJWTInspectCommand,
AuthJWTKeysCommand,
AuthJWTStatusCommand,
},
}

View File

@@ -0,0 +1,154 @@
package commands
import (
"errors"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/urfave/cli/v2"
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/pkg/clean"
)
// AuthJWTInspectCommand inspects and verifies portal-issued JWTs.
var AuthJWTInspectCommand = &cli.Command{
Name: "inspect",
Usage: "Decodes and verifies a portal JWT",
ArgsUsage: "<token>",
Flags: []cli.Flag{
&cli.StringFlag{Name: "file", Aliases: []string{"f"}, Usage: "read token from file"},
&cli.StringFlag{Name: "expect-audience", Usage: "expected audience (e.g., node:<uuid>)"},
&cli.StringSliceFlag{Name: "require-scope", Usage: "require specific scope(s)"},
&cli.BoolFlag{Name: "skip-verify", Usage: "decode without signature verification"},
JsonFlag(),
},
Action: authJWTInspectAction,
}
// authJWTInspectAction decodes and optionally verifies a portal-issued JWT.
func authJWTInspectAction(ctx *cli.Context) error {
return CallWithDependencies(ctx, func(conf *config.Config) error {
if err := requirePortal(conf); err != nil {
return err
}
token, err := readTokenInput(ctx)
if err != nil {
return err
}
header, claims, err := decodeJWTClaims(token)
if err != nil {
return cli.Exit(err, 1)
}
var verified bool
tokenScopes := clean.Scopes(claims.Scope)
if !ctx.Bool("skip-verify") {
expected := clusterjwt.ExpectedClaims{}
if clusterUUID := strings.TrimSpace(conf.ClusterUUID()); clusterUUID != "" {
expected.Issuer = fmt.Sprintf("portal:%s", clusterUUID)
} else if portal := strings.TrimSpace(conf.PortalUrl()); portal != "" {
expected.Issuer = strings.TrimRight(portal, "/")
}
if expectAud := strings.TrimSpace(ctx.String("expect-audience")); expectAud != "" {
expected.Audience = expectAud
} else if len(claims.Audience) > 0 {
expected.Audience = claims.Audience[0]
}
if required := ctx.StringSlice("require-scope"); len(required) > 0 {
scopes, scopeErr := normalizeScopes(required)
if scopeErr != nil {
return scopeErr
}
expected.Scope = scopes
} else {
expected.Scope = tokenScopes
}
if _, err := verifyPortalToken(conf, token, expected); err != nil {
return cli.Exit(err, 1)
}
verified = true
}
if ctx.Bool("json") {
payload := map[string]any{
"token": token,
"verified": verified,
"header": header,
"claims": claims,
}
return printJSON(payload)
}
fmt.Println()
fmt.Println("JWT header:")
for k, v := range header {
fmt.Printf(" %s: %v\n", k, v)
}
fmt.Println("\nJWT claims:")
fmt.Printf(" issuer: %s\n", claims.Issuer)
fmt.Printf(" subject: %s\n", claims.Subject)
fmt.Printf(" audience: %s\n", strings.Join(claims.Audience, " "))
fmt.Printf(" scope: %s\n", strings.Join(tokenScopes, " "))
if claims.IssuedAt != nil {
fmt.Printf(" issuedAt: %s\n", claims.IssuedAt.Time.UTC().Format(time.RFC3339))
}
if claims.ExpiresAt != nil {
fmt.Printf(" expiresAt: %s\n", claims.ExpiresAt.Time.UTC().Format(time.RFC3339))
}
if claims.NotBefore != nil {
fmt.Printf(" notBefore: %s\n", claims.NotBefore.Time.UTC().Format(time.RFC3339))
}
if claims.ID != "" {
fmt.Printf(" jti: %s\n", claims.ID)
}
if verified {
fmt.Println("\nSignature: verified")
} else {
fmt.Println("\nSignature: not verified (skipped)")
}
fmt.Printf("\nToken:\n%s\n\n", token)
return nil
})
}
// readTokenInput loads the token from CLI args, file, or STDIN.
func readTokenInput(ctx *cli.Context) (string, error) {
if file := strings.TrimSpace(ctx.String("file")); file != "" {
data, err := os.ReadFile(file)
if err != nil {
return "", cli.Exit(err, 1)
}
return strings.TrimSpace(string(data)), nil
}
if ctx.Args().Len() == 0 {
return "", cli.Exit(errors.New("token argument required"), 2)
}
token := strings.TrimSpace(ctx.Args().First())
if token == "-" {
data, err := io.ReadAll(os.Stdin)
if err != nil {
return "", cli.Exit(err, 1)
}
token = strings.TrimSpace(string(data))
}
if token == "" {
return "", cli.Exit(errors.New("token argument required"), 2)
}
return token, nil
}

View File

@@ -0,0 +1,117 @@
package commands
import (
"fmt"
"strings"
"time"
"github.com/urfave/cli/v2"
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/photoprism/get"
)
// AuthJWTIssueCommand issues portal-signed JWTs for cluster nodes.
var AuthJWTIssueCommand = &cli.Command{
Name: "issue",
Usage: "Issues a portal-signed JWT for a node",
Flags: []cli.Flag{
&cli.StringFlag{Name: "node", Aliases: []string{"n"}, Usage: "target node uuid, client id, or DNS label", Required: true},
&cli.StringSliceFlag{Name: "scope", Aliases: []string{"s"}, Usage: "token scope", Value: cli.NewStringSlice("cluster")},
&cli.DurationFlag{Name: "ttl", Usage: "token lifetime", Value: clusterjwt.TokenTTL},
&cli.StringFlag{Name: "subject", Usage: "token subject (default portal:<clusterUUID>)"},
JsonFlag(),
},
Action: authJWTIssueAction,
}
// authJWTIssueAction handles CLI issuance of portal-signed JWTs for nodes.
func authJWTIssueAction(ctx *cli.Context) error {
return CallWithDependencies(ctx, func(conf *config.Config) error {
if err := requirePortal(conf); err != nil {
return err
}
node, err := resolveNode(conf, ctx.String("node"))
if err != nil {
return err
}
scopes, err := normalizeScopes(ctx.StringSlice("scope"), "cluster")
if err != nil {
return err
}
ttl := ctx.Duration("ttl")
if ttl <= 0 {
ttl = clusterjwt.TokenTTL
}
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
if clusterUUID == "" {
return cli.Exit(fmt.Errorf("cluster uuid not configured"), 1)
}
subject := strings.TrimSpace(ctx.String("subject"))
if subject == "" {
subject = fmt.Sprintf("portal:%s", clusterUUID)
}
var token string
if subject == fmt.Sprintf("portal:%s", clusterUUID) {
token, err = get.IssuePortalJWTForNode(node.UUID, scopes, ttl)
} else {
spec := clusterjwt.ClaimsSpec{
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
Subject: subject,
Audience: fmt.Sprintf("node:%s", node.UUID),
Scope: scopes,
TTL: ttl,
}
token, err = get.IssuePortalJWT(spec)
}
if err != nil {
return cli.Exit(err, 1)
}
header, claims, err := decodeJWTClaims(token)
if err != nil {
return cli.Exit(err, 1)
}
if ctx.Bool("json") {
payload := map[string]any{
"token": token,
"header": header,
"claims": claims,
"node": map[string]string{
"uuid": node.UUID,
"clientId": node.ClientID,
"name": node.Name,
"role": string(node.Role),
},
}
return printJSON(payload)
}
expires := "unknown"
if claims.ExpiresAt != nil {
expires = claims.ExpiresAt.Time.UTC().Format(time.RFC3339)
}
audience := strings.Join(claims.Audience, " ")
if audience == "" {
audience = "(none)"
}
fmt.Printf("\nIssued JWT for node %s (%s)\n", node.Name, node.UUID)
fmt.Printf("Scopes: %s\n", strings.Join(scopes, " "))
fmt.Printf("Expires: %s\n", expires)
fmt.Printf("Audience: %s\n", audience)
fmt.Printf("Subject: %s\n", claims.Subject)
fmt.Printf("Key ID: %v\n", header["kid"])
fmt.Printf("\n%s\n", token)
return nil
})
}

View File

@@ -0,0 +1,107 @@
package commands
import (
"errors"
"fmt"
"strings"
"time"
"github.com/urfave/cli/v2"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/photoprism/get"
)
// AuthJWTKeysCommand groups JWT key management helpers.
var AuthJWTKeysCommand = &cli.Command{
Name: "keys",
Usage: "JWT signing key helpers",
Subcommands: []*cli.Command{
AuthJWTKeysListCommand,
},
}
// AuthJWTKeysListCommand lists JWT signing keys.
var AuthJWTKeysListCommand = &cli.Command{
Name: "ls",
Usage: "Lists JWT signing keys",
Aliases: []string{"list"},
ArgsUsage: "",
Flags: []cli.Flag{
JsonFlag(),
},
Action: authJWTKeysListAction,
}
// authJWTKeysListAction lists portal signing keys with metadata.
func authJWTKeysListAction(ctx *cli.Context) error {
return CallWithDependencies(ctx, func(conf *config.Config) error {
if err := requirePortal(conf); err != nil {
return err
}
manager := get.JWTManager()
if manager == nil {
return cli.Exit(errors.New("jwt manager not available"), 1)
}
keys := manager.AllKeys()
active, _ := manager.ActiveKey()
activeKid := ""
if active != nil {
activeKid = active.Kid
}
type keyInfo struct {
Kid string `json:"kid"`
CreatedAt string `json:"createdAt"`
NotAfter string `json:"notAfter,omitempty"`
Active bool `json:"active"`
}
rows := make([]keyInfo, 0, len(keys))
for _, k := range keys {
info := keyInfo{Kid: k.Kid, Active: k.Kid == activeKid}
if k.CreatedAt > 0 {
info.CreatedAt = time.Unix(k.CreatedAt, 0).UTC().Format(time.RFC3339)
}
if k.NotAfter > 0 {
info.NotAfter = time.Unix(k.NotAfter, 0).UTC().Format(time.RFC3339)
}
rows = append(rows, info)
}
if ctx.Bool("json") {
payload := map[string]any{
"keys": rows,
}
return printJSON(payload)
}
if len(rows) == 0 {
fmt.Println()
fmt.Println("No signing keys found.")
fmt.Println()
return nil
}
fmt.Println()
fmt.Println("JWT signing keys:")
for _, row := range rows {
status := ""
if row.Active {
status = " (active)"
}
parts := []string{fmt.Sprintf("KID: %s%s", row.Kid, status)}
if row.CreatedAt != "" {
parts = append(parts, fmt.Sprintf("created %s", row.CreatedAt))
}
if row.NotAfter != "" {
parts = append(parts, fmt.Sprintf("expires %s", row.NotAfter))
}
fmt.Printf("- %s\n", strings.Join(parts, ", "))
}
fmt.Println()
return nil
})
}

View File

@@ -0,0 +1,67 @@
package commands
import (
"errors"
"fmt"
"strings"
"time"
"github.com/urfave/cli/v2"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/photoprism/get"
)
// AuthJWTStatusCommand reports verifier cache diagnostics.
var AuthJWTStatusCommand = &cli.Command{
Name: "status",
Usage: "Shows JWT verifier cache status",
Flags: []cli.Flag{
JsonFlag(),
},
Action: authJWTStatusAction,
}
// authJWTStatusAction prints JWKS cache diagnostics for the current node.
func authJWTStatusAction(ctx *cli.Context) error {
return CallWithDependencies(ctx, func(conf *config.Config) error {
verifier := get.JWTVerifier()
if verifier == nil {
return cli.Exit(errors.New("jwt verifier not available"), 1)
}
ttl := time.Duration(conf.JWKSCacheTTL()) * time.Second
status := verifier.Status(ttl)
status.JWKSURL = strings.TrimSpace(conf.JWKSUrl())
if ctx.Bool("json") {
return printJSON(status)
}
fmt.Println()
fmt.Printf("JWKS URL: %s\n", status.JWKSURL)
fmt.Printf("Cache Path: %s\n", status.CachePath)
fmt.Printf("Cache URL: %s\n", status.CacheURL)
fmt.Printf("Cache ETag: %s\n", status.CacheETag)
fmt.Printf("Cached Keys: %d\n", status.KeyCount)
if len(status.KeyIDs) > 0 {
fmt.Printf("Key IDs: %s\n", strings.Join(status.KeyIDs, ", "))
}
if !status.CacheFetchedAt.IsZero() {
fmt.Printf("Last Fetch: %s\n", status.CacheFetchedAt.Format(time.RFC3339))
} else {
fmt.Println("Last Fetch: never")
}
fmt.Printf("Cache Age: %ds\n", status.CacheAgeSeconds)
if status.CacheTTLSeconds > 0 {
fmt.Printf("Cache TTL: %ds\n", status.CacheTTLSeconds)
}
if status.CacheStale {
fmt.Println("Cache Status: STALE")
} else {
fmt.Println("Cache Status: fresh")
}
fmt.Println()
return nil
})
}

View File

@@ -0,0 +1,94 @@
package commands
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/photoprism/get"
"github.com/photoprism/photoprism/internal/service/cluster"
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
"github.com/photoprism/photoprism/pkg/rnd"
)
func TestAuthJWTCommands(t *testing.T) {
conf := get.Config()
origEdition := conf.Options().Edition
origRole := conf.Options().NodeRole
origUUID := conf.Options().ClusterUUID
origPortal := conf.Options().PortalUrl
origJWKS := conf.JWKSUrl()
conf.Options().Edition = config.Portal
conf.Options().NodeRole = string(cluster.RolePortal)
conf.Options().ClusterUUID = "11111111-1111-4111-8111-111111111111"
conf.Options().PortalUrl = "https://portal.test"
conf.SetJWKSUrl("https://portal.test/.well-known/jwks.json")
get.SetConfig(conf)
conf.RegisterDb()
require.True(t, conf.IsPortal())
manager := get.JWTManager()
require.NotNil(t, manager)
_, err := manager.EnsureActiveKey()
require.NoError(t, err)
registry, err := reg.NewClientRegistryWithConfig(conf)
require.NoError(t, err)
nodeUUID := rnd.UUID()
node := &reg.Node{}
node.UUID = nodeUUID
node.Name = "pp-node-01"
node.Role = string(cluster.RoleInstance)
require.NoError(t, registry.Put(node))
t.Cleanup(func() {
conf.Options().Edition = origEdition
conf.Options().NodeRole = origRole
conf.Options().ClusterUUID = origUUID
conf.Options().PortalUrl = origPortal
conf.SetJWKSUrl(origJWKS)
get.SetConfig(conf)
conf.RegisterDb()
})
output, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID})
require.NoError(t, err)
assert.Contains(t, output, "Issued JWT")
jsonOut, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--json"})
require.NoError(t, err)
var payload struct {
Token string `json:"token"`
}
require.NoError(t, json.Unmarshal([]byte(jsonOut), &payload))
require.NotEmpty(t, payload.Token)
inspectOut, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", payload.Token})
require.NoError(t, err)
assert.Contains(t, inspectOut, "\"verified\": true")
inspectStrict, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", "--expect-audience", "node:" + nodeUUID, "--require-scope", "cluster", payload.Token})
require.NoError(t, err)
assert.Contains(t, inspectStrict, "\"verified\": true")
keysOut, err := RunWithTestContext(AuthJWTKeysListCommand, []string{"ls", "--json"})
require.NoError(t, err)
assert.Contains(t, keysOut, "\"keys\"")
statusOut, err := RunWithTestContext(AuthJWTStatusCommand, []string{"status"})
require.NoError(t, err)
assert.Contains(t, statusOut, "JWKS URL")
assert.Contains(t, statusOut, "Cached Keys")
// invalid scope should fail
_, err = RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--scope", "unknown"})
require.Error(t, err)
}

View File

@@ -19,8 +19,9 @@ type healthResponse struct {
// ClusterHealthCommand prints a minimal health response (Portal-only).
var ClusterHealthCommand = &cli.Command{
Name: "health",
Usage: "Shows cluster health (Portal-only)",
Usage: "Shows cluster health status",
Flags: report.CliFlags,
Hidden: true, // Required for cluster-management only.
Action: clusterHealthAction,
}

View File

@@ -14,8 +14,9 @@ import (
// ClusterNodesCommands groups node subcommands.
var ClusterNodesCommands = &cli.Command{
Name: "nodes",
Usage: "Node registry subcommands",
Name: "nodes",
Usage: "Node registry subcommands",
Hidden: true, // Required for cluster-management only.
Subcommands: []*cli.Command{
ClusterNodesListCommand,
ClusterNodesShowCommand,
@@ -28,9 +29,10 @@ var ClusterNodesCommands = &cli.Command{
// ClusterNodesListCommand lists registered nodes.
var ClusterNodesListCommand = &cli.Command{
Name: "ls",
Usage: "Lists registered cluster nodes (Portal-only)",
Usage: "Lists registered cluster nodes",
Flags: append(report.CliFlags, CountFlag, OffsetFlag),
ArgsUsage: "",
Hidden: true, // Required for cluster-management only.
Action: clusterNodesListAction,
}

View File

@@ -22,9 +22,10 @@ var (
// ClusterNodesModCommand updates node fields.
var ClusterNodesModCommand = &cli.Command{
Name: "mod",
Usage: "Updates node properties (Portal-only)",
Usage: "Updates node properties",
ArgsUsage: "<id|name>",
Flags: []cli.Flag{nodesModRoleFlag, nodesModInternal, nodesModLabel, &cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"}},
Hidden: true, // Required for cluster-management only.
Action: clusterNodesModAction,
}

View File

@@ -14,12 +14,13 @@ import (
// ClusterNodesRemoveCommand deletes a node from the registry.
var ClusterNodesRemoveCommand = &cli.Command{
Name: "rm",
Usage: "Deletes a node from the registry (Portal-only)",
Usage: "Deletes a node from the registry",
ArgsUsage: "<id|name>",
Flags: []cli.Flag{
&cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"},
&cli.BoolFlag{Name: "all-ids", Usage: "delete all records that share the same UUID (admin cleanup)"},
},
Hidden: true, // Required for cluster-management only.
Action: clusterNodesRemoveAction,
}

View File

@@ -2,6 +2,7 @@ package commands
import (
"encoding/json"
"errors"
"fmt"
"os"
@@ -106,11 +107,13 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
}
b, _ := json.Marshal(payload)
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
var resp cluster.RegisterResponse
if err := postWithBackoff(url, token, b, &resp); err != nil {
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
// Map common HTTP errors similarly to register command
if he, ok := err.(*httpError); ok {
var he *httpError
if errors.As(err, &he) {
switch he.Status {
case 401, 403:
return cli.Exit(fmt.Errorf("%s", he.Error()), 4)
@@ -151,6 +154,7 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
fmt.Printf("DSN: %s\n", resp.Database.DSN)
}
}
return nil
})
}

View File

@@ -15,9 +15,10 @@ import (
// ClusterNodesShowCommand shows node details.
var ClusterNodesShowCommand = &cli.Command{
Name: "show",
Usage: "Shows node details (Portal-only)",
Usage: "Shows node details",
ArgsUsage: "<id|name>",
Flags: report.CliFlags,
Hidden: true, // Required for cluster-management only.
Action: clusterNodesShowAction,
}

View File

@@ -24,7 +24,7 @@ import (
"github.com/photoprism/photoprism/pkg/txt/report"
)
// flags for register
// Supported cluster node register flags.
var (
regNameFlag = &cli.StringFlag{Name: "name", Usage: "node `NAME` (lowercase letters, digits, hyphens)"}
regRoleFlag = &cli.StringFlag{Name: "role", Usage: "node `ROLE` (instance, service)", Value: "instance"}
@@ -42,7 +42,7 @@ var (
// ClusterRegisterCommand registers a node with the Portal via HTTP.
var ClusterRegisterCommand = &cli.Command{
Name: "register",
Usage: "Registers/rotates a node via Portal (HTTP)",
Usage: "Registers a node or updates its credentials within a cluster",
Flags: append(append([]cli.Flag{regNameFlag, regRoleFlag, regIntUrlFlag, regLabelFlag, regRotateDatabase, regRotateSec, regPortalURL, regPortalTok, regWriteConf, regForceFlag, regDryRun}, report.CliFlags...)),
Action: clusterRegisterAction,
}
@@ -52,15 +52,18 @@ func clusterRegisterAction(ctx *cli.Context) error {
// Resolve inputs
name := clean.DNSLabel(ctx.String("name"))
derivedName := false
if name == "" { // default from config if set
name = clean.DNSLabel(conf.NodeName())
if name != "" {
derivedName = true
}
}
if name == "" {
return cli.Exit(fmt.Errorf("node name is required (use --name or set node-name)"), 2)
}
nodeRole := clean.TypeLowerDash(ctx.String("role"))
switch nodeRole {
case "instance", "service":
@@ -76,7 +79,6 @@ func clusterRegisterAction(ctx *cli.Context) error {
derivedPortal = true
}
}
// In dry-run, we allow empty portalURL (will print derived/empty values).
// Derive advertise/site URLs when omitted.
advertise := ctx.String("advertise-url")
@@ -93,17 +95,20 @@ func clusterRegisterAction(ctx *cli.Context) error {
RotateDatabase: ctx.Bool("rotate"),
RotateSecret: ctx.Bool("rotate-secret"),
}
// If we already have client credentials (e.g., re-register), include them so the
// portal can verify and authorize UUID/name moves or metadata updates.
if id, secret := strings.TrimSpace(conf.NodeClientID()), strings.TrimSpace(conf.NodeClientSecret()); id != "" && secret != "" {
payload.ClientID = id
payload.ClientSecret = secret
}
if site != "" && site != advertise {
payload.SiteUrl = site
}
b, _ := json.Marshal(payload)
// In dry-run, we allow empty portalURL (will print derived/empty values).
if ctx.Bool("dry-run") {
if ctx.Bool("json") {
out := map[string]any{"portalUrl": portalURL, "payload": payload}
@@ -140,18 +145,22 @@ func clusterRegisterAction(ctx *cli.Context) error {
if portalURL == "" {
return cli.Exit(fmt.Errorf("portal URL is required (use --portal-url or set portal-url)"), 2)
}
token := ctx.String("join-token")
if token == "" {
token = conf.JoinToken()
}
if token == "" {
return cli.Exit(fmt.Errorf("portal token is required (use --join-token or set join-token)"), 2)
}
// POST with bounded backoff on 429
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
var resp cluster.RegisterResponse
if err := postWithBackoff(url, token, b, &resp); err != nil {
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
var httpErr *httpError
if errors.As(err, &httpErr) && httpErr.Status == http.StatusTooManyRequests {
return cli.Exit(fmt.Errorf("portal rate-limited registration attempts"), 6)
@@ -179,13 +188,17 @@ func clusterRegisterAction(ctx *cli.Context) error {
} else {
// Human-readable: node row and credentials if present (UUID first as primary identifier)
cols := []string{"UUID", "ClientID", "Name", "Role", "DB Driver", "DB Name", "DB User", "Host", "Port"}
var dbName, dbUser string
if resp.Database.Name != "" {
dbName = resp.Database.Name
}
if resp.Database.User != "" {
dbUser = resp.Database.User
}
rows := [][]string{{resp.Node.UUID, resp.Node.ClientID, resp.Node.Name, resp.Node.Role, resp.Database.Driver, dbName, dbUser, resp.Database.Host, fmt.Sprintf("%d", resp.Database.Port)}}
out, _ := report.RenderFormat(rows, cols, report.CliFormat(ctx))
fmt.Printf("\n%s\n", out)

View File

@@ -13,11 +13,12 @@ import (
"github.com/photoprism/photoprism/pkg/txt/report"
)
// ClusterSummaryCommand prints a minimal cluster summary (Portal-only).
// ClusterSummaryCommand prints a minimal cluster summary.
var ClusterSummaryCommand = &cli.Command{
Name: "summary",
Usage: "Shows cluster summary (Portal-only)",
Usage: "Shows cluster summary",
Flags: report.CliFlags,
Hidden: true, // Required for cluster-management only.
Action: clusterSummaryAction,
}

View File

@@ -97,7 +97,6 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
// Ensure DB connection is open for each command run (some commands call Shutdown).
if c := get.Config(); c != nil {
_ = c.Init() // safe to call; re-opens DB if needed
c.RegisterDb() // (re)register provider
}
@@ -110,5 +109,11 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
err = cmd.Run(ctx, args...)
})
// Re-open the database after the command completed so follow-up checks
// (potentially issued by the test itself) have an active connection.
if c := get.Config(); c != nil {
c.RegisterDb()
}
return output, err
}

View File

@@ -81,9 +81,10 @@ func TestDownloadImpl_FileMethod_AutoSkipsRemux(t *testing.T) {
if conf == nil {
t.Fatalf("missing test config")
}
// Ensure DB is initialized and registered (bypassing CLI InitConfig)
_ = conf.Init()
conf.RegisterDb()
// Override yt-dlp after config init (config may set dl.YtDlpBin)
dl.YtDlpBin = fake
t.Logf("using yt-dlp binary: %s", dl.YtDlpBin)
@@ -125,7 +126,6 @@ func TestDownloadImpl_FileMethod_Skip_NoRemux(t *testing.T) {
if conf == nil {
t.Fatalf("missing test config")
}
_ = conf.Init()
conf.RegisterDb()
dl.YtDlpBin = fake
@@ -196,8 +196,9 @@ func TestDownloadImpl_FileMethod_Always_RemuxFails(t *testing.T) {
if conf == nil {
t.Fatalf("missing test config")
}
_ = conf.Init()
conf.RegisterDb()
dl.YtDlpBin = fake
err := runDownload(conf, DownloadOpts{

View File

@@ -0,0 +1,8 @@
package commands
import "github.com/urfave/cli/v2"
// JsonFlag returns the shared CLI flag definition for JSON output across commands.
func JsonFlag() *cli.BoolFlag {
return &cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"}
}

View File

@@ -0,0 +1,173 @@
package commands
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/urfave/cli/v2"
"github.com/photoprism/photoprism/internal/auth/acl"
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/photoprism/get"
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
"github.com/photoprism/photoprism/pkg/clean"
)
var allowedJWTScope = func() map[string]struct{} {
out := make(map[string]struct{}, len(acl.ResourceNames))
for _, res := range acl.ResourceNames {
out[res.String()] = struct{}{}
}
return out
}()
// requirePortal returns a CLI error when the active configuration is not a portal node.
func requirePortal(conf *config.Config) error {
if conf == nil || !conf.IsPortal() {
return cli.Exit(errors.New("command requires a Portal node"), 2)
}
return nil
}
// resolveNode finds a node by UUID, client ID, or DNS label using the portal registry.
func resolveNode(conf *config.Config, identifier string) (*reg.Node, error) {
if err := requirePortal(conf); err != nil {
return nil, err
}
key := strings.TrimSpace(identifier)
if key == "" {
return nil, cli.Exit(errors.New("node identifier required"), 2)
}
registry, err := reg.NewClientRegistryWithConfig(conf)
if err != nil {
return nil, cli.Exit(err, 1)
}
if node, err := registry.FindByNodeUUID(key); err == nil && node != nil {
return node, nil
}
if node, err := registry.FindByClientID(key); err == nil && node != nil {
return node, nil
}
name := clean.DNSLabel(key)
if name == "" {
return nil, cli.Exit(errors.New("invalid node identifier"), 2)
}
node, err := registry.FindByName(name)
if err != nil {
if errors.Is(err, reg.ErrNotFound) {
return nil, cli.Exit(fmt.Errorf("node %q not found", identifier), 3)
}
return nil, cli.Exit(err, 1)
}
return node, nil
}
// decodeJWTClaims decodes the compact JWT and returns header and claims without verifying the signature.
func decodeJWTClaims(token string) (map[string]any, *clusterjwt.Claims, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, nil, errors.New("jwt: token must contain three segments")
}
decode := func(segment string) ([]byte, error) {
return base64.RawURLEncoding.DecodeString(segment)
}
headerBytes, err := decode(parts[0])
if err != nil {
return nil, nil, err
}
payloadBytes, err := decode(parts[1])
if err != nil {
return nil, nil, err
}
var header map[string]any
if err := json.Unmarshal(headerBytes, &header); err != nil {
return nil, nil, err
}
claims := &clusterjwt.Claims{}
if err := json.Unmarshal(payloadBytes, claims); err != nil {
return nil, nil, err
}
return header, claims, nil
}
// verifyPortalToken verifies a JWT using the portal's in-memory key manager.
func verifyPortalToken(conf *config.Config, token string, expected clusterjwt.ExpectedClaims) (*clusterjwt.Claims, error) {
if err := requirePortal(conf); err != nil {
return nil, err
}
manager := get.JWTManager()
if manager == nil {
return nil, cli.Exit(errors.New("jwt issuer not available"), 1)
}
jwks := manager.JWKS()
if jwks == nil || len(jwks.Keys) == 0 {
return nil, cli.Exit(errors.New("jwks key set is empty"), 1)
}
leeway := time.Duration(conf.JWTLeeway()) * time.Second
if leeway <= 0 {
leeway = 60 * time.Second
}
claims, err := clusterjwt.VerifyTokenWithKeys(token, expected, jwks.Keys, leeway)
if err != nil {
return nil, err
}
return claims, nil
}
// normalizeScopes trims and de-duplicates scope values, falling back to defaults when necessary.
func normalizeScopes(values []string, defaults ...string) ([]string, error) {
src := values
if len(src) == 0 {
src = defaults
}
out := make([]string, 0, len(src))
seen := make(map[string]struct{}, len(src))
for _, raw := range src {
for _, parsed := range clean.Scopes(raw) {
scope := clean.Scope(parsed)
if scope == "" {
continue
}
if _, exists := seen[scope]; exists {
continue
}
if _, ok := allowedJWTScope[scope]; !ok {
return nil, cli.Exit(fmt.Errorf("unsupported scope %q", scope), 2)
}
seen[scope] = struct{}{}
out = append(out, scope)
}
}
if len(out) == 0 {
return nil, cli.Exit(errors.New("at least one scope is required"), 2)
}
return out, nil
}
// printJSON pretty-prints the payload as JSON.
func printJSON(payload any) error {
data, err := json.MarshalIndent(payload, "", " ")
if err != nil {
return cli.Exit(err, 1)
}
fmt.Printf("%s\n", data)
return nil
}

View File

@@ -16,7 +16,7 @@ var ShowCommandsCommand = &cli.Command{
Name: "commands",
Usage: "Displays a structured catalog of CLI commands",
Flags: []cli.Flag{
&cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"},
JsonFlag(),
&cli.BoolFlag{Name: "all", Usage: "include hidden commands and flags"},
&cli.BoolFlag{Name: "short", Usage: "omit flags in Markdown output"},
&cli.IntFlag{Name: "base-heading", Value: 2, Usage: "base Markdown heading level"},

View File

@@ -6,6 +6,7 @@ import (
"io"
"net"
"net/http"
"strings"
"time"
"github.com/tidwall/gjson"
@@ -43,9 +44,9 @@ func statusAction(ctx *cli.Context) error {
}
}
url := fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
endpointUrl := buildStatusEndpoint(conf)
req, err := http.NewRequest(http.MethodGet, url, nil)
req, err := http.NewRequest(http.MethodGet, endpointUrl, nil)
if err != nil {
return err
@@ -53,12 +54,12 @@ func statusAction(ctx *cli.Context) error {
var status string
if resp, err := client.Do(req); err != nil {
if resp, reqErr := client.Do(req); reqErr != nil {
return fmt.Errorf("cannot connect to %s:%d", conf.HttpHost(), conf.HttpPort())
} else if resp.StatusCode != 200 {
return fmt.Errorf("server running at %s:%d, bad status %d\n", conf.HttpHost(), conf.HttpPort(), resp.StatusCode)
} else if body, err := io.ReadAll(resp.Body); err != nil {
return err
} else if body, readErr := io.ReadAll(resp.Body); readErr != nil {
return readErr
} else {
status = string(body)
}
@@ -73,3 +74,21 @@ func statusAction(ctx *cli.Context) error {
return nil
}
// buildStatusEndpoint returns the status endpoint URL, preferring the public
// SiteUrl (which carries the correct scheme) and falling back to the local
// HTTP host/port. When a Unix socket is configured, an http+unix style URL is
// used so the custom transport can dial the socket.
func buildStatusEndpoint(conf *config.Config) string {
if socket := conf.HttpSocket(); socket != nil {
return fmt.Sprintf("%s://%s/api/v1/status", socket.Scheme, strings.TrimPrefix(socket.Path, "/"))
}
siteUrl := strings.TrimRight(conf.SiteUrl(), "/")
if siteUrl != "" {
return siteUrl + "/api/v1/status"
}
return fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
}

View File

@@ -2,6 +2,8 @@ package config
import (
"errors"
"net"
urlpkg "net/url"
"os"
"path/filepath"
"strings"
@@ -222,7 +224,36 @@ func (c *Config) SetJWKSUrl(url string) {
if c == nil || c.options == nil {
return
}
c.options.JWKSUrl = strings.TrimSpace(url)
trimmed := strings.TrimSpace(url)
if trimmed == "" {
c.options.JWKSUrl = ""
return
}
parsed, err := urlpkg.Parse(trimmed)
if err != nil || parsed == nil || parsed.Scheme == "" || parsed.Host == "" {
log.Warnf("config: ignoring JWKS URL %q (%v)", trimmed, err)
return
}
scheme := strings.ToLower(parsed.Scheme)
host := parsed.Hostname()
switch scheme {
case "https":
// Always allowed.
case "http":
if !isLoopbackHost(host) {
log.Warnf("config: rejecting JWKS URL %q (http only allowed for localhost/loopback)", trimmed)
return
}
default:
log.Warnf("config: rejecting JWKS URL %q (unsupported scheme)", trimmed)
return
}
c.options.JWKSUrl = trimmed
}
// JWKSCacheTTL returns the JWKS cache lifetime in seconds (default 300, max 3600).
@@ -261,6 +292,23 @@ func (c *Config) AdvertiseUrl() string {
return c.SiteUrl()
}
// isLoopbackHost returns true when host represents localhost or a loopback IP.
func isLoopbackHost(host string) bool {
if host == "" {
return false
}
if strings.EqualFold(host, "localhost") {
return true
}
if ip := net.ParseIP(host); ip != nil {
return ip.IsLoopback()
}
return false
}
// SaveClusterUUID writes or updates the ClusterUUID key in options.yml without
// touching unrelated keys. Creates the file and directories if needed.
func (c *Config) SaveClusterUUID(uuid string) error {

View File

@@ -73,15 +73,77 @@ func TestConfig_Cluster(t *testing.T) {
c.Options().NodeRole = ""
})
t.Run("JWKSUrlSetter", func(t *testing.T) {
c := NewConfig(CliTestContext())
c.options.JWKSUrl = ""
assert.Equal(t, "", c.JWKSUrl())
const existing = "https://existing.example/.well-known/jwks.json"
tests := []struct {
name string
prev string
input string
expect string
}{
{
name: "TrimHTTPS",
prev: "",
input: " https://portal.example/.well-known/jwks.json ",
expect: "https://portal.example/.well-known/jwks.json",
},
{
name: "CaseInsensitiveScheme",
prev: "",
input: "HTTPS://portal.example/.well-known/jwks.json",
expect: "HTTPS://portal.example/.well-known/jwks.json",
},
{
name: "AllowHTTPOnLocalhost",
prev: "",
input: "http://localhost:2342/.well-known/jwks.json",
expect: "http://localhost:2342/.well-known/jwks.json",
},
{
name: "AllowHTTPOnLoopbackIPv4",
prev: "",
input: "http://127.0.0.1/.well-known/jwks.json",
expect: "http://127.0.0.1/.well-known/jwks.json",
},
{
name: "AllowHTTPOnLoopbackIPv6",
prev: "",
input: "http://[::1]/.well-known/jwks.json",
expect: "http://[::1]/.well-known/jwks.json",
},
{
name: "RejectHTTPNonLoopback",
prev: existing,
input: "http://portal.example/.well-known/jwks.json",
expect: existing,
},
{
name: "RejectUnsupportedScheme",
prev: existing,
input: "ftp://portal.example/.well-known/jwks.json",
expect: existing,
},
{
name: "RejectMalformedURL",
prev: existing,
input: "://not-a-url",
expect: existing,
},
{
name: "ClearValue",
prev: existing,
input: "",
expect: "",
},
}
c.SetJWKSUrl(" https://portal.example/.well-known/jwks.json ")
assert.Equal(t, "https://portal.example/.well-known/jwks.json", c.JWKSUrl())
c.SetJWKSUrl("")
assert.Equal(t, "", c.JWKSUrl())
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
c := NewConfig(CliTestContext())
c.options.JWKSUrl = tc.prev
c.SetJWKSUrl(tc.input)
assert.Equal(t, tc.expect, c.JWKSUrl())
})
}
})
t.Run("Paths", func(t *testing.T) {
c := NewConfig(CliTestContext())

View File

@@ -342,12 +342,18 @@ func (c *Config) SetDbOptions() {
case Postgres:
// Ignore for now.
case SQLite3:
// Not required as unicode is default.
// Not required as Unicode is default.
}
}
// RegisterDb sets the database options and connection provider.
// RegisterDb opens a database connection if needed,
// sets the database options and connection provider.
func (c *Config) RegisterDb() {
if err := c.connectDb(); err != nil {
log.Errorf("config: %s (register db)")
return
}
c.SetDbOptions()
entity.SetDbProvider(c)
}
@@ -456,6 +462,11 @@ func (c *Config) connectDb() error {
mutex.Db.Lock()
defer mutex.Db.Unlock()
// Database connection already exists.
if c.db != nil {
return nil
}
// Get database driver and data source name.
dbDriver := c.DatabaseDriver()
dbDsn := c.DatabaseDSN()

View File

@@ -227,13 +227,16 @@ func TestDownloadPlaylistEntry(t *testing.T) {
}
// Download the same file but with the direct link
url := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
dlUrl := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
stderrBuf = &bytes.Buffer{}
r, err = NewMetadata(context.Background(), url, Options{
r, err = NewMetadata(context.Background(), dlUrl, Options{
StderrFn: func(cmd *exec.Cmd) io.Writer {
return stderrBuf
},
})
if err != nil {
t.Fatal(err)
}

View File

@@ -1,30 +1,42 @@
package get
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/photoprism/photoprism/internal/auth/jwt"
"github.com/photoprism/photoprism/pkg/clean"
)
var (
onceJWTManager sync.Once
onceJWTIssuer sync.Once
onceJWTManager sync.Once
onceJWTIssuer sync.Once
onceJWTVerifier sync.Once
)
// initJWTManager lazily initializes the shared portal key manager for JWT issuance.
func initJWTManager() {
conf := Config()
if conf == nil || !conf.IsPortal() {
if conf == nil {
return
} else if !conf.IsPortal() {
return
}
manager, err := jwt.NewManager(conf)
if err != nil {
log.Warnf("jwt: manager init failed (%s)", clean.Error(err))
return
}
if _, err := manager.EnsureActiveKey(); err != nil {
if _, err = manager.EnsureActiveKey(); err != nil {
log.Warnf("jwt: ensure signing key failed (%s)", clean.Error(err))
}
services.JWTManager = manager
}
@@ -34,6 +46,7 @@ func JWTManager() *jwt.Manager {
return services.JWTManager
}
// initJWTIssuer lazily binds the shared issuer to the active portal key manager.
func initJWTIssuer() {
manager := JWTManager()
if manager == nil {
@@ -50,9 +63,98 @@ func JWTIssuer() *jwt.Issuer {
// JWTVerifier returns a verifier bound to the current config.
func JWTVerifier() *jwt.Verifier {
conf := Config()
if conf == nil {
return nil
}
return jwt.NewVerifier(conf)
onceJWTVerifier.Do(initJWTVerifier)
return services.JWTVerifier
}
// VerifyJWT verifies a token using the shared verifier instance.
func VerifyJWT(ctx context.Context, token string, expected jwt.ExpectedClaims) (*jwt.Claims, error) {
verifier := JWTVerifier()
if verifier == nil {
return nil, errors.New("jwt: verifier not available")
}
return verifier.VerifyToken(ctx, token, expected)
}
// initJWTVerifier lazily constructs the shared verifier for the current configuration.
func initJWTVerifier() {
if conf != nil {
services.JWTVerifier = jwt.NewVerifier(conf)
}
}
// resetJWTVerifier clears the cached verifier so it can be rebuilt for a new configuration.
func resetJWTVerifier() {
services.JWTVerifier = nil
onceJWTVerifier = sync.Once{}
}
// resetJWTIssuer clears the cached issuer so it can be recreated for a new configuration.
func resetJWTIssuer() {
services.JWTIssuer = nil
onceJWTIssuer = sync.Once{}
}
// resetJWTManager clears the cached key manager so subsequent calls reload keys for the active configuration.
func resetJWTManager() {
services.JWTManager = nil
onceJWTManager = sync.Once{}
}
// resetJWT clears all cached JWT helpers.
func resetJWT() {
resetJWTVerifier()
resetJWTIssuer()
resetJWTManager()
}
// IssuePortalJWT signs a token using the shared portal issuer with the provided claims.
func IssuePortalJWT(spec jwt.ClaimsSpec) (string, error) {
if issuer := JWTIssuer(); issuer == nil {
return "", errors.New("jwt: issuer not available")
} else {
return issuer.Issue(spec)
}
}
// IssuePortalJWTForNode issues a portal-signed JWT targeting the specified node UUID.
func IssuePortalJWTForNode(nodeUUID string, scopes []string, ttl time.Duration) (string, error) {
if conf == nil {
return "", errors.New("jwt: missing config")
} else if !conf.IsPortal() {
return "", errors.New("jwt: not supported on nodes")
}
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
if clusterUUID == "" {
return "", errors.New("jwt: cluster uuid not configured")
}
nodeUUID = strings.TrimSpace(nodeUUID)
if nodeUUID == "" {
return "", errors.New("jwt: node uuid required")
}
if len(scopes) == 0 {
return "", errors.New("jwt: at least one scope is required")
}
normalized := make([]string, 0, len(scopes))
for _, s := range scopes {
if trimmed := strings.TrimSpace(s); trimmed != "" {
normalized = append(normalized, trimmed)
}
}
if len(normalized) == 0 {
return "", errors.New("jwt: at least one scope is required")
}
spec := jwt.ClaimsSpec{
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
Subject: fmt.Sprintf("portal:%s", clusterUUID),
Audience: fmt.Sprintf("node:%s", nodeUUID),
Scope: normalized,
TTL: ttl,
}
return IssuePortalJWT(spec)
}

View File

@@ -0,0 +1,39 @@
package get
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/photoprism/photoprism/internal/config"
)
func TestJWTVerifierReuse(t *testing.T) {
verifier1 := JWTVerifier()
require.NotNil(t, verifier1)
verifier2 := JWTVerifier()
require.NotNil(t, verifier2)
assert.Same(t, verifier1, verifier2)
}
func TestJWTVerifierResetOnConfigChange(t *testing.T) {
orig := Config()
verifier1 := JWTVerifier()
require.NotNil(t, verifier1)
tempConf := config.NewMinimalTestConfigWithDb("jwt-reset", t.TempDir())
SetConfig(tempConf)
t.Cleanup(func() {
SetConfig(orig)
tempConf.CloseDb()
orig.RegisterDb()
})
verifier2 := JWTVerifier()
require.NotNil(t, verifier2)
assert.NotSame(t, verifier1, verifier2)
}

View File

@@ -57,13 +57,17 @@ var services struct {
OIDC *oidc.Client
JWTManager *clusterjwt.Manager
JWTIssuer *clusterjwt.Issuer
JWTVerifier *clusterjwt.Verifier
}
func SetConfig(c *config.Config) {
if c == nil {
log.Panic("panic: argument is nil in get.SetConfig(c *config.Config)")
return
}
resetJWT()
conf = c
photoprism.SetConfig(c)
@@ -72,6 +76,7 @@ func SetConfig(c *config.Config) {
func Config() *config.Config {
if conf == nil {
log.Panic("panic: conf is nil in get.Config()")
return nil
}
return conf

View File

@@ -229,9 +229,9 @@ func persistRegistration(c *config.Config, r *cluster.RegisterResponse, wantRota
updates["NodeClientSecret"] = r.Secrets.ClientSecret
}
if url := strings.TrimSpace(r.JWKSUrl); url != "" {
updates["JWKSUrl"] = url
c.SetJWKSUrl(url)
if jwksUrl := strings.TrimSpace(r.JWKSUrl); jwksUrl != "" {
updates["JWKSUrl"] = jwksUrl
c.SetJWKSUrl(jwksUrl)
}
// Persist NodeUUID from portal response if provided and not set locally.

View File

@@ -221,13 +221,13 @@ func (c *Config) ReSync(token string) (err error) {
// interrupt reading of the Response.Body.
client := &http.Client{Timeout: 60 * time.Second}
url := ServiceURL
endpointUrl := ServiceURL
method := http.MethodPost
var req *http.Request
if c.Key != "" {
url = fmt.Sprintf(ServiceURL+"/%s", c.Key)
endpointUrl = fmt.Sprintf(ServiceURL+"/%s", c.Key)
method = http.MethodPut
log.Tracef("config: requesting updated keys for maps and places")
} else {
@@ -239,7 +239,7 @@ func (c *Config) ReSync(token string) (err error) {
if j, err = json.Marshal(NewRequest(c.Version, c.Serial, c.Env, c.PartnerID, token)); err != nil {
return err
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
} else if req, err = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); err != nil {
return err
}

View File

@@ -67,17 +67,17 @@ func (c *Config) SendFeedback(frm form.Feedback) (err error) {
// interrupt reading of the Response.Body.
client := &http.Client{Timeout: 60 * time.Second}
url := fmt.Sprintf(FeedbackURL, c.Key)
endpointUrl := fmt.Sprintf(FeedbackURL, c.Key)
method := http.MethodPost
var req *http.Request
log.Debugf("sending feedback to %s", ApiHost())
if j, err := json.Marshal(feedback); err != nil {
return err
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
return err
if j, reqErr := json.Marshal(feedback); reqErr != nil {
return reqErr
} else if req, reqErr = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); reqErr != nil {
return reqErr
}
// Set user agent.