mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 12:51:31 +08:00
CLI: Added JWT issuance and diagnostics sub commands #5230
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
@@ -50,7 +50,6 @@ func authAnyJWT(c *gin.Context, clientIP, authToken string, resource acl.Resourc
|
||||
return nil
|
||||
}
|
||||
|
||||
verifier := clusterjwt.NewVerifier(conf)
|
||||
requiredScopes := []string{"cluster"}
|
||||
if resource == acl.ResourceVision {
|
||||
requiredScopes = []string{"vision"}
|
||||
@@ -77,7 +76,7 @@ func authAnyJWT(c *gin.Context, clientIP, authToken string, resource acl.Resourc
|
||||
|
||||
for _, issuer := range issuers {
|
||||
expected.Issuer = issuer
|
||||
claims, err = verifier.VerifyToken(ctx, authToken, expected)
|
||||
claims, err = get.VerifyJWT(ctx, authToken, expected)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
@@ -25,12 +25,12 @@ func TestEcho(t *testing.T) {
|
||||
t.Logf("Response Body: %s", r.Body.String())
|
||||
|
||||
body := r.Body.String()
|
||||
url := gjson.Get(body, "url").String()
|
||||
bodyUrl := gjson.Get(body, "url").String()
|
||||
method := gjson.Get(body, "method").String()
|
||||
request := gjson.Get(body, "headers.request")
|
||||
response := gjson.Get(body, "headers.response")
|
||||
|
||||
assert.Equal(t, "/api/v1/echo", url)
|
||||
assert.Equal(t, "/api/v1/echo", bodyUrl)
|
||||
assert.Equal(t, "GET", method)
|
||||
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
||||
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
||||
@@ -49,12 +49,12 @@ func TestEcho(t *testing.T) {
|
||||
r := AuthenticatedRequest(app, http.MethodPost, "/api/v1/echo", authToken)
|
||||
|
||||
body := r.Body.String()
|
||||
url := gjson.Get(body, "url").String()
|
||||
bodyUrl := gjson.Get(body, "url").String()
|
||||
method := gjson.Get(body, "method").String()
|
||||
request := gjson.Get(body, "headers.request")
|
||||
response := gjson.Get(body, "headers.response")
|
||||
|
||||
assert.Equal(t, "/api/v1/echo", url)
|
||||
assert.Equal(t, "/api/v1/echo", bodyUrl)
|
||||
assert.Equal(t, "POST", method)
|
||||
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
||||
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
||||
|
@@ -25,6 +25,20 @@ var (
|
||||
errKeyNotFound = errors.New("jwt: key not found")
|
||||
)
|
||||
|
||||
// VerifierStatus captures diagnostic information about a verifier's JWKS cache state.
|
||||
type VerifierStatus struct {
|
||||
CacheURL string `json:"cacheUrl,omitempty"`
|
||||
CacheETag string `json:"cacheEtag,omitempty"`
|
||||
KeyIDs []string `json:"keyIds,omitempty"`
|
||||
KeyCount int `json:"keyCount"`
|
||||
CacheFetchedAt time.Time `json:"cacheFetchedAt,omitempty"`
|
||||
CacheAgeSeconds int64 `json:"cacheAgeSeconds"`
|
||||
CacheTTLSeconds int `json:"cacheTtlSeconds"`
|
||||
CacheStale bool `json:"cacheStale"`
|
||||
CachePath string `json:"cachePath,omitempty"`
|
||||
JWKSURL string `json:"jwksUrl,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// jwksFetchMaxRetries caps the number of immediate retry attempts after a fetch error.
|
||||
jwksFetchMaxRetries = 3
|
||||
@@ -99,15 +113,14 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
||||
if strings.TrimSpace(expected.Audience) == "" {
|
||||
return nil, errors.New("jwt: expected audience required")
|
||||
}
|
||||
if len(expected.Scope) == 0 {
|
||||
return nil, errors.New("jwt: expected scope required")
|
||||
|
||||
jwksUrl := strings.TrimSpace(expected.JWKSURL)
|
||||
|
||||
if jwksUrl == "" && v.conf != nil {
|
||||
jwksUrl = strings.TrimSpace(v.conf.JWKSUrl())
|
||||
}
|
||||
|
||||
url := strings.TrimSpace(expected.JWKSURL)
|
||||
if url == "" && v.conf != nil {
|
||||
url = strings.TrimSpace(v.conf.JWKSUrl())
|
||||
}
|
||||
if url == "" {
|
||||
if jwksUrl == "" {
|
||||
return nil, errors.New("jwt: jwks url not configured")
|
||||
}
|
||||
|
||||
@@ -126,16 +139,111 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
||||
claims := &Claims{}
|
||||
keyFunc := func(token *gojwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
|
||||
if kid == "" {
|
||||
return nil, errors.New("jwt: missing kid header")
|
||||
}
|
||||
pk, err := v.publicKeyForKid(ctx, url, kid, false)
|
||||
|
||||
pk, err := v.publicKeyForKid(ctx, jwksUrl, kid, false)
|
||||
|
||||
if errors.Is(err, errKeyNotFound) {
|
||||
pk, err = v.publicKeyForKid(ctx, url, kid, true)
|
||||
pk, err = v.publicKeyForKid(ctx, jwksUrl, kid, true)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
if _, err := parser.ParseWithClaims(tokenString, claims, keyFunc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
|
||||
return nil, errors.New("jwt: missing temporal claims")
|
||||
}
|
||||
|
||||
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
|
||||
return nil, errors.New("jwt: token ttl exceeds maximum")
|
||||
}
|
||||
|
||||
scopeSet := map[string]struct{}{}
|
||||
|
||||
for _, s := range strings.Fields(claims.Scope) {
|
||||
scopeSet[s] = struct{}{}
|
||||
}
|
||||
|
||||
for _, req := range expected.Scope {
|
||||
if _, ok := scopeSet[req]; !ok {
|
||||
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyTokenWithKeys verifies a token using the provided JWKS keys without performing HTTP fetches.
|
||||
func VerifyTokenWithKeys(tokenString string, expected ExpectedClaims, keys []PublicJWK, leeway time.Duration) (*Claims, error) {
|
||||
if strings.TrimSpace(tokenString) == "" {
|
||||
return nil, errors.New("jwt: token is empty")
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, errors.New("jwt: no jwks keys provided")
|
||||
}
|
||||
|
||||
if leeway <= 0 {
|
||||
leeway = 60 * time.Second
|
||||
}
|
||||
|
||||
keyMap := make(map[string]ed25519.PublicKey, len(keys))
|
||||
|
||||
for _, jwk := range keys {
|
||||
if jwk.Kid == "" {
|
||||
continue
|
||||
}
|
||||
raw, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(raw) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("jwt: invalid public key length %d", len(raw))
|
||||
}
|
||||
pk := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
copy(pk, raw)
|
||||
keyMap[jwk.Kid] = pk
|
||||
}
|
||||
|
||||
if len(keyMap) == 0 {
|
||||
return nil, errors.New("jwt: no valid jwks keys provided")
|
||||
}
|
||||
|
||||
options := []gojwt.ParserOption{
|
||||
gojwt.WithLeeway(leeway),
|
||||
gojwt.WithValidMethods([]string{gojwt.SigningMethodEdDSA.Alg()}),
|
||||
}
|
||||
|
||||
if iss := strings.TrimSpace(expected.Issuer); iss != "" {
|
||||
options = append(options, gojwt.WithIssuer(iss))
|
||||
}
|
||||
|
||||
if aud := strings.TrimSpace(expected.Audience); aud != "" {
|
||||
options = append(options, gojwt.WithAudience(aud))
|
||||
}
|
||||
|
||||
parser := gojwt.NewParser(options...)
|
||||
claims := &Claims{}
|
||||
keyFunc := func(token *gojwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
if kid == "" {
|
||||
return nil, errors.New("jwt: missing kid header")
|
||||
}
|
||||
pk, ok := keyMap[kid]
|
||||
if !ok {
|
||||
return nil, errKeyNotFound
|
||||
}
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
@@ -146,29 +254,70 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
||||
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
|
||||
return nil, errors.New("jwt: missing temporal claims")
|
||||
}
|
||||
|
||||
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
|
||||
return nil, errors.New("jwt: token ttl exceeds maximum")
|
||||
}
|
||||
|
||||
scopeSet := map[string]struct{}{}
|
||||
for _, s := range strings.Fields(claims.Scope) {
|
||||
scopeSet[s] = struct{}{}
|
||||
}
|
||||
for _, req := range expected.Scope {
|
||||
if _, ok := scopeSet[req]; !ok {
|
||||
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||
if len(expected.Scope) > 0 {
|
||||
scopeSet := map[string]struct{}{}
|
||||
for _, s := range strings.Fields(claims.Scope) {
|
||||
scopeSet[s] = struct{}{}
|
||||
}
|
||||
for _, req := range expected.Scope {
|
||||
if _, ok := scopeSet[req]; !ok {
|
||||
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Status returns diagnostic information about the verifier's current JWKS cache.
|
||||
func (v *Verifier) Status(ttl time.Duration) VerifierStatus {
|
||||
status := VerifierStatus{}
|
||||
|
||||
if ttl > 0 {
|
||||
status.CacheTTLSeconds = int(ttl / time.Second)
|
||||
}
|
||||
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
||||
status.CacheURL = v.cache.URL
|
||||
status.CacheETag = v.cache.ETag
|
||||
status.JWKSURL = v.cache.URL
|
||||
status.KeyCount = len(v.cache.Keys)
|
||||
status.KeyIDs = make([]string, 0, len(v.cache.Keys))
|
||||
|
||||
for _, key := range v.cache.Keys {
|
||||
status.KeyIDs = append(status.KeyIDs, key.Kid)
|
||||
}
|
||||
|
||||
status.CachePath = v.cachePath
|
||||
|
||||
if v.cache.FetchedAt > 0 {
|
||||
fetched := time.Unix(v.cache.FetchedAt, 0).UTC()
|
||||
status.CacheFetchedAt = fetched
|
||||
age := time.Since(fetched)
|
||||
status.CacheAgeSeconds = int64(age.Seconds())
|
||||
if ttl > 0 && age > ttl {
|
||||
status.CacheStale = true
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// publicKeyForKid resolves the public key for the given key ID, fetching JWKS data if needed.
|
||||
func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force bool) (ed25519.PublicKey, error) {
|
||||
keys, err := v.keysForURL(ctx, url, force)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
@@ -184,12 +333,14 @@ func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force b
|
||||
copy(pk, raw)
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
return nil, errKeyNotFound
|
||||
}
|
||||
|
||||
// keysForURL returns JWKS keys for the specified endpoint, reusing cache when possible.
|
||||
func (v *Verifier) keysForURL(ctx context.Context, url string, force bool) ([]PublicJWK, error) {
|
||||
ttl := 300 * time.Second
|
||||
|
||||
if v.conf != nil && v.conf.JWKSCacheTTL() > 0 {
|
||||
ttl = time.Duration(v.conf.JWKSCacheTTL()) * time.Second
|
||||
}
|
||||
@@ -250,13 +401,16 @@ func (v *Verifier) cachedKeys(url string, ttl time.Duration, cache cacheEntry, f
|
||||
if force || cache.URL != url || len(cache.Keys) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
age := v.now().Unix() - cache.FetchedAt
|
||||
if age < 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Duration(age)*time.Second > ttl {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return append([]PublicJWK(nil), cache.Keys...), true
|
||||
}
|
||||
|
||||
@@ -270,17 +424,21 @@ type jwksFetchResult struct {
|
||||
// fetchJWKS downloads the JWKS document (respecting conditional requests) and returns the parsed keys.
|
||||
func (v *Verifier) fetchJWKS(ctx context.Context, url, etag string) (*jwksFetchResult, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
@@ -331,6 +489,7 @@ func (v *Verifier) updateCache(url string, result *jwksFetchResult) ([]PublicJWK
|
||||
Keys: append([]PublicJWK(nil), result.keys...),
|
||||
FetchedAt: result.fetchedAt,
|
||||
}
|
||||
|
||||
_ = v.saveCacheLocked()
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), true
|
||||
}
|
||||
@@ -347,7 +506,7 @@ func (v *Verifier) loadCache() error {
|
||||
}
|
||||
|
||||
var entry cacheEntry
|
||||
if err := json.Unmarshal(b, &entry); err != nil {
|
||||
if err = json.Unmarshal(b, &entry); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -360,13 +519,17 @@ func (v *Verifier) saveCacheLocked() error {
|
||||
if v.cachePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := fs.MkdirAll(filepath.Dir(v.cachePath)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(v.cache)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(v.cachePath, data, fs.ModeSecretFile)
|
||||
}
|
||||
|
||||
@@ -377,11 +540,13 @@ func backoffDuration(attempt int) time.Duration {
|
||||
}
|
||||
|
||||
base := jwksFetchBaseDelay << (attempt - 1)
|
||||
|
||||
if base > jwksFetchMaxDelay {
|
||||
base = jwksFetchMaxDelay
|
||||
}
|
||||
|
||||
jitterRange := base / 2
|
||||
|
||||
if jitterRange > 0 {
|
||||
base += time.Duration(randInt63n(int64(jitterRange) + 1))
|
||||
}
|
||||
|
@@ -103,6 +103,56 @@ func TestVerifierPrimeAndVerify(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyTokenWithKeys(t *testing.T) {
|
||||
portalCfg := newTestConfig(t)
|
||||
clusterUUID := rnd.UUIDv7()
|
||||
portalCfg.Options().ClusterUUID = clusterUUID
|
||||
|
||||
mgr, err := NewManager(portalCfg)
|
||||
require.NoError(t, err)
|
||||
mgr.now = func() time.Time { return time.Date(2025, 9, 24, 10, 30, 0, 0, time.UTC) }
|
||||
_, err = mgr.EnsureActiveKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
issuer := NewIssuer(mgr)
|
||||
issuer.now = func() time.Time { return time.Now().UTC() }
|
||||
|
||||
spec := ClaimsSpec{
|
||||
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Subject: "portal:client-test",
|
||||
Audience: "node:1234",
|
||||
Scope: []string{"cluster"},
|
||||
}
|
||||
|
||||
token, err := issuer.Issue(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys := mgr.JWKS().Keys
|
||||
claims, err := VerifyTokenWithKeys(token, ExpectedClaims{
|
||||
Issuer: spec.Issuer,
|
||||
Audience: spec.Audience,
|
||||
Scope: []string{"cluster"},
|
||||
}, keys, 60*time.Second)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, spec.Subject, claims.Subject)
|
||||
|
||||
// Ensure scope filtering is honored when expected scope is empty.
|
||||
claims, err = VerifyTokenWithKeys(token, ExpectedClaims{
|
||||
Issuer: spec.Issuer,
|
||||
Audience: spec.Audience,
|
||||
}, keys, 60*time.Second)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, spec.Subject, claims.Subject)
|
||||
|
||||
// Missing scope should fail when explicitly required.
|
||||
_, err = VerifyTokenWithKeys(token, ExpectedClaims{
|
||||
Issuer: spec.Issuer,
|
||||
Audience: spec.Audience,
|
||||
Scope: []string{"vision"},
|
||||
}, keys, 60*time.Second)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIssuerClampTTL(t *testing.T) {
|
||||
portalCfg := newTestConfig(t)
|
||||
mgr, err := NewManager(portalCfg)
|
||||
|
@@ -15,6 +15,7 @@ var AuthCommands = &cli.Command{
|
||||
AuthShowCommand,
|
||||
AuthRemoveCommand,
|
||||
AuthResetCommand,
|
||||
AuthJWTCommands,
|
||||
},
|
||||
}
|
||||
|
||||
|
16
internal/commands/auth_jwt.go
Normal file
16
internal/commands/auth_jwt.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package commands
|
||||
|
||||
import "github.com/urfave/cli/v2"
|
||||
|
||||
// AuthJWTCommands groups JWT-related auth helpers under photoprism auth jwt.
|
||||
var AuthJWTCommands = &cli.Command{
|
||||
Name: "jwt",
|
||||
Usage: "JWT issuance and diagnostics",
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Subcommands: []*cli.Command{
|
||||
AuthJWTIssueCommand,
|
||||
AuthJWTInspectCommand,
|
||||
AuthJWTKeysCommand,
|
||||
AuthJWTStatusCommand,
|
||||
},
|
||||
}
|
154
internal/commands/auth_jwt_inspect.go
Normal file
154
internal/commands/auth_jwt_inspect.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
// AuthJWTInspectCommand inspects and verifies portal-issued JWTs.
|
||||
var AuthJWTInspectCommand = &cli.Command{
|
||||
Name: "inspect",
|
||||
Usage: "Decodes and verifies a portal JWT",
|
||||
ArgsUsage: "<token>",
|
||||
Flags: []cli.Flag{
|
||||
&cli.StringFlag{Name: "file", Aliases: []string{"f"}, Usage: "read token from file"},
|
||||
&cli.StringFlag{Name: "expect-audience", Usage: "expected audience (e.g., node:<uuid>)"},
|
||||
&cli.StringSliceFlag{Name: "require-scope", Usage: "require specific scope(s)"},
|
||||
&cli.BoolFlag{Name: "skip-verify", Usage: "decode without signature verification"},
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTInspectAction,
|
||||
}
|
||||
|
||||
// authJWTInspectAction decodes and optionally verifies a portal-issued JWT.
|
||||
func authJWTInspectAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := readTokenInput(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, claims, err := decodeJWTClaims(token)
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
var verified bool
|
||||
tokenScopes := clean.Scopes(claims.Scope)
|
||||
|
||||
if !ctx.Bool("skip-verify") {
|
||||
expected := clusterjwt.ExpectedClaims{}
|
||||
if clusterUUID := strings.TrimSpace(conf.ClusterUUID()); clusterUUID != "" {
|
||||
expected.Issuer = fmt.Sprintf("portal:%s", clusterUUID)
|
||||
} else if portal := strings.TrimSpace(conf.PortalUrl()); portal != "" {
|
||||
expected.Issuer = strings.TrimRight(portal, "/")
|
||||
}
|
||||
|
||||
if expectAud := strings.TrimSpace(ctx.String("expect-audience")); expectAud != "" {
|
||||
expected.Audience = expectAud
|
||||
} else if len(claims.Audience) > 0 {
|
||||
expected.Audience = claims.Audience[0]
|
||||
}
|
||||
|
||||
if required := ctx.StringSlice("require-scope"); len(required) > 0 {
|
||||
scopes, scopeErr := normalizeScopes(required)
|
||||
if scopeErr != nil {
|
||||
return scopeErr
|
||||
}
|
||||
expected.Scope = scopes
|
||||
} else {
|
||||
expected.Scope = tokenScopes
|
||||
}
|
||||
|
||||
if _, err := verifyPortalToken(conf, token, expected); err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
verified = true
|
||||
}
|
||||
|
||||
if ctx.Bool("json") {
|
||||
payload := map[string]any{
|
||||
"token": token,
|
||||
"verified": verified,
|
||||
"header": header,
|
||||
"claims": claims,
|
||||
}
|
||||
return printJSON(payload)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("JWT header:")
|
||||
for k, v := range header {
|
||||
fmt.Printf(" %s: %v\n", k, v)
|
||||
}
|
||||
|
||||
fmt.Println("\nJWT claims:")
|
||||
fmt.Printf(" issuer: %s\n", claims.Issuer)
|
||||
fmt.Printf(" subject: %s\n", claims.Subject)
|
||||
fmt.Printf(" audience: %s\n", strings.Join(claims.Audience, " "))
|
||||
fmt.Printf(" scope: %s\n", strings.Join(tokenScopes, " "))
|
||||
if claims.IssuedAt != nil {
|
||||
fmt.Printf(" issuedAt: %s\n", claims.IssuedAt.Time.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
fmt.Printf(" expiresAt: %s\n", claims.ExpiresAt.Time.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if claims.NotBefore != nil {
|
||||
fmt.Printf(" notBefore: %s\n", claims.NotBefore.Time.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if claims.ID != "" {
|
||||
fmt.Printf(" jti: %s\n", claims.ID)
|
||||
}
|
||||
|
||||
if verified {
|
||||
fmt.Println("\nSignature: verified")
|
||||
} else {
|
||||
fmt.Println("\nSignature: not verified (skipped)")
|
||||
}
|
||||
|
||||
fmt.Printf("\nToken:\n%s\n\n", token)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// readTokenInput loads the token from CLI args, file, or STDIN.
|
||||
func readTokenInput(ctx *cli.Context) (string, error) {
|
||||
if file := strings.TrimSpace(ctx.String("file")); file != "" {
|
||||
data, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return "", cli.Exit(err, 1)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
if ctx.Args().Len() == 0 {
|
||||
return "", cli.Exit(errors.New("token argument required"), 2)
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(ctx.Args().First())
|
||||
if token == "-" {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return "", cli.Exit(err, 1)
|
||||
}
|
||||
token = strings.TrimSpace(string(data))
|
||||
}
|
||||
if token == "" {
|
||||
return "", cli.Exit(errors.New("token argument required"), 2)
|
||||
}
|
||||
return token, nil
|
||||
}
|
117
internal/commands/auth_jwt_issue.go
Normal file
117
internal/commands/auth_jwt_issue.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
)
|
||||
|
||||
// AuthJWTIssueCommand issues portal-signed JWTs for cluster nodes.
|
||||
var AuthJWTIssueCommand = &cli.Command{
|
||||
Name: "issue",
|
||||
Usage: "Issues a portal-signed JWT for a node",
|
||||
Flags: []cli.Flag{
|
||||
&cli.StringFlag{Name: "node", Aliases: []string{"n"}, Usage: "target node uuid, client id, or DNS label", Required: true},
|
||||
&cli.StringSliceFlag{Name: "scope", Aliases: []string{"s"}, Usage: "token scope", Value: cli.NewStringSlice("cluster")},
|
||||
&cli.DurationFlag{Name: "ttl", Usage: "token lifetime", Value: clusterjwt.TokenTTL},
|
||||
&cli.StringFlag{Name: "subject", Usage: "token subject (default portal:<clusterUUID>)"},
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTIssueAction,
|
||||
}
|
||||
|
||||
// authJWTIssueAction handles CLI issuance of portal-signed JWTs for nodes.
|
||||
func authJWTIssueAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node, err := resolveNode(conf, ctx.String("node"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scopes, err := normalizeScopes(ctx.StringSlice("scope"), "cluster")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ttl := ctx.Duration("ttl")
|
||||
if ttl <= 0 {
|
||||
ttl = clusterjwt.TokenTTL
|
||||
}
|
||||
|
||||
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
|
||||
if clusterUUID == "" {
|
||||
return cli.Exit(fmt.Errorf("cluster uuid not configured"), 1)
|
||||
}
|
||||
|
||||
subject := strings.TrimSpace(ctx.String("subject"))
|
||||
if subject == "" {
|
||||
subject = fmt.Sprintf("portal:%s", clusterUUID)
|
||||
}
|
||||
|
||||
var token string
|
||||
if subject == fmt.Sprintf("portal:%s", clusterUUID) {
|
||||
token, err = get.IssuePortalJWTForNode(node.UUID, scopes, ttl)
|
||||
} else {
|
||||
spec := clusterjwt.ClaimsSpec{
|
||||
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Subject: subject,
|
||||
Audience: fmt.Sprintf("node:%s", node.UUID),
|
||||
Scope: scopes,
|
||||
TTL: ttl,
|
||||
}
|
||||
token, err = get.IssuePortalJWT(spec)
|
||||
}
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
header, claims, err := decodeJWTClaims(token)
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
if ctx.Bool("json") {
|
||||
payload := map[string]any{
|
||||
"token": token,
|
||||
"header": header,
|
||||
"claims": claims,
|
||||
"node": map[string]string{
|
||||
"uuid": node.UUID,
|
||||
"clientId": node.ClientID,
|
||||
"name": node.Name,
|
||||
"role": string(node.Role),
|
||||
},
|
||||
}
|
||||
return printJSON(payload)
|
||||
}
|
||||
|
||||
expires := "unknown"
|
||||
if claims.ExpiresAt != nil {
|
||||
expires = claims.ExpiresAt.Time.UTC().Format(time.RFC3339)
|
||||
}
|
||||
audience := strings.Join(claims.Audience, " ")
|
||||
if audience == "" {
|
||||
audience = "(none)"
|
||||
}
|
||||
|
||||
fmt.Printf("\nIssued JWT for node %s (%s)\n", node.Name, node.UUID)
|
||||
fmt.Printf("Scopes: %s\n", strings.Join(scopes, " "))
|
||||
fmt.Printf("Expires: %s\n", expires)
|
||||
fmt.Printf("Audience: %s\n", audience)
|
||||
fmt.Printf("Subject: %s\n", claims.Subject)
|
||||
fmt.Printf("Key ID: %v\n", header["kid"])
|
||||
fmt.Printf("\n%s\n", token)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
107
internal/commands/auth_jwt_keys.go
Normal file
107
internal/commands/auth_jwt_keys.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
)
|
||||
|
||||
// AuthJWTKeysCommand groups JWT key management helpers.
|
||||
var AuthJWTKeysCommand = &cli.Command{
|
||||
Name: "keys",
|
||||
Usage: "JWT signing key helpers",
|
||||
Subcommands: []*cli.Command{
|
||||
AuthJWTKeysListCommand,
|
||||
},
|
||||
}
|
||||
|
||||
// AuthJWTKeysListCommand lists JWT signing keys.
|
||||
var AuthJWTKeysListCommand = &cli.Command{
|
||||
Name: "ls",
|
||||
Usage: "Lists JWT signing keys",
|
||||
Aliases: []string{"list"},
|
||||
ArgsUsage: "",
|
||||
Flags: []cli.Flag{
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTKeysListAction,
|
||||
}
|
||||
|
||||
// authJWTKeysListAction lists portal signing keys with metadata.
|
||||
func authJWTKeysListAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager := get.JWTManager()
|
||||
if manager == nil {
|
||||
return cli.Exit(errors.New("jwt manager not available"), 1)
|
||||
}
|
||||
|
||||
keys := manager.AllKeys()
|
||||
active, _ := manager.ActiveKey()
|
||||
activeKid := ""
|
||||
if active != nil {
|
||||
activeKid = active.Kid
|
||||
}
|
||||
|
||||
type keyInfo struct {
|
||||
Kid string `json:"kid"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
NotAfter string `json:"notAfter,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
rows := make([]keyInfo, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
info := keyInfo{Kid: k.Kid, Active: k.Kid == activeKid}
|
||||
if k.CreatedAt > 0 {
|
||||
info.CreatedAt = time.Unix(k.CreatedAt, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
if k.NotAfter > 0 {
|
||||
info.NotAfter = time.Unix(k.NotAfter, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
rows = append(rows, info)
|
||||
}
|
||||
|
||||
if ctx.Bool("json") {
|
||||
payload := map[string]any{
|
||||
"keys": rows,
|
||||
}
|
||||
return printJSON(payload)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
fmt.Println()
|
||||
fmt.Println("No signing keys found.")
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("JWT signing keys:")
|
||||
for _, row := range rows {
|
||||
status := ""
|
||||
if row.Active {
|
||||
status = " (active)"
|
||||
}
|
||||
parts := []string{fmt.Sprintf("KID: %s%s", row.Kid, status)}
|
||||
if row.CreatedAt != "" {
|
||||
parts = append(parts, fmt.Sprintf("created %s", row.CreatedAt))
|
||||
}
|
||||
if row.NotAfter != "" {
|
||||
parts = append(parts, fmt.Sprintf("expires %s", row.NotAfter))
|
||||
}
|
||||
fmt.Printf("- %s\n", strings.Join(parts, ", "))
|
||||
}
|
||||
fmt.Println()
|
||||
return nil
|
||||
})
|
||||
}
|
67
internal/commands/auth_jwt_status.go
Normal file
67
internal/commands/auth_jwt_status.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
)
|
||||
|
||||
// AuthJWTStatusCommand reports verifier cache diagnostics.
|
||||
var AuthJWTStatusCommand = &cli.Command{
|
||||
Name: "status",
|
||||
Usage: "Shows JWT verifier cache status",
|
||||
Flags: []cli.Flag{
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTStatusAction,
|
||||
}
|
||||
|
||||
// authJWTStatusAction prints JWKS cache diagnostics for the current node.
|
||||
func authJWTStatusAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
verifier := get.JWTVerifier()
|
||||
if verifier == nil {
|
||||
return cli.Exit(errors.New("jwt verifier not available"), 1)
|
||||
}
|
||||
|
||||
ttl := time.Duration(conf.JWKSCacheTTL()) * time.Second
|
||||
status := verifier.Status(ttl)
|
||||
status.JWKSURL = strings.TrimSpace(conf.JWKSUrl())
|
||||
|
||||
if ctx.Bool("json") {
|
||||
return printJSON(status)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Printf("JWKS URL: %s\n", status.JWKSURL)
|
||||
fmt.Printf("Cache Path: %s\n", status.CachePath)
|
||||
fmt.Printf("Cache URL: %s\n", status.CacheURL)
|
||||
fmt.Printf("Cache ETag: %s\n", status.CacheETag)
|
||||
fmt.Printf("Cached Keys: %d\n", status.KeyCount)
|
||||
if len(status.KeyIDs) > 0 {
|
||||
fmt.Printf("Key IDs: %s\n", strings.Join(status.KeyIDs, ", "))
|
||||
}
|
||||
if !status.CacheFetchedAt.IsZero() {
|
||||
fmt.Printf("Last Fetch: %s\n", status.CacheFetchedAt.Format(time.RFC3339))
|
||||
} else {
|
||||
fmt.Println("Last Fetch: never")
|
||||
}
|
||||
fmt.Printf("Cache Age: %ds\n", status.CacheAgeSeconds)
|
||||
if status.CacheTTLSeconds > 0 {
|
||||
fmt.Printf("Cache TTL: %ds\n", status.CacheTTLSeconds)
|
||||
}
|
||||
if status.CacheStale {
|
||||
fmt.Println("Cache Status: STALE")
|
||||
} else {
|
||||
fmt.Println("Cache Status: fresh")
|
||||
}
|
||||
fmt.Println()
|
||||
return nil
|
||||
})
|
||||
}
|
94
internal/commands/auth_jwt_test.go
Normal file
94
internal/commands/auth_jwt_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
"github.com/photoprism/photoprism/internal/service/cluster"
|
||||
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
func TestAuthJWTCommands(t *testing.T) {
|
||||
conf := get.Config()
|
||||
|
||||
origEdition := conf.Options().Edition
|
||||
origRole := conf.Options().NodeRole
|
||||
origUUID := conf.Options().ClusterUUID
|
||||
origPortal := conf.Options().PortalUrl
|
||||
origJWKS := conf.JWKSUrl()
|
||||
|
||||
conf.Options().Edition = config.Portal
|
||||
conf.Options().NodeRole = string(cluster.RolePortal)
|
||||
conf.Options().ClusterUUID = "11111111-1111-4111-8111-111111111111"
|
||||
conf.Options().PortalUrl = "https://portal.test"
|
||||
conf.SetJWKSUrl("https://portal.test/.well-known/jwks.json")
|
||||
|
||||
get.SetConfig(conf)
|
||||
conf.RegisterDb()
|
||||
|
||||
require.True(t, conf.IsPortal())
|
||||
|
||||
manager := get.JWTManager()
|
||||
require.NotNil(t, manager)
|
||||
_, err := manager.EnsureActiveKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
registry, err := reg.NewClientRegistryWithConfig(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodeUUID := rnd.UUID()
|
||||
node := ®.Node{}
|
||||
node.UUID = nodeUUID
|
||||
node.Name = "pp-node-01"
|
||||
node.Role = string(cluster.RoleInstance)
|
||||
require.NoError(t, registry.Put(node))
|
||||
t.Cleanup(func() {
|
||||
conf.Options().Edition = origEdition
|
||||
conf.Options().NodeRole = origRole
|
||||
conf.Options().ClusterUUID = origUUID
|
||||
conf.Options().PortalUrl = origPortal
|
||||
conf.SetJWKSUrl(origJWKS)
|
||||
get.SetConfig(conf)
|
||||
conf.RegisterDb()
|
||||
})
|
||||
|
||||
output, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, output, "Issued JWT")
|
||||
|
||||
jsonOut, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--json"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonOut), &payload))
|
||||
require.NotEmpty(t, payload.Token)
|
||||
|
||||
inspectOut, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", payload.Token})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, inspectOut, "\"verified\": true")
|
||||
|
||||
inspectStrict, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", "--expect-audience", "node:" + nodeUUID, "--require-scope", "cluster", payload.Token})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, inspectStrict, "\"verified\": true")
|
||||
|
||||
keysOut, err := RunWithTestContext(AuthJWTKeysListCommand, []string{"ls", "--json"})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, keysOut, "\"keys\"")
|
||||
|
||||
statusOut, err := RunWithTestContext(AuthJWTStatusCommand, []string{"status"})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, statusOut, "JWKS URL")
|
||||
assert.Contains(t, statusOut, "Cached Keys")
|
||||
|
||||
// invalid scope should fail
|
||||
_, err = RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--scope", "unknown"})
|
||||
require.Error(t, err)
|
||||
}
|
@@ -19,8 +19,9 @@ type healthResponse struct {
|
||||
// ClusterHealthCommand prints a minimal health response (Portal-only).
|
||||
var ClusterHealthCommand = &cli.Command{
|
||||
Name: "health",
|
||||
Usage: "Shows cluster health (Portal-only)",
|
||||
Usage: "Shows cluster health status",
|
||||
Flags: report.CliFlags,
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterHealthAction,
|
||||
}
|
||||
|
||||
|
@@ -14,8 +14,9 @@ import (
|
||||
|
||||
// ClusterNodesCommands groups node subcommands.
|
||||
var ClusterNodesCommands = &cli.Command{
|
||||
Name: "nodes",
|
||||
Usage: "Node registry subcommands",
|
||||
Name: "nodes",
|
||||
Usage: "Node registry subcommands",
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Subcommands: []*cli.Command{
|
||||
ClusterNodesListCommand,
|
||||
ClusterNodesShowCommand,
|
||||
@@ -28,9 +29,10 @@ var ClusterNodesCommands = &cli.Command{
|
||||
// ClusterNodesListCommand lists registered nodes.
|
||||
var ClusterNodesListCommand = &cli.Command{
|
||||
Name: "ls",
|
||||
Usage: "Lists registered cluster nodes (Portal-only)",
|
||||
Usage: "Lists registered cluster nodes",
|
||||
Flags: append(report.CliFlags, CountFlag, OffsetFlag),
|
||||
ArgsUsage: "",
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesListAction,
|
||||
}
|
||||
|
||||
|
@@ -22,9 +22,10 @@ var (
|
||||
// ClusterNodesModCommand updates node fields.
|
||||
var ClusterNodesModCommand = &cli.Command{
|
||||
Name: "mod",
|
||||
Usage: "Updates node properties (Portal-only)",
|
||||
Usage: "Updates node properties",
|
||||
ArgsUsage: "<id|name>",
|
||||
Flags: []cli.Flag{nodesModRoleFlag, nodesModInternal, nodesModLabel, &cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"}},
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesModAction,
|
||||
}
|
||||
|
||||
|
@@ -14,12 +14,13 @@ import (
|
||||
// ClusterNodesRemoveCommand deletes a node from the registry.
|
||||
var ClusterNodesRemoveCommand = &cli.Command{
|
||||
Name: "rm",
|
||||
Usage: "Deletes a node from the registry (Portal-only)",
|
||||
Usage: "Deletes a node from the registry",
|
||||
ArgsUsage: "<id|name>",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"},
|
||||
&cli.BoolFlag{Name: "all-ids", Usage: "delete all records that share the same UUID (admin cleanup)"},
|
||||
},
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesRemoveAction,
|
||||
}
|
||||
|
||||
|
@@ -2,6 +2,7 @@ package commands
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -106,11 +107,13 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
|
||||
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
|
||||
var resp cluster.RegisterResponse
|
||||
if err := postWithBackoff(url, token, b, &resp); err != nil {
|
||||
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
|
||||
// Map common HTTP errors similarly to register command
|
||||
if he, ok := err.(*httpError); ok {
|
||||
var he *httpError
|
||||
if errors.As(err, &he) {
|
||||
switch he.Status {
|
||||
case 401, 403:
|
||||
return cli.Exit(fmt.Errorf("%s", he.Error()), 4)
|
||||
@@ -151,6 +154,7 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
|
||||
fmt.Printf("DSN: %s\n", resp.Database.DSN)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
@@ -15,9 +15,10 @@ import (
|
||||
// ClusterNodesShowCommand shows node details.
|
||||
var ClusterNodesShowCommand = &cli.Command{
|
||||
Name: "show",
|
||||
Usage: "Shows node details (Portal-only)",
|
||||
Usage: "Shows node details",
|
||||
ArgsUsage: "<id|name>",
|
||||
Flags: report.CliFlags,
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesShowAction,
|
||||
}
|
||||
|
||||
|
@@ -24,7 +24,7 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/txt/report"
|
||||
)
|
||||
|
||||
// flags for register
|
||||
// Supported cluster node register flags.
|
||||
var (
|
||||
regNameFlag = &cli.StringFlag{Name: "name", Usage: "node `NAME` (lowercase letters, digits, hyphens)"}
|
||||
regRoleFlag = &cli.StringFlag{Name: "role", Usage: "node `ROLE` (instance, service)", Value: "instance"}
|
||||
@@ -42,7 +42,7 @@ var (
|
||||
// ClusterRegisterCommand registers a node with the Portal via HTTP.
|
||||
var ClusterRegisterCommand = &cli.Command{
|
||||
Name: "register",
|
||||
Usage: "Registers/rotates a node via Portal (HTTP)",
|
||||
Usage: "Registers a node or updates its credentials within a cluster",
|
||||
Flags: append(append([]cli.Flag{regNameFlag, regRoleFlag, regIntUrlFlag, regLabelFlag, regRotateDatabase, regRotateSec, regPortalURL, regPortalTok, regWriteConf, regForceFlag, regDryRun}, report.CliFlags...)),
|
||||
Action: clusterRegisterAction,
|
||||
}
|
||||
@@ -52,15 +52,18 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
// Resolve inputs
|
||||
name := clean.DNSLabel(ctx.String("name"))
|
||||
derivedName := false
|
||||
|
||||
if name == "" { // default from config if set
|
||||
name = clean.DNSLabel(conf.NodeName())
|
||||
if name != "" {
|
||||
derivedName = true
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return cli.Exit(fmt.Errorf("node name is required (use --name or set node-name)"), 2)
|
||||
}
|
||||
|
||||
nodeRole := clean.TypeLowerDash(ctx.String("role"))
|
||||
switch nodeRole {
|
||||
case "instance", "service":
|
||||
@@ -76,7 +79,6 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
derivedPortal = true
|
||||
}
|
||||
}
|
||||
// In dry-run, we allow empty portalURL (will print derived/empty values).
|
||||
|
||||
// Derive advertise/site URLs when omitted.
|
||||
advertise := ctx.String("advertise-url")
|
||||
@@ -93,17 +95,20 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
RotateDatabase: ctx.Bool("rotate"),
|
||||
RotateSecret: ctx.Bool("rotate-secret"),
|
||||
}
|
||||
|
||||
// If we already have client credentials (e.g., re-register), include them so the
|
||||
// portal can verify and authorize UUID/name moves or metadata updates.
|
||||
if id, secret := strings.TrimSpace(conf.NodeClientID()), strings.TrimSpace(conf.NodeClientSecret()); id != "" && secret != "" {
|
||||
payload.ClientID = id
|
||||
payload.ClientSecret = secret
|
||||
}
|
||||
|
||||
if site != "" && site != advertise {
|
||||
payload.SiteUrl = site
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
|
||||
// In dry-run, we allow empty portalURL (will print derived/empty values).
|
||||
if ctx.Bool("dry-run") {
|
||||
if ctx.Bool("json") {
|
||||
out := map[string]any{"portalUrl": portalURL, "payload": payload}
|
||||
@@ -140,18 +145,22 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
if portalURL == "" {
|
||||
return cli.Exit(fmt.Errorf("portal URL is required (use --portal-url or set portal-url)"), 2)
|
||||
}
|
||||
|
||||
token := ctx.String("join-token")
|
||||
|
||||
if token == "" {
|
||||
token = conf.JoinToken()
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return cli.Exit(fmt.Errorf("portal token is required (use --join-token or set join-token)"), 2)
|
||||
}
|
||||
|
||||
// POST with bounded backoff on 429
|
||||
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
|
||||
var resp cluster.RegisterResponse
|
||||
if err := postWithBackoff(url, token, b, &resp); err != nil {
|
||||
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
|
||||
var httpErr *httpError
|
||||
if errors.As(err, &httpErr) && httpErr.Status == http.StatusTooManyRequests {
|
||||
return cli.Exit(fmt.Errorf("portal rate-limited registration attempts"), 6)
|
||||
@@ -179,13 +188,17 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
} else {
|
||||
// Human-readable: node row and credentials if present (UUID first as primary identifier)
|
||||
cols := []string{"UUID", "ClientID", "Name", "Role", "DB Driver", "DB Name", "DB User", "Host", "Port"}
|
||||
|
||||
var dbName, dbUser string
|
||||
|
||||
if resp.Database.Name != "" {
|
||||
dbName = resp.Database.Name
|
||||
}
|
||||
|
||||
if resp.Database.User != "" {
|
||||
dbUser = resp.Database.User
|
||||
}
|
||||
|
||||
rows := [][]string{{resp.Node.UUID, resp.Node.ClientID, resp.Node.Name, resp.Node.Role, resp.Database.Driver, dbName, dbUser, resp.Database.Host, fmt.Sprintf("%d", resp.Database.Port)}}
|
||||
out, _ := report.RenderFormat(rows, cols, report.CliFormat(ctx))
|
||||
fmt.Printf("\n%s\n", out)
|
||||
|
@@ -13,11 +13,12 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/txt/report"
|
||||
)
|
||||
|
||||
// ClusterSummaryCommand prints a minimal cluster summary (Portal-only).
|
||||
// ClusterSummaryCommand prints a minimal cluster summary.
|
||||
var ClusterSummaryCommand = &cli.Command{
|
||||
Name: "summary",
|
||||
Usage: "Shows cluster summary (Portal-only)",
|
||||
Usage: "Shows cluster summary",
|
||||
Flags: report.CliFlags,
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterSummaryAction,
|
||||
}
|
||||
|
||||
|
@@ -97,7 +97,6 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
|
||||
|
||||
// Ensure DB connection is open for each command run (some commands call Shutdown).
|
||||
if c := get.Config(); c != nil {
|
||||
_ = c.Init() // safe to call; re-opens DB if needed
|
||||
c.RegisterDb() // (re)register provider
|
||||
}
|
||||
|
||||
@@ -110,5 +109,11 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
|
||||
err = cmd.Run(ctx, args...)
|
||||
})
|
||||
|
||||
// Re-open the database after the command completed so follow-up checks
|
||||
// (potentially issued by the test itself) have an active connection.
|
||||
if c := get.Config(); c != nil {
|
||||
c.RegisterDb()
|
||||
}
|
||||
|
||||
return output, err
|
||||
}
|
||||
|
@@ -81,9 +81,10 @@ func TestDownloadImpl_FileMethod_AutoSkipsRemux(t *testing.T) {
|
||||
if conf == nil {
|
||||
t.Fatalf("missing test config")
|
||||
}
|
||||
|
||||
// Ensure DB is initialized and registered (bypassing CLI InitConfig)
|
||||
_ = conf.Init()
|
||||
conf.RegisterDb()
|
||||
|
||||
// Override yt-dlp after config init (config may set dl.YtDlpBin)
|
||||
dl.YtDlpBin = fake
|
||||
t.Logf("using yt-dlp binary: %s", dl.YtDlpBin)
|
||||
@@ -125,7 +126,6 @@ func TestDownloadImpl_FileMethod_Skip_NoRemux(t *testing.T) {
|
||||
if conf == nil {
|
||||
t.Fatalf("missing test config")
|
||||
}
|
||||
_ = conf.Init()
|
||||
conf.RegisterDb()
|
||||
dl.YtDlpBin = fake
|
||||
|
||||
@@ -196,8 +196,9 @@ func TestDownloadImpl_FileMethod_Always_RemuxFails(t *testing.T) {
|
||||
if conf == nil {
|
||||
t.Fatalf("missing test config")
|
||||
}
|
||||
_ = conf.Init()
|
||||
|
||||
conf.RegisterDb()
|
||||
|
||||
dl.YtDlpBin = fake
|
||||
|
||||
err := runDownload(conf, DownloadOpts{
|
||||
|
8
internal/commands/flags.go
Normal file
8
internal/commands/flags.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package commands
|
||||
|
||||
import "github.com/urfave/cli/v2"
|
||||
|
||||
// JsonFlag returns the shared CLI flag definition for JSON output across commands.
|
||||
func JsonFlag() *cli.BoolFlag {
|
||||
return &cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"}
|
||||
}
|
173
internal/commands/jwt_helpers.go
Normal file
173
internal/commands/jwt_helpers.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
var allowedJWTScope = func() map[string]struct{} {
|
||||
out := make(map[string]struct{}, len(acl.ResourceNames))
|
||||
for _, res := range acl.ResourceNames {
|
||||
out[res.String()] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}()
|
||||
|
||||
// requirePortal returns a CLI error when the active configuration is not a portal node.
|
||||
func requirePortal(conf *config.Config) error {
|
||||
if conf == nil || !conf.IsPortal() {
|
||||
return cli.Exit(errors.New("command requires a Portal node"), 2)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveNode finds a node by UUID, client ID, or DNS label using the portal registry.
|
||||
func resolveNode(conf *config.Config, identifier string) (*reg.Node, error) {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := strings.TrimSpace(identifier)
|
||||
if key == "" {
|
||||
return nil, cli.Exit(errors.New("node identifier required"), 2)
|
||||
}
|
||||
|
||||
registry, err := reg.NewClientRegistryWithConfig(conf)
|
||||
if err != nil {
|
||||
return nil, cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
if node, err := registry.FindByNodeUUID(key); err == nil && node != nil {
|
||||
return node, nil
|
||||
}
|
||||
if node, err := registry.FindByClientID(key); err == nil && node != nil {
|
||||
return node, nil
|
||||
}
|
||||
|
||||
name := clean.DNSLabel(key)
|
||||
if name == "" {
|
||||
return nil, cli.Exit(errors.New("invalid node identifier"), 2)
|
||||
}
|
||||
|
||||
node, err := registry.FindByName(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, reg.ErrNotFound) {
|
||||
return nil, cli.Exit(fmt.Errorf("node %q not found", identifier), 3)
|
||||
}
|
||||
return nil, cli.Exit(err, 1)
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// decodeJWTClaims decodes the compact JWT and returns header and claims without verifying the signature.
|
||||
func decodeJWTClaims(token string) (map[string]any, *clusterjwt.Claims, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, nil, errors.New("jwt: token must contain three segments")
|
||||
}
|
||||
|
||||
decode := func(segment string) ([]byte, error) {
|
||||
return base64.RawURLEncoding.DecodeString(segment)
|
||||
}
|
||||
|
||||
headerBytes, err := decode(parts[0])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
payloadBytes, err := decode(parts[1])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var header map[string]any
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
claims := &clusterjwt.Claims{}
|
||||
if err := json.Unmarshal(payloadBytes, claims); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return header, claims, nil
|
||||
}
|
||||
|
||||
// verifyPortalToken verifies a JWT using the portal's in-memory key manager.
|
||||
func verifyPortalToken(conf *config.Config, token string, expected clusterjwt.ExpectedClaims) (*clusterjwt.Claims, error) {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager := get.JWTManager()
|
||||
if manager == nil {
|
||||
return nil, cli.Exit(errors.New("jwt issuer not available"), 1)
|
||||
}
|
||||
|
||||
jwks := manager.JWKS()
|
||||
if jwks == nil || len(jwks.Keys) == 0 {
|
||||
return nil, cli.Exit(errors.New("jwks key set is empty"), 1)
|
||||
}
|
||||
|
||||
leeway := time.Duration(conf.JWTLeeway()) * time.Second
|
||||
if leeway <= 0 {
|
||||
leeway = 60 * time.Second
|
||||
}
|
||||
|
||||
claims, err := clusterjwt.VerifyTokenWithKeys(token, expected, jwks.Keys, leeway)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// normalizeScopes trims and de-duplicates scope values, falling back to defaults when necessary.
|
||||
func normalizeScopes(values []string, defaults ...string) ([]string, error) {
|
||||
src := values
|
||||
if len(src) == 0 {
|
||||
src = defaults
|
||||
}
|
||||
out := make([]string, 0, len(src))
|
||||
seen := make(map[string]struct{}, len(src))
|
||||
for _, raw := range src {
|
||||
for _, parsed := range clean.Scopes(raw) {
|
||||
scope := clean.Scope(parsed)
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[scope]; exists {
|
||||
continue
|
||||
}
|
||||
if _, ok := allowedJWTScope[scope]; !ok {
|
||||
return nil, cli.Exit(fmt.Errorf("unsupported scope %q", scope), 2)
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
out = append(out, scope)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, cli.Exit(errors.New("at least one scope is required"), 2)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// printJSON pretty-prints the payload as JSON.
|
||||
func printJSON(payload any) error {
|
||||
data, err := json.MarshalIndent(payload, "", " ")
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
fmt.Printf("%s\n", data)
|
||||
return nil
|
||||
}
|
@@ -16,7 +16,7 @@ var ShowCommandsCommand = &cli.Command{
|
||||
Name: "commands",
|
||||
Usage: "Displays a structured catalog of CLI commands",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"},
|
||||
JsonFlag(),
|
||||
&cli.BoolFlag{Name: "all", Usage: "include hidden commands and flags"},
|
||||
&cli.BoolFlag{Name: "short", Usage: "omit flags in Markdown output"},
|
||||
&cli.IntFlag{Name: "base-heading", Value: 2, Usage: "base Markdown heading level"},
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -43,9 +44,9 @@ func statusAction(ctx *cli.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
|
||||
endpointUrl := buildStatusEndpoint(conf)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
req, err := http.NewRequest(http.MethodGet, endpointUrl, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -53,12 +54,12 @@ func statusAction(ctx *cli.Context) error {
|
||||
|
||||
var status string
|
||||
|
||||
if resp, err := client.Do(req); err != nil {
|
||||
if resp, reqErr := client.Do(req); reqErr != nil {
|
||||
return fmt.Errorf("cannot connect to %s:%d", conf.HttpHost(), conf.HttpPort())
|
||||
} else if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("server running at %s:%d, bad status %d\n", conf.HttpHost(), conf.HttpPort(), resp.StatusCode)
|
||||
} else if body, err := io.ReadAll(resp.Body); err != nil {
|
||||
return err
|
||||
} else if body, readErr := io.ReadAll(resp.Body); readErr != nil {
|
||||
return readErr
|
||||
} else {
|
||||
status = string(body)
|
||||
}
|
||||
@@ -73,3 +74,21 @@ func statusAction(ctx *cli.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildStatusEndpoint returns the status endpoint URL, preferring the public
|
||||
// SiteUrl (which carries the correct scheme) and falling back to the local
|
||||
// HTTP host/port. When a Unix socket is configured, an http+unix style URL is
|
||||
// used so the custom transport can dial the socket.
|
||||
func buildStatusEndpoint(conf *config.Config) string {
|
||||
if socket := conf.HttpSocket(); socket != nil {
|
||||
return fmt.Sprintf("%s://%s/api/v1/status", socket.Scheme, strings.TrimPrefix(socket.Path, "/"))
|
||||
}
|
||||
|
||||
siteUrl := strings.TrimRight(conf.SiteUrl(), "/")
|
||||
|
||||
if siteUrl != "" {
|
||||
return siteUrl + "/api/v1/status"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
|
||||
}
|
||||
|
@@ -2,6 +2,8 @@ package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
urlpkg "net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -222,7 +224,36 @@ func (c *Config) SetJWKSUrl(url string) {
|
||||
if c == nil || c.options == nil {
|
||||
return
|
||||
}
|
||||
c.options.JWKSUrl = strings.TrimSpace(url)
|
||||
|
||||
trimmed := strings.TrimSpace(url)
|
||||
if trimmed == "" {
|
||||
c.options.JWKSUrl = ""
|
||||
return
|
||||
}
|
||||
|
||||
parsed, err := urlpkg.Parse(trimmed)
|
||||
if err != nil || parsed == nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
log.Warnf("config: ignoring JWKS URL %q (%v)", trimmed, err)
|
||||
return
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
host := parsed.Hostname()
|
||||
|
||||
switch scheme {
|
||||
case "https":
|
||||
// Always allowed.
|
||||
case "http":
|
||||
if !isLoopbackHost(host) {
|
||||
log.Warnf("config: rejecting JWKS URL %q (http only allowed for localhost/loopback)", trimmed)
|
||||
return
|
||||
}
|
||||
default:
|
||||
log.Warnf("config: rejecting JWKS URL %q (unsupported scheme)", trimmed)
|
||||
return
|
||||
}
|
||||
|
||||
c.options.JWKSUrl = trimmed
|
||||
}
|
||||
|
||||
// JWKSCacheTTL returns the JWKS cache lifetime in seconds (default 300, max 3600).
|
||||
@@ -261,6 +292,23 @@ func (c *Config) AdvertiseUrl() string {
|
||||
return c.SiteUrl()
|
||||
}
|
||||
|
||||
// isLoopbackHost returns true when host represents localhost or a loopback IP.
|
||||
func isLoopbackHost(host string) bool {
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SaveClusterUUID writes or updates the ClusterUUID key in options.yml without
|
||||
// touching unrelated keys. Creates the file and directories if needed.
|
||||
func (c *Config) SaveClusterUUID(uuid string) error {
|
||||
|
@@ -73,15 +73,77 @@ func TestConfig_Cluster(t *testing.T) {
|
||||
c.Options().NodeRole = ""
|
||||
})
|
||||
t.Run("JWKSUrlSetter", func(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
c.options.JWKSUrl = ""
|
||||
assert.Equal(t, "", c.JWKSUrl())
|
||||
const existing = "https://existing.example/.well-known/jwks.json"
|
||||
tests := []struct {
|
||||
name string
|
||||
prev string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "TrimHTTPS",
|
||||
prev: "",
|
||||
input: " https://portal.example/.well-known/jwks.json ",
|
||||
expect: "https://portal.example/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "CaseInsensitiveScheme",
|
||||
prev: "",
|
||||
input: "HTTPS://portal.example/.well-known/jwks.json",
|
||||
expect: "HTTPS://portal.example/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "AllowHTTPOnLocalhost",
|
||||
prev: "",
|
||||
input: "http://localhost:2342/.well-known/jwks.json",
|
||||
expect: "http://localhost:2342/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "AllowHTTPOnLoopbackIPv4",
|
||||
prev: "",
|
||||
input: "http://127.0.0.1/.well-known/jwks.json",
|
||||
expect: "http://127.0.0.1/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "AllowHTTPOnLoopbackIPv6",
|
||||
prev: "",
|
||||
input: "http://[::1]/.well-known/jwks.json",
|
||||
expect: "http://[::1]/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "RejectHTTPNonLoopback",
|
||||
prev: existing,
|
||||
input: "http://portal.example/.well-known/jwks.json",
|
||||
expect: existing,
|
||||
},
|
||||
{
|
||||
name: "RejectUnsupportedScheme",
|
||||
prev: existing,
|
||||
input: "ftp://portal.example/.well-known/jwks.json",
|
||||
expect: existing,
|
||||
},
|
||||
{
|
||||
name: "RejectMalformedURL",
|
||||
prev: existing,
|
||||
input: "://not-a-url",
|
||||
expect: existing,
|
||||
},
|
||||
{
|
||||
name: "ClearValue",
|
||||
prev: existing,
|
||||
input: "",
|
||||
expect: "",
|
||||
},
|
||||
}
|
||||
|
||||
c.SetJWKSUrl(" https://portal.example/.well-known/jwks.json ")
|
||||
assert.Equal(t, "https://portal.example/.well-known/jwks.json", c.JWKSUrl())
|
||||
|
||||
c.SetJWKSUrl("")
|
||||
assert.Equal(t, "", c.JWKSUrl())
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
c.options.JWKSUrl = tc.prev
|
||||
c.SetJWKSUrl(tc.input)
|
||||
assert.Equal(t, tc.expect, c.JWKSUrl())
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("Paths", func(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
|
@@ -342,12 +342,18 @@ func (c *Config) SetDbOptions() {
|
||||
case Postgres:
|
||||
// Ignore for now.
|
||||
case SQLite3:
|
||||
// Not required as unicode is default.
|
||||
// Not required as Unicode is default.
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterDb sets the database options and connection provider.
|
||||
// RegisterDb opens a database connection if needed,
|
||||
// sets the database options and connection provider.
|
||||
func (c *Config) RegisterDb() {
|
||||
if err := c.connectDb(); err != nil {
|
||||
log.Errorf("config: %s (register db)")
|
||||
return
|
||||
}
|
||||
|
||||
c.SetDbOptions()
|
||||
entity.SetDbProvider(c)
|
||||
}
|
||||
@@ -456,6 +462,11 @@ func (c *Config) connectDb() error {
|
||||
mutex.Db.Lock()
|
||||
defer mutex.Db.Unlock()
|
||||
|
||||
// Database connection already exists.
|
||||
if c.db != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get database driver and data source name.
|
||||
dbDriver := c.DatabaseDriver()
|
||||
dbDsn := c.DatabaseDSN()
|
||||
|
@@ -227,13 +227,16 @@ func TestDownloadPlaylistEntry(t *testing.T) {
|
||||
}
|
||||
|
||||
// Download the same file but with the direct link
|
||||
url := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
|
||||
dlUrl := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
|
||||
|
||||
stderrBuf = &bytes.Buffer{}
|
||||
r, err = NewMetadata(context.Background(), url, Options{
|
||||
|
||||
r, err = NewMetadata(context.Background(), dlUrl, Options{
|
||||
StderrFn: func(cmd *exec.Cmd) io.Writer {
|
||||
return stderrBuf
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@@ -1,30 +1,42 @@
|
||||
package get
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
var (
|
||||
onceJWTManager sync.Once
|
||||
onceJWTIssuer sync.Once
|
||||
onceJWTManager sync.Once
|
||||
onceJWTIssuer sync.Once
|
||||
onceJWTVerifier sync.Once
|
||||
)
|
||||
|
||||
// initJWTManager lazily initializes the shared portal key manager for JWT issuance.
|
||||
func initJWTManager() {
|
||||
conf := Config()
|
||||
if conf == nil || !conf.IsPortal() {
|
||||
if conf == nil {
|
||||
return
|
||||
} else if !conf.IsPortal() {
|
||||
return
|
||||
}
|
||||
|
||||
manager, err := jwt.NewManager(conf)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("jwt: manager init failed (%s)", clean.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := manager.EnsureActiveKey(); err != nil {
|
||||
|
||||
if _, err = manager.EnsureActiveKey(); err != nil {
|
||||
log.Warnf("jwt: ensure signing key failed (%s)", clean.Error(err))
|
||||
}
|
||||
|
||||
services.JWTManager = manager
|
||||
}
|
||||
|
||||
@@ -34,6 +46,7 @@ func JWTManager() *jwt.Manager {
|
||||
return services.JWTManager
|
||||
}
|
||||
|
||||
// initJWTIssuer lazily binds the shared issuer to the active portal key manager.
|
||||
func initJWTIssuer() {
|
||||
manager := JWTManager()
|
||||
if manager == nil {
|
||||
@@ -50,9 +63,98 @@ func JWTIssuer() *jwt.Issuer {
|
||||
|
||||
// JWTVerifier returns a verifier bound to the current config.
|
||||
func JWTVerifier() *jwt.Verifier {
|
||||
conf := Config()
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
return jwt.NewVerifier(conf)
|
||||
onceJWTVerifier.Do(initJWTVerifier)
|
||||
return services.JWTVerifier
|
||||
}
|
||||
|
||||
// VerifyJWT verifies a token using the shared verifier instance.
|
||||
func VerifyJWT(ctx context.Context, token string, expected jwt.ExpectedClaims) (*jwt.Claims, error) {
|
||||
verifier := JWTVerifier()
|
||||
if verifier == nil {
|
||||
return nil, errors.New("jwt: verifier not available")
|
||||
}
|
||||
return verifier.VerifyToken(ctx, token, expected)
|
||||
}
|
||||
|
||||
// initJWTVerifier lazily constructs the shared verifier for the current configuration.
|
||||
func initJWTVerifier() {
|
||||
if conf != nil {
|
||||
services.JWTVerifier = jwt.NewVerifier(conf)
|
||||
}
|
||||
}
|
||||
|
||||
// resetJWTVerifier clears the cached verifier so it can be rebuilt for a new configuration.
|
||||
func resetJWTVerifier() {
|
||||
services.JWTVerifier = nil
|
||||
onceJWTVerifier = sync.Once{}
|
||||
}
|
||||
|
||||
// resetJWTIssuer clears the cached issuer so it can be recreated for a new configuration.
|
||||
func resetJWTIssuer() {
|
||||
services.JWTIssuer = nil
|
||||
onceJWTIssuer = sync.Once{}
|
||||
}
|
||||
|
||||
// resetJWTManager clears the cached key manager so subsequent calls reload keys for the active configuration.
|
||||
func resetJWTManager() {
|
||||
services.JWTManager = nil
|
||||
onceJWTManager = sync.Once{}
|
||||
}
|
||||
|
||||
// resetJWT clears all cached JWT helpers.
|
||||
func resetJWT() {
|
||||
resetJWTVerifier()
|
||||
resetJWTIssuer()
|
||||
resetJWTManager()
|
||||
}
|
||||
|
||||
// IssuePortalJWT signs a token using the shared portal issuer with the provided claims.
|
||||
func IssuePortalJWT(spec jwt.ClaimsSpec) (string, error) {
|
||||
if issuer := JWTIssuer(); issuer == nil {
|
||||
return "", errors.New("jwt: issuer not available")
|
||||
} else {
|
||||
return issuer.Issue(spec)
|
||||
}
|
||||
}
|
||||
|
||||
// IssuePortalJWTForNode issues a portal-signed JWT targeting the specified node UUID.
|
||||
func IssuePortalJWTForNode(nodeUUID string, scopes []string, ttl time.Duration) (string, error) {
|
||||
if conf == nil {
|
||||
return "", errors.New("jwt: missing config")
|
||||
} else if !conf.IsPortal() {
|
||||
return "", errors.New("jwt: not supported on nodes")
|
||||
}
|
||||
|
||||
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
|
||||
if clusterUUID == "" {
|
||||
return "", errors.New("jwt: cluster uuid not configured")
|
||||
}
|
||||
|
||||
nodeUUID = strings.TrimSpace(nodeUUID)
|
||||
if nodeUUID == "" {
|
||||
return "", errors.New("jwt: node uuid required")
|
||||
}
|
||||
if len(scopes) == 0 {
|
||||
return "", errors.New("jwt: at least one scope is required")
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(scopes))
|
||||
for _, s := range scopes {
|
||||
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return "", errors.New("jwt: at least one scope is required")
|
||||
}
|
||||
|
||||
spec := jwt.ClaimsSpec{
|
||||
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Subject: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Audience: fmt.Sprintf("node:%s", nodeUUID),
|
||||
Scope: normalized,
|
||||
TTL: ttl,
|
||||
}
|
||||
|
||||
return IssuePortalJWT(spec)
|
||||
}
|
||||
|
39
internal/photoprism/get/jwt_test.go
Normal file
39
internal/photoprism/get/jwt_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package get
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
)
|
||||
|
||||
func TestJWTVerifierReuse(t *testing.T) {
|
||||
verifier1 := JWTVerifier()
|
||||
require.NotNil(t, verifier1)
|
||||
|
||||
verifier2 := JWTVerifier()
|
||||
require.NotNil(t, verifier2)
|
||||
|
||||
assert.Same(t, verifier1, verifier2)
|
||||
}
|
||||
|
||||
func TestJWTVerifierResetOnConfigChange(t *testing.T) {
|
||||
orig := Config()
|
||||
verifier1 := JWTVerifier()
|
||||
require.NotNil(t, verifier1)
|
||||
|
||||
tempConf := config.NewMinimalTestConfigWithDb("jwt-reset", t.TempDir())
|
||||
SetConfig(tempConf)
|
||||
t.Cleanup(func() {
|
||||
SetConfig(orig)
|
||||
tempConf.CloseDb()
|
||||
orig.RegisterDb()
|
||||
})
|
||||
|
||||
verifier2 := JWTVerifier()
|
||||
require.NotNil(t, verifier2)
|
||||
|
||||
assert.NotSame(t, verifier1, verifier2)
|
||||
}
|
@@ -57,13 +57,17 @@ var services struct {
|
||||
OIDC *oidc.Client
|
||||
JWTManager *clusterjwt.Manager
|
||||
JWTIssuer *clusterjwt.Issuer
|
||||
JWTVerifier *clusterjwt.Verifier
|
||||
}
|
||||
|
||||
func SetConfig(c *config.Config) {
|
||||
if c == nil {
|
||||
log.Panic("panic: argument is nil in get.SetConfig(c *config.Config)")
|
||||
return
|
||||
}
|
||||
|
||||
resetJWT()
|
||||
|
||||
conf = c
|
||||
|
||||
photoprism.SetConfig(c)
|
||||
@@ -72,6 +76,7 @@ func SetConfig(c *config.Config) {
|
||||
func Config() *config.Config {
|
||||
if conf == nil {
|
||||
log.Panic("panic: conf is nil in get.Config()")
|
||||
return nil
|
||||
}
|
||||
|
||||
return conf
|
||||
|
@@ -229,9 +229,9 @@ func persistRegistration(c *config.Config, r *cluster.RegisterResponse, wantRota
|
||||
updates["NodeClientSecret"] = r.Secrets.ClientSecret
|
||||
}
|
||||
|
||||
if url := strings.TrimSpace(r.JWKSUrl); url != "" {
|
||||
updates["JWKSUrl"] = url
|
||||
c.SetJWKSUrl(url)
|
||||
if jwksUrl := strings.TrimSpace(r.JWKSUrl); jwksUrl != "" {
|
||||
updates["JWKSUrl"] = jwksUrl
|
||||
c.SetJWKSUrl(jwksUrl)
|
||||
}
|
||||
|
||||
// Persist NodeUUID from portal response if provided and not set locally.
|
||||
|
@@ -221,13 +221,13 @@ func (c *Config) ReSync(token string) (err error) {
|
||||
// interrupt reading of the Response.Body.
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
|
||||
url := ServiceURL
|
||||
endpointUrl := ServiceURL
|
||||
method := http.MethodPost
|
||||
|
||||
var req *http.Request
|
||||
|
||||
if c.Key != "" {
|
||||
url = fmt.Sprintf(ServiceURL+"/%s", c.Key)
|
||||
endpointUrl = fmt.Sprintf(ServiceURL+"/%s", c.Key)
|
||||
method = http.MethodPut
|
||||
log.Tracef("config: requesting updated keys for maps and places")
|
||||
} else {
|
||||
@@ -239,7 +239,7 @@ func (c *Config) ReSync(token string) (err error) {
|
||||
|
||||
if j, err = json.Marshal(NewRequest(c.Version, c.Serial, c.Env, c.PartnerID, token)); err != nil {
|
||||
return err
|
||||
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
|
||||
} else if req, err = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -67,17 +67,17 @@ func (c *Config) SendFeedback(frm form.Feedback) (err error) {
|
||||
// interrupt reading of the Response.Body.
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
|
||||
url := fmt.Sprintf(FeedbackURL, c.Key)
|
||||
endpointUrl := fmt.Sprintf(FeedbackURL, c.Key)
|
||||
method := http.MethodPost
|
||||
|
||||
var req *http.Request
|
||||
|
||||
log.Debugf("sending feedback to %s", ApiHost())
|
||||
|
||||
if j, err := json.Marshal(feedback); err != nil {
|
||||
return err
|
||||
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
|
||||
return err
|
||||
if j, reqErr := json.Marshal(feedback); reqErr != nil {
|
||||
return reqErr
|
||||
} else if req, reqErr = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); reqErr != nil {
|
||||
return reqErr
|
||||
}
|
||||
|
||||
// Set user agent.
|
||||
|
Reference in New Issue
Block a user