mirror of
https://github.com/photoprism/photoprism.git
synced 2025-09-26 21:01:58 +08:00
Compare commits
10 Commits
7e419f7419
...
9f119a8cfa
Author | SHA1 | Date | |
---|---|---|---|
![]() |
9f119a8cfa | ||
![]() |
66e2027c10 | ||
![]() |
bd66110c18 | ||
![]() |
07658dac69 | ||
![]() |
108b2c2df4 | ||
![]() |
48a965a7cc | ||
![]() |
32c054da7a | ||
![]() |
566eed05e0 | ||
![]() |
660c0a89db | ||
![]() |
ebb0410b20 |
19
AGENTS.md
19
AGENTS.md
@@ -163,6 +163,15 @@ Note: Across our public documentation, official images, and in production, the c
|
||||
- JS/Vue: use the lint/format scripts in `frontend/package.json` (ESLint + Prettier)
|
||||
- All added code and tests **must** be formatted according to our standards.
|
||||
|
||||
> Remember to update the `**Last Updated:**` line at the top whenever you edit these guidelines or other files containing a timestamp.
|
||||
|
||||
## Safety & Data
|
||||
|
||||
- Never commit secrets, local configurations, or cache files. Use environment variables or a local `.env`.
|
||||
- Ensure `.env` and `.local` are ignored in `.gitignore` and `.dockerignore`.
|
||||
- Prefer using existing caches, workers, and batching strategies referenced in code and `Makefile`. Consider memory/CPU impact; suggest benchmarks or profiling only when justified.
|
||||
- Do not run destructive commands against production data. Prefer ephemeral volumes and test fixtures when running acceptance tests.
|
||||
|
||||
### Filesystem Permissions & io/fs Aliasing (Go)
|
||||
|
||||
- Always use our shared permission variables from `pkg/fs` when creating files/directories:
|
||||
@@ -176,13 +185,7 @@ Note: Across our public documentation, official images, and in production, the c
|
||||
- Our package is `github.com/photoprism/photoprism/pkg/fs` and provides the only approved permission constants for `os.MkdirAll`, `os.WriteFile`, `os.OpenFile`, and `os.Chmod`.
|
||||
- Prefer `filepath.Join` for filesystem paths; reserve `path.Join` for URL paths.
|
||||
|
||||
## Safety & Data
|
||||
|
||||
- Never commit secrets, local configurations, or cache files. Use environment variables or a local `.env`.
|
||||
- Ensure `.env` and `.local` are ignored in `.gitignore` and `.dockerignore`.
|
||||
- Prefer using existing caches, workers, and batching strategies referenced in code and `Makefile`. Consider memory/CPU impact; suggest benchmarks or profiling only when justified.
|
||||
- Do not run destructive commands against production data. Prefer ephemeral volumes and test fixtures when running acceptance tests.
|
||||
- ### File I/O — Overwrite Policy (force semantics)
|
||||
### File I/O — Overwrite Policy (force semantics)
|
||||
|
||||
- Default is safety-first: callers must not overwrite non-empty destination files unless they opt-in with a `force` flag.
|
||||
- Replacing empty destination files is allowed without `force=true` (useful for placeholder files).
|
||||
@@ -240,12 +243,14 @@ If anything in this file conflicts with the `Makefile` or the Developer Guide, t
|
||||
- In `internal/photoprism` tests, rely on `photoprism.Config()` for runtime-accurate behavior; only build a new config if you replace it via `photoprism.SetConfig`.
|
||||
- Generate identifiers with `rnd.GenerateUID(entity.ClientUID)` for OAuth client IDs and `rnd.UUIDv7()` for node UUIDs; treat `node.uuid` as required in responses.
|
||||
- Shared fixtures live under `storage/testdata`; `NewTestConfig("<pkg>")` already calls `InitializeTestData()`, but call `c.InitializeTestData()` (and optionally `c.AssertTestData(t)`) when you construct custom configs so originals/import/cache/temp exist. `InitializeTestData()` clears old data, downloads fixtures if needed, then calls `CreateDirectories()`.
|
||||
- For slimmer tests that only need config objects, prefer the new helpers in `internal/config/test.go`: `NewMinimalTestConfig(t.TempDir())` when no database is needed, or `NewMinimalTestConfigWithDb("<pkg>", t.TempDir())` to spin up an isolated SQLite schema without seeding all fixtures.
|
||||
|
||||
### Roles & ACL
|
||||
- Map roles via the shared tables: users through `acl.ParseRole(s)` / `acl.UserRoles[...]`, clients through `acl.ClientRoles[...]`.
|
||||
- Treat `RoleAliasNone` ("none") and an empty string as `RoleNone`; no caller-specific overrides.
|
||||
- Default unknown client roles to `RoleClient`; `acl.ParseRole` already handles `0/false/nil` as none for users.
|
||||
- Build CLI role help from `Roles.CliUsageString()` (e.g., `acl.ClientRoles.CliUsageString()`); never hand-maintain role lists.
|
||||
- When checking JWT/client scopes, use the shared helpers (`acl.ScopePermits` / `acl.ScopeAttrPermits`) instead of hand-written parsing.
|
||||
|
||||
### Import/Index
|
||||
|
||||
|
@@ -80,7 +80,7 @@ Database & Migrations
|
||||
|
||||
AuthN/Z & Sessions
|
||||
- Session model and cache: `internal/entity/auth_session*` and `internal/auth/session/*` (cleanup worker).
|
||||
- ACL: `internal/auth/acl/*` — roles, grants, scopes; use constants; avoid logging secrets, compare tokens constant‑time.
|
||||
- ACL: `internal/auth/acl/*` — roles, grants, scopes; use constants; avoid logging secrets, compare tokens constant‑time; for scope checks use `acl.ScopePermits` / `ScopeAttrPermits` instead of rolling your own parsing.
|
||||
- OIDC: `internal/auth/oidc/*`.
|
||||
|
||||
Media Processing
|
||||
|
18
Makefile
18
Makefile
@@ -72,15 +72,15 @@ watch: watch-js
|
||||
build-all: build-go build-js
|
||||
pull: docker-pull
|
||||
test: test-js test-go
|
||||
test-go: reset-sqlite run-test-go
|
||||
test-pkg: reset-sqlite run-test-pkg
|
||||
test-ai: reset-sqlite run-test-ai
|
||||
test-api: reset-sqlite run-test-api
|
||||
test-video: reset-sqlite run-test-video
|
||||
test-entity: reset-sqlite run-test-entity
|
||||
test-commands: reset-sqlite run-test-commands
|
||||
test-photoprism: reset-sqlite run-test-photoprism
|
||||
test-short: reset-sqlite run-test-short
|
||||
test-go: run-test-go
|
||||
test-pkg: run-test-pkg
|
||||
test-ai: run-test-ai
|
||||
test-api: run-test-api
|
||||
test-video: run-test-video
|
||||
test-entity: run-test-entity
|
||||
test-commands: run-test-commands
|
||||
test-photoprism: run-test-photoprism
|
||||
test-short: run-test-short
|
||||
test-mariadb: reset-acceptance run-test-mariadb
|
||||
acceptance-run-chromium: storage/acceptance acceptance-auth-sqlite-restart wait acceptance-auth acceptance-auth-sqlite-stop acceptance-sqlite-restart wait-2 acceptance acceptance-sqlite-stop
|
||||
acceptance-run-chromium-short: storage/acceptance acceptance-auth-sqlite-restart wait acceptance-auth-short acceptance-auth-sqlite-stop acceptance-sqlite-restart wait-2 acceptance-short acceptance-sqlite-stop
|
||||
|
@@ -1,7 +1,9 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@@ -16,79 +18,132 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
// authAnyJWT attempts to authenticate a Portal-issued JWT when a cluster
|
||||
// node receives a request without an existing session. It verifies the token
|
||||
// against the node's cached JWKS, ensures the issuer/audience/scope match the
|
||||
// expected portal values, and, if valid, returns a client session mirroring the
|
||||
// JWT claims. It returns nil on any validation failure so the caller can fall
|
||||
// back to existing auth flows. Currently cluster and vision resources are
|
||||
// eligible for JWT-based authorization; vision access requires the `vision`
|
||||
// scope whereas cluster access requires the `cluster` scope.
|
||||
// authAnyJWT attempts to authenticate a Portal-issued JWT when a cluster node
|
||||
// receives a request without an existing session. It verifies the token against
|
||||
// the node's cached JWKS, ensures the issuer/audience/scope match the expected
|
||||
// portal values, and, if valid, returns a client session mirroring the JWT
|
||||
// claims. It returns nil on any validation failure so the caller can fall back
|
||||
// to existing auth flows. By default, only cluster and vision resources are
|
||||
// eligible, but nodes may opt in to additional scopes via PHOTOPRISM_JWT_SCOPE.
|
||||
func authAnyJWT(c *gin.Context, clientIP, authToken string, resource acl.Resource, perms acl.Permissions) *entity.Session {
|
||||
if c == nil || authToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = perms
|
||||
|
||||
if resource != acl.ResourceCluster && resource != acl.ResourceVision {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Basic sanity check for JWT structure.
|
||||
if strings.Count(authToken, ".") != 2 {
|
||||
// Check if token may be a JWT.
|
||||
if !shouldAttemptJWT(c, authToken) {
|
||||
return nil
|
||||
}
|
||||
|
||||
conf := get.Config()
|
||||
|
||||
if conf == nil || conf.IsPortal() {
|
||||
// Determine whether JWT authentication is possible
|
||||
// based on the local config and client IP address.
|
||||
if !shouldAllowJWT(conf, clientIP) {
|
||||
return nil
|
||||
}
|
||||
|
||||
requiredScope := resource.String()
|
||||
expected := expectedClaimsFor(conf, requiredScope)
|
||||
|
||||
// verifyTokenFromPortal handles cryptographic validation (signature, issuer,
|
||||
// audience, temporal claims) and enforces that the token includes any scopes
|
||||
// listed in expected.Scope. Local authorization still happens below so nodes
|
||||
// can apply their own allow-list semantics.
|
||||
claims := verifyTokenFromPortal(c.Request.Context(), authToken, expected, jwtIssuerCandidates(conf))
|
||||
|
||||
if claims == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if config allows resource access to be authorized with JWT.
|
||||
allowedScopes := conf.JWTAllowedScopes()
|
||||
if !acl.ScopeAttrPermits(allowedScopes, resource, perms) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if token allows access to specified resource.
|
||||
tokenScopes := acl.ScopeAttr(claims.Scope)
|
||||
if !acl.ScopeAttrPermits(tokenScopes, resource, perms) {
|
||||
return nil
|
||||
}
|
||||
|
||||
claims.Scope = tokenScopes.String()
|
||||
|
||||
return sessionFromJWTClaims(claims, clientIP)
|
||||
}
|
||||
|
||||
// shouldAttemptJWT reports whether JWT verification should run for the supplied
|
||||
// request context and token.
|
||||
func shouldAttemptJWT(c *gin.Context, token string) bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if token == "" || strings.Count(token, ".") != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldAllowJWT reports whether the current node configuration permits JWT
|
||||
// authentication for the request originating from clientIP.
|
||||
func shouldAllowJWT(conf *config.Config, clientIP string) bool {
|
||||
if conf == nil || conf.IsPortal() {
|
||||
return false
|
||||
}
|
||||
|
||||
if conf.JWKSUrl() == "" {
|
||||
return nil
|
||||
return false
|
||||
}
|
||||
|
||||
verifier := clusterjwt.NewVerifier(conf)
|
||||
requiredScopes := []string{"cluster"}
|
||||
if resource == acl.ResourceVision {
|
||||
requiredScopes = []string{"vision"}
|
||||
cidr := strings.TrimSpace(conf.ClusterCIDR())
|
||||
if cidr == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := net.ParseIP(clientIP)
|
||||
_, block, err := net.ParseCIDR(cidr)
|
||||
if err != nil || ip == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return block.Contains(ip)
|
||||
}
|
||||
|
||||
// expectedClaimsFor builds the ExpectedClaims used to validate JWTs for the
|
||||
// current node and required scope.
|
||||
func expectedClaimsFor(conf *config.Config, requiredScope string) clusterjwt.ExpectedClaims {
|
||||
expected := clusterjwt.ExpectedClaims{
|
||||
Audience: fmt.Sprintf("node:%s", conf.NodeUUID()),
|
||||
Scope: requiredScopes,
|
||||
JWKSURL: conf.JWKSUrl(),
|
||||
}
|
||||
|
||||
issuers := jwtIssuerCandidates(conf)
|
||||
if requiredScope != "" {
|
||||
expected.Scope = []string{requiredScope}
|
||||
}
|
||||
|
||||
return expected
|
||||
}
|
||||
|
||||
// verifyTokenFromPortal checks the token against each candidate issuer and
|
||||
// returns the verified claims on success.
|
||||
func verifyTokenFromPortal(ctx context.Context, token string, expected clusterjwt.ExpectedClaims, issuers []string) *clusterjwt.Claims {
|
||||
if len(issuers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
claims *clusterjwt.Claims
|
||||
err error
|
||||
)
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
for _, issuer := range issuers {
|
||||
expected.Issuer = issuer
|
||||
claims, err = verifier.VerifyToken(ctx, authToken, expected)
|
||||
claims, err := get.VerifyJWT(ctx, token, expected)
|
||||
if err == nil {
|
||||
break
|
||||
return claims
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
} else if claims == nil {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// sessionFromJWTClaims constructs a Session populated with fields derived from
|
||||
// the verified JWT claims.
|
||||
func sessionFromJWTClaims(claims *clusterjwt.Claims, clientIP string) *entity.Session {
|
||||
sess := &entity.Session{
|
||||
Status: http.StatusOK,
|
||||
ClientUID: claims.Subject,
|
||||
|
@@ -1,17 +1,22 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
func TestAuthAnyJWT(t *testing.T) {
|
||||
@@ -35,7 +40,149 @@ func TestAuthAnyJWT(t *testing.T) {
|
||||
assert.Contains(t, session.AuthScope, "cluster")
|
||||
assert.Equal(t, spec.Issuer, session.AuthIssuer)
|
||||
})
|
||||
t.Run("ClusterCIDRAllowed", func(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "cluster-jwt-cidr-allow")
|
||||
spec := fx.defaultClaimsSpec()
|
||||
token := fx.issue(t, spec)
|
||||
|
||||
origCIDR := fx.nodeConf.Options().ClusterCIDR
|
||||
fx.nodeConf.Options().ClusterCIDR = "192.0.2.0/24"
|
||||
get.SetConfig(fx.nodeConf)
|
||||
t.Cleanup(func() {
|
||||
fx.nodeConf.Options().ClusterCIDR = origCIDR
|
||||
get.SetConfig(fx.nodeConf)
|
||||
})
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/cluster/theme", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.RemoteAddr = "192.0.2.10:2222"
|
||||
c.Request = req
|
||||
|
||||
session := authAnyJWT(c, "192.0.2.10", token, acl.ResourceCluster, nil)
|
||||
require.NotNil(t, session)
|
||||
assert.Equal(t, spec.Subject, session.ClientUID)
|
||||
})
|
||||
t.Run("ClusterCIDRBlocked", func(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "cluster-jwt-cidr-block")
|
||||
spec := fx.defaultClaimsSpec()
|
||||
token := fx.issue(t, spec)
|
||||
|
||||
origCIDR := fx.nodeConf.Options().ClusterCIDR
|
||||
fx.nodeConf.Options().ClusterCIDR = "192.0.2.0/24"
|
||||
get.SetConfig(fx.nodeConf)
|
||||
t.Cleanup(func() {
|
||||
fx.nodeConf.Options().ClusterCIDR = origCIDR
|
||||
get.SetConfig(fx.nodeConf)
|
||||
})
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/cluster/theme", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.RemoteAddr = "203.0.113.10:2222"
|
||||
c.Request = req
|
||||
|
||||
assert.Nil(t, authAnyJWT(c, "203.0.113.10", token, acl.ResourceCluster, nil))
|
||||
})
|
||||
t.Run("JWTScopeDefaultRejectsOtherResources", func(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "cluster-jwt-scope-default-reject")
|
||||
spec := fx.defaultClaimsSpec()
|
||||
spec.Scope = []string{"photos"}
|
||||
token := fx.issue(t, spec)
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/photos", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.RemoteAddr = "192.0.2.60:1001"
|
||||
c.Request = req
|
||||
|
||||
assert.Nil(t, authAnyJWT(c, "192.0.2.60", token, acl.ResourcePhotos, nil))
|
||||
})
|
||||
t.Run("JWTScopeAllowed", func(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "cluster-jwt-scope-allow")
|
||||
token := fx.issue(t, fx.defaultClaimsSpec())
|
||||
|
||||
orig := fx.nodeConf.Options().JWTScope
|
||||
fx.nodeConf.Options().JWTScope = "cluster vision"
|
||||
get.SetConfig(fx.nodeConf)
|
||||
t.Cleanup(func() {
|
||||
fx.nodeConf.Options().JWTScope = orig
|
||||
get.SetConfig(fx.nodeConf)
|
||||
})
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/cluster/theme", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.RemoteAddr = "192.0.2.30:1001"
|
||||
c.Request = req
|
||||
|
||||
sess := authAnyJWT(c, "192.0.2.30", token, acl.ResourceCluster, nil)
|
||||
require.NotNil(t, sess)
|
||||
})
|
||||
t.Run("JWTScopeAllowsSuperset", func(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "cluster-jwt-scope-reject")
|
||||
token := fx.issue(t, fx.defaultClaimsSpec())
|
||||
|
||||
orig := fx.nodeConf.Options().JWTScope
|
||||
fx.nodeConf.Options().JWTScope = "cluster"
|
||||
get.SetConfig(fx.nodeConf)
|
||||
t.Cleanup(func() {
|
||||
fx.nodeConf.Options().JWTScope = orig
|
||||
get.SetConfig(fx.nodeConf)
|
||||
})
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/cluster/theme", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.RemoteAddr = "192.0.2.40:1001"
|
||||
c.Request = req
|
||||
|
||||
sess := authAnyJWT(c, "192.0.2.40", token, acl.ResourceCluster, nil)
|
||||
require.NotNil(t, sess)
|
||||
})
|
||||
t.Run("JWTScopeCustomResource", func(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "cluster-jwt-scope-custom")
|
||||
spec := fx.defaultClaimsSpec()
|
||||
spec.Scope = []string{"photos"}
|
||||
token := fx.issue(t, spec)
|
||||
|
||||
origScope := fx.nodeConf.Options().JWTScope
|
||||
fx.nodeConf.Options().JWTScope = "photos"
|
||||
get.SetConfig(fx.nodeConf)
|
||||
t.Cleanup(func() {
|
||||
fx.nodeConf.Options().JWTScope = origScope
|
||||
get.SetConfig(fx.nodeConf)
|
||||
})
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/v1/photos", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.RemoteAddr = "192.0.2.50:2001"
|
||||
c.Request = req
|
||||
|
||||
_, verifyErr := get.VerifyJWT(c.Request.Context(), token, clusterjwt.ExpectedClaims{
|
||||
Issuer: fmt.Sprintf("portal:%s", fx.clusterUUID),
|
||||
Audience: fmt.Sprintf("node:%s", fx.nodeUUID),
|
||||
Scope: []string{"photos"},
|
||||
JWKSURL: fx.nodeConf.JWKSUrl(),
|
||||
})
|
||||
require.NoError(t, verifyErr)
|
||||
|
||||
sess := authAnyJWT(c, "192.0.2.50", token, acl.ResourcePhotos, nil)
|
||||
require.NotNil(t, sess)
|
||||
})
|
||||
t.Run("VisionScope", func(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "cluster-jwt-vision")
|
||||
spec := fx.defaultClaimsSpec()
|
||||
@@ -146,3 +293,83 @@ func TestJwtIssuerCandidates(t *testing.T) {
|
||||
assert.Equal(t, []string{"http://localhost:2342"}, jwtIssuerCandidates(conf))
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldAttemptJWT(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/ping", nil)
|
||||
c.Request = req
|
||||
|
||||
assert.True(t, shouldAttemptJWT(c, "a.b.c"))
|
||||
assert.False(t, shouldAttemptJWT(nil, "a.b.c"))
|
||||
assert.False(t, shouldAttemptJWT(c, "invalidtoken"))
|
||||
assert.False(t, shouldAttemptJWT(c, ""))
|
||||
}
|
||||
|
||||
func TestNodeAllowsJWT(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "node-allows")
|
||||
conf := fx.nodeConf
|
||||
|
||||
assert.True(t, shouldAllowJWT(conf, "192.0.2.9"))
|
||||
|
||||
origCIDR := conf.Options().ClusterCIDR
|
||||
conf.Options().ClusterCIDR = "192.0.2.0/24"
|
||||
assert.True(t, shouldAllowJWT(conf, "192.0.2.25"))
|
||||
assert.False(t, shouldAllowJWT(conf, "203.0.113.1"))
|
||||
conf.Options().ClusterCIDR = origCIDR
|
||||
|
||||
origJWKS := conf.JWKSUrl()
|
||||
conf.SetJWKSUrl("")
|
||||
assert.False(t, shouldAllowJWT(conf, "192.0.2.25"))
|
||||
conf.SetJWKSUrl(origJWKS)
|
||||
|
||||
assert.False(t, shouldAllowJWT(nil, "192.0.2.25"))
|
||||
}
|
||||
|
||||
func TestExpectedClaimsFor(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "expected-claims")
|
||||
|
||||
claims := expectedClaimsFor(fx.nodeConf, "cluster")
|
||||
assert.Equal(t, fmt.Sprintf("node:%s", fx.nodeUUID), claims.Audience)
|
||||
assert.Equal(t, []string{"cluster"}, claims.Scope)
|
||||
assert.Equal(t, fx.nodeConf.JWKSUrl(), claims.JWKSURL)
|
||||
|
||||
noScope := expectedClaimsFor(fx.nodeConf, "")
|
||||
assert.Nil(t, noScope.Scope)
|
||||
}
|
||||
|
||||
func TestVerifyTokenFromPortal(t *testing.T) {
|
||||
fx := newPortalJWTFixture(t, "verify-token")
|
||||
spec := fx.defaultClaimsSpec()
|
||||
token := fx.issue(t, spec)
|
||||
|
||||
expected := expectedClaimsFor(fx.nodeConf, clean.Scope("cluster"))
|
||||
claims := verifyTokenFromPortal(context.Background(), token, expected, []string{"wrong", spec.Issuer})
|
||||
require.NotNil(t, claims)
|
||||
assert.Equal(t, spec.Issuer, claims.Issuer)
|
||||
assert.Equal(t, spec.Subject, claims.Subject)
|
||||
|
||||
nilClaims := verifyTokenFromPortal(context.Background(), token, expected, []string{"wrong"})
|
||||
assert.Nil(t, nilClaims)
|
||||
}
|
||||
|
||||
func TestSessionFromJWTClaims(t *testing.T) {
|
||||
claims := &clusterjwt.Claims{
|
||||
Scope: "cluster vision",
|
||||
RegisteredClaims: gojwt.RegisteredClaims{
|
||||
Issuer: "portal:test",
|
||||
Subject: "portal:client",
|
||||
ID: "token-id",
|
||||
},
|
||||
}
|
||||
|
||||
sess := sessionFromJWTClaims(claims, "192.0.2.100")
|
||||
require.NotNil(t, sess)
|
||||
assert.Equal(t, http.StatusOK, sess.HttpStatus())
|
||||
assert.Equal(t, "portal:client", sess.ClientUID)
|
||||
assert.Equal(t, clean.Scope("cluster vision"), sess.AuthScope)
|
||||
assert.Equal(t, "portal:test", sess.AuthIssuer)
|
||||
assert.Equal(t, "token-id", sess.AuthID)
|
||||
assert.Equal(t, "192.0.2.100", sess.ClientIP)
|
||||
}
|
||||
|
@@ -253,11 +253,12 @@ type portalJWTFixture struct {
|
||||
|
||||
func newPortalJWTFixture(t *testing.T, suffix string) portalJWTFixture {
|
||||
t.Helper()
|
||||
t.Setenv("PHOTOPRISM_STORAGE_PATH", t.TempDir())
|
||||
|
||||
origConf := get.Config()
|
||||
t.Cleanup(func() { get.SetConfig(origConf) })
|
||||
|
||||
nodeConf := config.NewTestConfig("auth-any-portal-jwt-" + suffix)
|
||||
nodeConf := config.NewMinimalTestConfigWithDb("auth-any-portal-jwt-"+suffix, t.TempDir())
|
||||
|
||||
nodeConf.Options().NodeRole = cluster.RoleInstance
|
||||
nodeConf.Options().Public = false
|
||||
clusterUUID := rnd.UUID()
|
||||
@@ -265,7 +266,8 @@ func newPortalJWTFixture(t *testing.T, suffix string) portalJWTFixture {
|
||||
nodeUUID := nodeConf.NodeUUID()
|
||||
nodeConf.Options().PortalUrl = "https://portal.example.test"
|
||||
|
||||
portalConf := config.NewTestConfig("auth-any-portal-jwt-issuer-" + suffix)
|
||||
portalConf := config.NewMinimalTestConfigWithDb("auth-any-portal-jwt-issuer-"+suffix, t.TempDir())
|
||||
|
||||
portalConf.Options().NodeRole = cluster.RolePortal
|
||||
portalConf.Options().ClusterUUID = clusterUUID
|
||||
|
||||
|
@@ -229,6 +229,7 @@ func ClusterNodesRegister(router *gin.RouterGroup) {
|
||||
|
||||
resp := cluster.RegisterResponse{
|
||||
UUID: conf.ClusterUUID(),
|
||||
ClusterCIDR: conf.ClusterCIDR(),
|
||||
Node: reg.BuildClusterNode(*n, opts),
|
||||
Database: cluster.RegisterDatabase{Host: conf.DatabaseHost(), Port: conf.DatabasePort(), Name: dbInfo.Name, User: dbInfo.User, Driver: provisioner.DatabaseDriver},
|
||||
Secrets: respSecret,
|
||||
@@ -299,6 +300,8 @@ func ClusterNodesRegister(router *gin.RouterGroup) {
|
||||
}
|
||||
|
||||
resp := cluster.RegisterResponse{
|
||||
UUID: conf.ClusterUUID(),
|
||||
ClusterCIDR: conf.ClusterCIDR(),
|
||||
Node: reg.BuildClusterNode(*n, reg.NodeOptsForSession(nil)),
|
||||
Secrets: &cluster.RegisterSecrets{ClientSecret: n.ClientSecret, RotatedAt: n.RotatedAt},
|
||||
Database: cluster.RegisterDatabase{Host: conf.DatabaseHost(), Port: conf.DatabasePort(), Name: creds.Name, User: creds.User, Driver: provisioner.DatabaseDriver, Password: creds.Password, DSN: creds.DSN, RotatedAt: creds.RotatedAt},
|
||||
|
@@ -46,10 +46,11 @@ func ClusterSummary(router *gin.RouterGroup) {
|
||||
nodes, _ := regy.List()
|
||||
|
||||
c.JSON(http.StatusOK, cluster.SummaryResponse{
|
||||
UUID: conf.ClusterUUID(),
|
||||
Nodes: len(nodes),
|
||||
Database: cluster.DatabaseInfo{Driver: conf.DatabaseDriverName(), Host: conf.DatabaseHost(), Port: conf.DatabasePort()},
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
UUID: conf.ClusterUUID(),
|
||||
ClusterCIDR: conf.ClusterCIDR(),
|
||||
Nodes: len(nodes),
|
||||
Database: cluster.DatabaseInfo{Driver: conf.DatabaseDriverName(), Host: conf.DatabaseHost(), Port: conf.DatabasePort()},
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
@@ -25,12 +25,12 @@ func TestEcho(t *testing.T) {
|
||||
t.Logf("Response Body: %s", r.Body.String())
|
||||
|
||||
body := r.Body.String()
|
||||
url := gjson.Get(body, "url").String()
|
||||
bodyUrl := gjson.Get(body, "url").String()
|
||||
method := gjson.Get(body, "method").String()
|
||||
request := gjson.Get(body, "headers.request")
|
||||
response := gjson.Get(body, "headers.response")
|
||||
|
||||
assert.Equal(t, "/api/v1/echo", url)
|
||||
assert.Equal(t, "/api/v1/echo", bodyUrl)
|
||||
assert.Equal(t, "GET", method)
|
||||
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
||||
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
||||
@@ -49,12 +49,12 @@ func TestEcho(t *testing.T) {
|
||||
r := AuthenticatedRequest(app, http.MethodPost, "/api/v1/echo", authToken)
|
||||
|
||||
body := r.Body.String()
|
||||
url := gjson.Get(body, "url").String()
|
||||
bodyUrl := gjson.Get(body, "url").String()
|
||||
method := gjson.Get(body, "method").String()
|
||||
request := gjson.Get(body, "headers.request")
|
||||
response := gjson.Get(body, "headers.response")
|
||||
|
||||
assert.Equal(t, "/api/v1/echo", url)
|
||||
assert.Equal(t, "/api/v1/echo", bodyUrl)
|
||||
assert.Equal(t, "POST", method)
|
||||
assert.Equal(t, "Bearer "+authToken, request.Get("Authorization.0").String())
|
||||
assert.Equal(t, "application/json; charset=utf-8", response.Get("Content-Type.0").String())
|
||||
|
@@ -419,6 +419,9 @@
|
||||
"alreadyRegistered": {
|
||||
"type": "boolean"
|
||||
},
|
||||
"clusterCidr": {
|
||||
"type": "string"
|
||||
},
|
||||
"database": {
|
||||
"$ref": "#/definitions/cluster.RegisterDatabase"
|
||||
},
|
||||
@@ -459,6 +462,9 @@
|
||||
},
|
||||
"cluster.SummaryResponse": {
|
||||
"properties": {
|
||||
"clusterCidr": {
|
||||
"type": "string"
|
||||
},
|
||||
"database": {
|
||||
"$ref": "#/definitions/cluster.DatabaseInfo"
|
||||
},
|
||||
|
@@ -1,5 +1,11 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/pkg/list"
|
||||
)
|
||||
|
||||
// Permission scopes to Grant multiple Permissions for a Resource.
|
||||
const (
|
||||
ScopeRead Permission = "read"
|
||||
@@ -35,3 +41,63 @@ var (
|
||||
ActionManageOwn: true,
|
||||
}
|
||||
)
|
||||
|
||||
// ScopeAttr parses an auth scope string and returns a normalized Attr
|
||||
// with duplicate and invalid entries removed.
|
||||
func ScopeAttr(s string) list.Attr {
|
||||
if s == "" {
|
||||
return list.Attr{}
|
||||
}
|
||||
|
||||
return list.ParseAttr(strings.ToLower(s))
|
||||
}
|
||||
|
||||
// ScopePermits sanitizes the raw scope string and then calls ScopeAttrPermits for
|
||||
// the actual authorization check.
|
||||
func ScopePermits(scope string, resource Resource, perms Permissions) bool {
|
||||
if scope == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse scope to check for resources and permissions.
|
||||
return ScopeAttrPermits(ScopeAttr(scope), resource, perms)
|
||||
}
|
||||
|
||||
// ScopeAttrPermits evaluates an already-parsed scope attribute list against a
|
||||
// resource and permission set, enforcing wildcard/read/write semantics.
|
||||
func ScopeAttrPermits(attr list.Attr, resource Resource, perms Permissions) bool {
|
||||
if len(attr) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
scope := attr.String()
|
||||
|
||||
// Skip detailed check and allow all if scope is "*".
|
||||
if scope == list.Any {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip resource check if scope includes all read operations.
|
||||
if scope == ScopeRead.String() {
|
||||
return !GrantScopeRead.DenyAny(perms)
|
||||
}
|
||||
|
||||
// Check if resource is within scope.
|
||||
if granted := attr.Contains(resource.String()); !granted {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if permission is within scope.
|
||||
if len(perms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if scope is limited to read or write operations.
|
||||
if a := attr.Find(ScopeRead.String()); a.Value == list.True && GrantScopeRead.DenyAny(perms) {
|
||||
return false
|
||||
} else if a = attr.Find(ScopeWrite.String()); a.Value == list.True && GrantScopeWrite.DenyAny(perms) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
@@ -35,3 +35,136 @@ func TestGrantScopeWrite(t *testing.T) {
|
||||
assert.False(t, GrantScopeWrite.DenyAny(Permissions{AccessAll}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopePermits(t *testing.T) {
|
||||
t.Run("AnyScope", func(t *testing.T) {
|
||||
assert.True(t, ScopePermits("*", "", nil))
|
||||
})
|
||||
t.Run("ReadScope", func(t *testing.T) {
|
||||
assert.True(t, ScopePermits("read", "metrics", nil))
|
||||
assert.True(t, ScopePermits("read", "sessions", nil))
|
||||
assert.True(t, ScopePermits("read", "metrics", Permissions{ActionView, AccessAll}))
|
||||
assert.False(t, ScopePermits("read", "metrics", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read", "metrics", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read", "settings", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read", "settings", Permissions{ActionCreate}))
|
||||
assert.False(t, ScopePermits("read", "sessions", Permissions{ActionDelete}))
|
||||
})
|
||||
t.Run("ReadAny", func(t *testing.T) {
|
||||
assert.True(t, ScopePermits("read *", "metrics", nil))
|
||||
assert.True(t, ScopePermits("read *", "sessions", nil))
|
||||
assert.True(t, ScopePermits("read *", "metrics", Permissions{ActionView, AccessAll}))
|
||||
assert.False(t, ScopePermits("read *", "metrics", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read *", "metrics", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read *", "settings", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read *", "settings", Permissions{ActionCreate}))
|
||||
assert.False(t, ScopePermits("read *", "sessions", Permissions{ActionDelete}))
|
||||
})
|
||||
t.Run("ReadSettings", func(t *testing.T) {
|
||||
assert.True(t, ScopePermits("read settings", "settings", Permissions{ActionView}))
|
||||
assert.False(t, ScopePermits("read settings", "metrics", nil))
|
||||
assert.False(t, ScopePermits("read settings", "sessions", nil))
|
||||
assert.False(t, ScopePermits("read settings", "metrics", Permissions{ActionView, AccessAll}))
|
||||
assert.False(t, ScopePermits("read settings", "metrics", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read settings", "metrics", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read settings", "settings", Permissions{ActionUpdate}))
|
||||
assert.False(t, ScopePermits("read settings", "sessions", Permissions{ActionDelete}))
|
||||
assert.False(t, ScopePermits("read settings", "sessions", Permissions{ActionDelete}))
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopeAttr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []string
|
||||
}{
|
||||
{name: "Empty", input: "", expected: nil},
|
||||
{name: "Lowercase", input: "read metrics", expected: []string{"metrics", "read"}},
|
||||
{name: "Uppercase", input: "READ SETTINGS", expected: []string{"read", "settings"}},
|
||||
{name: "WithNoise", input: " Read\tSessions\nmetrics", expected: []string{"metrics", "read", "sessions"}},
|
||||
{name: "Deduplicates", input: "metrics metrics", expected: []string{"metrics"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
attr := ScopeAttr(tc.input)
|
||||
if len(tc.expected) == 0 {
|
||||
assert.Len(t, attr, 0)
|
||||
return
|
||||
}
|
||||
assert.ElementsMatch(t, tc.expected, attr.Strings())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopePermitsEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope string
|
||||
resource Resource
|
||||
perms Permissions
|
||||
want bool
|
||||
}{
|
||||
{name: "EmptyScope", scope: "", resource: "metrics", perms: nil, want: false},
|
||||
{name: "OnlyInvalidChars", scope: "()", resource: "metrics", perms: nil, want: false},
|
||||
{name: "WildcardMixedOrder", scope: "* read metrics", resource: "metrics", perms: Permissions{ActionUpdate}, want: false},
|
||||
{name: "WildcardOverridesReadRestrictions", scope: "read metrics *", resource: "metrics", perms: Permissions{ActionDelete}, want: false},
|
||||
{name: "WildcardWithFalseValueIgnored", scope: "*:false read", resource: "metrics", perms: Permissions{ActionUpdate}, want: false},
|
||||
{name: "ExplicitFalseResource", scope: "metrics:false", resource: "metrics", perms: nil, want: false},
|
||||
{name: "ExplicitTrueResource", scope: "metrics:true", resource: "metrics", perms: nil, want: true},
|
||||
{name: "CaseInsensitiveScopeAndResource", scope: "READ SETTINGS", resource: Resource("Settings"), perms: Permissions{ActionView}, want: true},
|
||||
{name: "WhitespaceAndTabs", scope: "\tread\tsettings\n", resource: "settings", perms: Permissions{ActionView}, want: true},
|
||||
{name: "DefaultResourceRead", scope: "read default", resource: "", perms: Permissions{ActionView}, want: true},
|
||||
{name: "DefaultResourceUpdateDenied", scope: "read default", resource: "", perms: Permissions{ActionUpdate}, want: false},
|
||||
{name: "WriteAllowsMutation", scope: "write settings", resource: "settings", perms: Permissions{ActionUpdate}, want: true},
|
||||
{name: "WriteBlocksReadOnly", scope: "write settings", resource: "settings", perms: Permissions{ActionView}, want: false},
|
||||
{name: "ReadGrantAllowsAccessAll", scope: "read", resource: "metrics", perms: Permissions{AccessAll}, want: true},
|
||||
{name: "ReadGrantDeniesManage", scope: "read metrics", resource: "metrics", perms: Permissions{ActionManage}, want: false},
|
||||
{name: "WriteGrantAllowsManage", scope: "write metrics", resource: "metrics", perms: Permissions{ActionManage}, want: true},
|
||||
{name: "ResourceWildcard", scope: "metrics:*", resource: "metrics", perms: Permissions{ActionDelete}, want: true},
|
||||
{name: "GlobalWildcardWithoutRead", scope: "* metrics", resource: "metrics", perms: Permissions{ActionDelete}, want: true},
|
||||
{name: "ResourceWildcardWithRead", scope: "read metrics:*", resource: "metrics", perms: Permissions{ActionView}, want: true},
|
||||
{name: "ResourceWildcardWriteDenied", scope: "read metrics:*", resource: "metrics", perms: Permissions{ActionUpdate}, want: false},
|
||||
{name: "DuplicateAndNoise", scope: " read metrics metrics ", resource: "metrics", perms: nil, want: true},
|
||||
{name: "FalseOverridesTrue", scope: "metrics metrics:false", resource: "metrics", perms: nil, want: false},
|
||||
{name: "CaseInsensitiveResourceLookup", scope: "read metrics", resource: Resource("METRICS"), perms: Permissions{ActionView}, want: true},
|
||||
{name: "MixedReadWriteConflict", scope: "read write settings", resource: "settings", perms: Permissions{ActionUpdate}, want: false},
|
||||
{name: "PermissionsEmptySlice", scope: "read metrics", resource: "metrics", perms: Permissions{}, want: true},
|
||||
{name: "SimpleNonReadScopeAllows", scope: "cluster vision", resource: "cluster", perms: nil, want: true},
|
||||
{name: "SimpleNonReadScopeRejectsMissing", scope: "cluster vision", resource: "portal", perms: nil, want: false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := ScopePermits(tc.scope, tc.resource, tc.perms)
|
||||
assert.Equalf(t, tc.want, got, "scope %q resource %q perms %v", tc.scope, tc.resource, tc.perms)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopeAttrPermits(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scope string
|
||||
resource Resource
|
||||
perms Permissions
|
||||
want bool
|
||||
}{
|
||||
{name: "EmptyAttr", scope: "", resource: "metrics", perms: nil, want: false},
|
||||
{name: "Wildcard", scope: "*", resource: "metrics", perms: Permissions{ActionUpdate}, want: true},
|
||||
{name: "ReadAllowsView", scope: "read", resource: "settings", perms: Permissions{ActionView}, want: true},
|
||||
{name: "ReadBlocksUpdate", scope: "read", resource: "settings", perms: Permissions{ActionUpdate}, want: false},
|
||||
{name: "ResourceMismatch", scope: "read metrics", resource: "settings", perms: nil, want: false},
|
||||
{name: "WriteAllowsManage", scope: "write metrics", resource: "metrics", perms: Permissions{ActionManage}, want: true},
|
||||
{name: "WriteBlocksView", scope: "write metrics", resource: "metrics", perms: Permissions{ActionView}, want: false},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
attr := ScopeAttr(tc.scope)
|
||||
got := ScopeAttrPermits(attr, tc.resource, tc.perms)
|
||||
assert.Equalf(t, tc.want, got, "scope %q resource %q perms %v", tc.scope, tc.resource, tc.perms)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
cfg "github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
@@ -17,9 +17,6 @@ func TestMain(m *testing.M) {
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
event.AuditLog = log
|
||||
|
||||
c := config.TestConfig()
|
||||
defer c.CloseDb()
|
||||
|
||||
// Run unit tests.
|
||||
code := m.Run()
|
||||
|
||||
@@ -28,3 +25,7 @@ func TestMain(m *testing.M) {
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func newTestConfig(t *testing.T) *cfg.Config {
|
||||
return cfg.NewMinimalTestConfig(t.TempDir())
|
||||
}
|
||||
|
@@ -9,12 +9,11 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
cfg "github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
func TestManagerEnsureActiveKey(t *testing.T) {
|
||||
c := cfg.TestConfig()
|
||||
c := newTestConfig(t)
|
||||
m, err := NewManager(c)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, m)
|
||||
@@ -53,7 +52,7 @@ func TestManagerEnsureActiveKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestManagerGenerateSecondKey(t *testing.T) {
|
||||
c := cfg.TestConfig()
|
||||
c := newTestConfig(t)
|
||||
m, err := NewManager(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
@@ -25,6 +25,20 @@ var (
|
||||
errKeyNotFound = errors.New("jwt: key not found")
|
||||
)
|
||||
|
||||
// VerifierStatus captures diagnostic information about a verifier's JWKS cache state.
|
||||
type VerifierStatus struct {
|
||||
CacheURL string `json:"cacheUrl,omitempty"`
|
||||
CacheETag string `json:"cacheEtag,omitempty"`
|
||||
KeyIDs []string `json:"keyIds,omitempty"`
|
||||
KeyCount int `json:"keyCount"`
|
||||
CacheFetchedAt time.Time `json:"cacheFetchedAt,omitempty"`
|
||||
CacheAgeSeconds int64 `json:"cacheAgeSeconds"`
|
||||
CacheTTLSeconds int `json:"cacheTtlSeconds"`
|
||||
CacheStale bool `json:"cacheStale"`
|
||||
CachePath string `json:"cachePath,omitempty"`
|
||||
JWKSURL string `json:"jwksUrl,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// jwksFetchMaxRetries caps the number of immediate retry attempts after a fetch error.
|
||||
jwksFetchMaxRetries = 3
|
||||
@@ -99,15 +113,14 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
||||
if strings.TrimSpace(expected.Audience) == "" {
|
||||
return nil, errors.New("jwt: expected audience required")
|
||||
}
|
||||
if len(expected.Scope) == 0 {
|
||||
return nil, errors.New("jwt: expected scope required")
|
||||
|
||||
jwksUrl := strings.TrimSpace(expected.JWKSURL)
|
||||
|
||||
if jwksUrl == "" && v.conf != nil {
|
||||
jwksUrl = strings.TrimSpace(v.conf.JWKSUrl())
|
||||
}
|
||||
|
||||
url := strings.TrimSpace(expected.JWKSURL)
|
||||
if url == "" && v.conf != nil {
|
||||
url = strings.TrimSpace(v.conf.JWKSUrl())
|
||||
}
|
||||
if url == "" {
|
||||
if jwksUrl == "" {
|
||||
return nil, errors.New("jwt: jwks url not configured")
|
||||
}
|
||||
|
||||
@@ -126,16 +139,111 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
||||
claims := &Claims{}
|
||||
keyFunc := func(token *gojwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
|
||||
if kid == "" {
|
||||
return nil, errors.New("jwt: missing kid header")
|
||||
}
|
||||
pk, err := v.publicKeyForKid(ctx, url, kid, false)
|
||||
|
||||
pk, err := v.publicKeyForKid(ctx, jwksUrl, kid, false)
|
||||
|
||||
if errors.Is(err, errKeyNotFound) {
|
||||
pk, err = v.publicKeyForKid(ctx, url, kid, true)
|
||||
pk, err = v.publicKeyForKid(ctx, jwksUrl, kid, true)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
if _, err := parser.ParseWithClaims(tokenString, claims, keyFunc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
|
||||
return nil, errors.New("jwt: missing temporal claims")
|
||||
}
|
||||
|
||||
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
|
||||
return nil, errors.New("jwt: token ttl exceeds maximum")
|
||||
}
|
||||
|
||||
scopeSet := map[string]struct{}{}
|
||||
|
||||
for _, s := range strings.Fields(claims.Scope) {
|
||||
scopeSet[s] = struct{}{}
|
||||
}
|
||||
|
||||
for _, req := range expected.Scope {
|
||||
if _, ok := scopeSet[req]; !ok {
|
||||
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyTokenWithKeys verifies a token using the provided JWKS keys without performing HTTP fetches.
|
||||
func VerifyTokenWithKeys(tokenString string, expected ExpectedClaims, keys []PublicJWK, leeway time.Duration) (*Claims, error) {
|
||||
if strings.TrimSpace(tokenString) == "" {
|
||||
return nil, errors.New("jwt: token is empty")
|
||||
}
|
||||
|
||||
if len(keys) == 0 {
|
||||
return nil, errors.New("jwt: no jwks keys provided")
|
||||
}
|
||||
|
||||
if leeway <= 0 {
|
||||
leeway = 60 * time.Second
|
||||
}
|
||||
|
||||
keyMap := make(map[string]ed25519.PublicKey, len(keys))
|
||||
|
||||
for _, jwk := range keys {
|
||||
if jwk.Kid == "" {
|
||||
continue
|
||||
}
|
||||
raw, err := base64.RawURLEncoding.DecodeString(jwk.X)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(raw) != ed25519.PublicKeySize {
|
||||
return nil, fmt.Errorf("jwt: invalid public key length %d", len(raw))
|
||||
}
|
||||
pk := make(ed25519.PublicKey, ed25519.PublicKeySize)
|
||||
copy(pk, raw)
|
||||
keyMap[jwk.Kid] = pk
|
||||
}
|
||||
|
||||
if len(keyMap) == 0 {
|
||||
return nil, errors.New("jwt: no valid jwks keys provided")
|
||||
}
|
||||
|
||||
options := []gojwt.ParserOption{
|
||||
gojwt.WithLeeway(leeway),
|
||||
gojwt.WithValidMethods([]string{gojwt.SigningMethodEdDSA.Alg()}),
|
||||
}
|
||||
|
||||
if iss := strings.TrimSpace(expected.Issuer); iss != "" {
|
||||
options = append(options, gojwt.WithIssuer(iss))
|
||||
}
|
||||
|
||||
if aud := strings.TrimSpace(expected.Audience); aud != "" {
|
||||
options = append(options, gojwt.WithAudience(aud))
|
||||
}
|
||||
|
||||
parser := gojwt.NewParser(options...)
|
||||
claims := &Claims{}
|
||||
keyFunc := func(token *gojwt.Token) (interface{}, error) {
|
||||
kid, _ := token.Header["kid"].(string)
|
||||
if kid == "" {
|
||||
return nil, errors.New("jwt: missing kid header")
|
||||
}
|
||||
pk, ok := keyMap[kid]
|
||||
if !ok {
|
||||
return nil, errKeyNotFound
|
||||
}
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
@@ -146,29 +254,70 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
||||
if claims.IssuedAt == nil || claims.ExpiresAt == nil {
|
||||
return nil, errors.New("jwt: missing temporal claims")
|
||||
}
|
||||
|
||||
if ttl := claims.ExpiresAt.Time.Sub(claims.IssuedAt.Time); ttl > MaxTokenTTL {
|
||||
return nil, errors.New("jwt: token ttl exceeds maximum")
|
||||
}
|
||||
|
||||
scopeSet := map[string]struct{}{}
|
||||
for _, s := range strings.Fields(claims.Scope) {
|
||||
scopeSet[s] = struct{}{}
|
||||
}
|
||||
for _, req := range expected.Scope {
|
||||
if _, ok := scopeSet[req]; !ok {
|
||||
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||
if len(expected.Scope) > 0 {
|
||||
scopeSet := map[string]struct{}{}
|
||||
for _, s := range strings.Fields(claims.Scope) {
|
||||
scopeSet[s] = struct{}{}
|
||||
}
|
||||
for _, req := range expected.Scope {
|
||||
if _, ok := scopeSet[req]; !ok {
|
||||
return nil, fmt.Errorf("jwt: missing scope %s", req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Status returns diagnostic information about the verifier's current JWKS cache.
|
||||
func (v *Verifier) Status(ttl time.Duration) VerifierStatus {
|
||||
status := VerifierStatus{}
|
||||
|
||||
if ttl > 0 {
|
||||
status.CacheTTLSeconds = int(ttl / time.Second)
|
||||
}
|
||||
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
||||
status.CacheURL = v.cache.URL
|
||||
status.CacheETag = v.cache.ETag
|
||||
status.JWKSURL = v.cache.URL
|
||||
status.KeyCount = len(v.cache.Keys)
|
||||
status.KeyIDs = make([]string, 0, len(v.cache.Keys))
|
||||
|
||||
for _, key := range v.cache.Keys {
|
||||
status.KeyIDs = append(status.KeyIDs, key.Kid)
|
||||
}
|
||||
|
||||
status.CachePath = v.cachePath
|
||||
|
||||
if v.cache.FetchedAt > 0 {
|
||||
fetched := time.Unix(v.cache.FetchedAt, 0).UTC()
|
||||
status.CacheFetchedAt = fetched
|
||||
age := time.Since(fetched)
|
||||
status.CacheAgeSeconds = int64(age.Seconds())
|
||||
if ttl > 0 && age > ttl {
|
||||
status.CacheStale = true
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// publicKeyForKid resolves the public key for the given key ID, fetching JWKS data if needed.
|
||||
func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force bool) (ed25519.PublicKey, error) {
|
||||
keys, err := v.keysForURL(ctx, url, force)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, k := range keys {
|
||||
if k.Kid != kid {
|
||||
continue
|
||||
@@ -184,12 +333,14 @@ func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force b
|
||||
copy(pk, raw)
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
return nil, errKeyNotFound
|
||||
}
|
||||
|
||||
// keysForURL returns JWKS keys for the specified endpoint, reusing cache when possible.
|
||||
func (v *Verifier) keysForURL(ctx context.Context, url string, force bool) ([]PublicJWK, error) {
|
||||
ttl := 300 * time.Second
|
||||
|
||||
if v.conf != nil && v.conf.JWKSCacheTTL() > 0 {
|
||||
ttl = time.Duration(v.conf.JWKSCacheTTL()) * time.Second
|
||||
}
|
||||
@@ -250,13 +401,16 @@ func (v *Verifier) cachedKeys(url string, ttl time.Duration, cache cacheEntry, f
|
||||
if force || cache.URL != url || len(cache.Keys) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
age := v.now().Unix() - cache.FetchedAt
|
||||
if age < 0 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if time.Duration(age)*time.Second > ttl {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return append([]PublicJWK(nil), cache.Keys...), true
|
||||
}
|
||||
|
||||
@@ -270,17 +424,21 @@ type jwksFetchResult struct {
|
||||
// fetchJWKS downloads the JWKS document (respecting conditional requests) and returns the parsed keys.
|
||||
func (v *Verifier) fetchJWKS(ctx context.Context, url, etag string) (*jwksFetchResult, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
switch resp.StatusCode {
|
||||
@@ -331,6 +489,7 @@ func (v *Verifier) updateCache(url string, result *jwksFetchResult) ([]PublicJWK
|
||||
Keys: append([]PublicJWK(nil), result.keys...),
|
||||
FetchedAt: result.fetchedAt,
|
||||
}
|
||||
|
||||
_ = v.saveCacheLocked()
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), true
|
||||
}
|
||||
@@ -347,7 +506,7 @@ func (v *Verifier) loadCache() error {
|
||||
}
|
||||
|
||||
var entry cacheEntry
|
||||
if err := json.Unmarshal(b, &entry); err != nil {
|
||||
if err = json.Unmarshal(b, &entry); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -360,13 +519,17 @@ func (v *Verifier) saveCacheLocked() error {
|
||||
if v.cachePath == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := fs.MkdirAll(filepath.Dir(v.cachePath)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
data, err := json.Marshal(v.cache)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(v.cachePath, data, fs.ModeSecretFile)
|
||||
}
|
||||
|
||||
@@ -377,11 +540,13 @@ func backoffDuration(attempt int) time.Duration {
|
||||
}
|
||||
|
||||
base := jwksFetchBaseDelay << (attempt - 1)
|
||||
|
||||
if base > jwksFetchMaxDelay {
|
||||
base = jwksFetchMaxDelay
|
||||
}
|
||||
|
||||
jitterRange := base / 2
|
||||
|
||||
if jitterRange > 0 {
|
||||
base += time.Duration(randInt63n(int64(jitterRange) + 1))
|
||||
}
|
||||
|
@@ -13,12 +13,11 @@ import (
|
||||
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
|
||||
cfg "github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
func TestVerifierPrimeAndVerify(t *testing.T) {
|
||||
portalCfg := cfg.TestConfig()
|
||||
portalCfg := newTestConfig(t)
|
||||
clusterUUID := rnd.UUIDv7()
|
||||
portalCfg.Options().ClusterUUID = clusterUUID
|
||||
|
||||
@@ -47,7 +46,7 @@ func TestVerifierPrimeAndVerify(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
nodeCfg := cfg.NewTestConfig("jwt-verifier-node")
|
||||
nodeCfg := newTestConfig(t)
|
||||
nodeCfg.SetJWKSUrl(server.URL + "/.well-known/jwks.json")
|
||||
nodeCfg.Options().ClusterUUID = clusterUUID
|
||||
nodeUUID := nodeCfg.NodeUUID()
|
||||
@@ -104,8 +103,58 @@ func TestVerifierPrimeAndVerify(t *testing.T) {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestVerifyTokenWithKeys(t *testing.T) {
|
||||
portalCfg := newTestConfig(t)
|
||||
clusterUUID := rnd.UUIDv7()
|
||||
portalCfg.Options().ClusterUUID = clusterUUID
|
||||
|
||||
mgr, err := NewManager(portalCfg)
|
||||
require.NoError(t, err)
|
||||
mgr.now = func() time.Time { return time.Date(2025, 9, 24, 10, 30, 0, 0, time.UTC) }
|
||||
_, err = mgr.EnsureActiveKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
issuer := NewIssuer(mgr)
|
||||
issuer.now = func() time.Time { return time.Now().UTC() }
|
||||
|
||||
spec := ClaimsSpec{
|
||||
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Subject: "portal:client-test",
|
||||
Audience: "node:1234",
|
||||
Scope: []string{"cluster"},
|
||||
}
|
||||
|
||||
token, err := issuer.Issue(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
keys := mgr.JWKS().Keys
|
||||
claims, err := VerifyTokenWithKeys(token, ExpectedClaims{
|
||||
Issuer: spec.Issuer,
|
||||
Audience: spec.Audience,
|
||||
Scope: []string{"cluster"},
|
||||
}, keys, 60*time.Second)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, spec.Subject, claims.Subject)
|
||||
|
||||
// Ensure scope filtering is honored when expected scope is empty.
|
||||
claims, err = VerifyTokenWithKeys(token, ExpectedClaims{
|
||||
Issuer: spec.Issuer,
|
||||
Audience: spec.Audience,
|
||||
}, keys, 60*time.Second)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, spec.Subject, claims.Subject)
|
||||
|
||||
// Missing scope should fail when explicitly required.
|
||||
_, err = VerifyTokenWithKeys(token, ExpectedClaims{
|
||||
Issuer: spec.Issuer,
|
||||
Audience: spec.Audience,
|
||||
Scope: []string{"vision"},
|
||||
}, keys, 60*time.Second)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestIssuerClampTTL(t *testing.T) {
|
||||
portalCfg := cfg.TestConfig()
|
||||
portalCfg := newTestConfig(t)
|
||||
mgr, err := NewManager(portalCfg)
|
||||
require.NoError(t, err)
|
||||
mgr.now = func() time.Time { return time.Unix(0, 0) }
|
||||
|
@@ -15,6 +15,7 @@ var AuthCommands = &cli.Command{
|
||||
AuthShowCommand,
|
||||
AuthRemoveCommand,
|
||||
AuthResetCommand,
|
||||
AuthJWTCommands,
|
||||
},
|
||||
}
|
||||
|
||||
|
16
internal/commands/auth_jwt.go
Normal file
16
internal/commands/auth_jwt.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package commands
|
||||
|
||||
import "github.com/urfave/cli/v2"
|
||||
|
||||
// AuthJWTCommands groups JWT-related auth helpers under photoprism auth jwt.
|
||||
var AuthJWTCommands = &cli.Command{
|
||||
Name: "jwt",
|
||||
Usage: "JWT issuance and diagnostics",
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Subcommands: []*cli.Command{
|
||||
AuthJWTIssueCommand,
|
||||
AuthJWTInspectCommand,
|
||||
AuthJWTKeysCommand,
|
||||
AuthJWTStatusCommand,
|
||||
},
|
||||
}
|
154
internal/commands/auth_jwt_inspect.go
Normal file
154
internal/commands/auth_jwt_inspect.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
// AuthJWTInspectCommand inspects and verifies portal-issued JWTs.
|
||||
var AuthJWTInspectCommand = &cli.Command{
|
||||
Name: "inspect",
|
||||
Usage: "Decodes and verifies a portal JWT",
|
||||
ArgsUsage: "<token>",
|
||||
Flags: []cli.Flag{
|
||||
&cli.StringFlag{Name: "file", Aliases: []string{"f"}, Usage: "read token from file"},
|
||||
&cli.StringFlag{Name: "expect-audience", Usage: "expected audience (e.g., node:<uuid>)"},
|
||||
&cli.StringSliceFlag{Name: "require-scope", Usage: "require specific scope(s)"},
|
||||
&cli.BoolFlag{Name: "skip-verify", Usage: "decode without signature verification"},
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTInspectAction,
|
||||
}
|
||||
|
||||
// authJWTInspectAction decodes and optionally verifies a portal-issued JWT.
|
||||
func authJWTInspectAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
token, err := readTokenInput(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, claims, err := decodeJWTClaims(token)
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
var verified bool
|
||||
tokenScopes := clean.Scopes(claims.Scope)
|
||||
|
||||
if !ctx.Bool("skip-verify") {
|
||||
expected := clusterjwt.ExpectedClaims{}
|
||||
if clusterUUID := strings.TrimSpace(conf.ClusterUUID()); clusterUUID != "" {
|
||||
expected.Issuer = fmt.Sprintf("portal:%s", clusterUUID)
|
||||
} else if portal := strings.TrimSpace(conf.PortalUrl()); portal != "" {
|
||||
expected.Issuer = strings.TrimRight(portal, "/")
|
||||
}
|
||||
|
||||
if expectAud := strings.TrimSpace(ctx.String("expect-audience")); expectAud != "" {
|
||||
expected.Audience = expectAud
|
||||
} else if len(claims.Audience) > 0 {
|
||||
expected.Audience = claims.Audience[0]
|
||||
}
|
||||
|
||||
if required := ctx.StringSlice("require-scope"); len(required) > 0 {
|
||||
scopes, scopeErr := normalizeScopes(required)
|
||||
if scopeErr != nil {
|
||||
return scopeErr
|
||||
}
|
||||
expected.Scope = scopes
|
||||
} else {
|
||||
expected.Scope = tokenScopes
|
||||
}
|
||||
|
||||
if _, err := verifyPortalToken(conf, token, expected); err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
verified = true
|
||||
}
|
||||
|
||||
if ctx.Bool("json") {
|
||||
payload := map[string]any{
|
||||
"token": token,
|
||||
"verified": verified,
|
||||
"header": header,
|
||||
"claims": claims,
|
||||
}
|
||||
return printJSON(payload)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("JWT header:")
|
||||
for k, v := range header {
|
||||
fmt.Printf(" %s: %v\n", k, v)
|
||||
}
|
||||
|
||||
fmt.Println("\nJWT claims:")
|
||||
fmt.Printf(" issuer: %s\n", claims.Issuer)
|
||||
fmt.Printf(" subject: %s\n", claims.Subject)
|
||||
fmt.Printf(" audience: %s\n", strings.Join(claims.Audience, " "))
|
||||
fmt.Printf(" scope: %s\n", strings.Join(tokenScopes, " "))
|
||||
if claims.IssuedAt != nil {
|
||||
fmt.Printf(" issuedAt: %s\n", claims.IssuedAt.Time.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if claims.ExpiresAt != nil {
|
||||
fmt.Printf(" expiresAt: %s\n", claims.ExpiresAt.Time.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if claims.NotBefore != nil {
|
||||
fmt.Printf(" notBefore: %s\n", claims.NotBefore.Time.UTC().Format(time.RFC3339))
|
||||
}
|
||||
if claims.ID != "" {
|
||||
fmt.Printf(" jti: %s\n", claims.ID)
|
||||
}
|
||||
|
||||
if verified {
|
||||
fmt.Println("\nSignature: verified")
|
||||
} else {
|
||||
fmt.Println("\nSignature: not verified (skipped)")
|
||||
}
|
||||
|
||||
fmt.Printf("\nToken:\n%s\n\n", token)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// readTokenInput loads the token from CLI args, file, or STDIN.
|
||||
func readTokenInput(ctx *cli.Context) (string, error) {
|
||||
if file := strings.TrimSpace(ctx.String("file")); file != "" {
|
||||
data, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
return "", cli.Exit(err, 1)
|
||||
}
|
||||
return strings.TrimSpace(string(data)), nil
|
||||
}
|
||||
|
||||
if ctx.Args().Len() == 0 {
|
||||
return "", cli.Exit(errors.New("token argument required"), 2)
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(ctx.Args().First())
|
||||
if token == "-" {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return "", cli.Exit(err, 1)
|
||||
}
|
||||
token = strings.TrimSpace(string(data))
|
||||
}
|
||||
if token == "" {
|
||||
return "", cli.Exit(errors.New("token argument required"), 2)
|
||||
}
|
||||
return token, nil
|
||||
}
|
117
internal/commands/auth_jwt_issue.go
Normal file
117
internal/commands/auth_jwt_issue.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
)
|
||||
|
||||
// AuthJWTIssueCommand issues portal-signed JWTs for cluster nodes.
|
||||
var AuthJWTIssueCommand = &cli.Command{
|
||||
Name: "issue",
|
||||
Usage: "Issues a portal-signed JWT for a node",
|
||||
Flags: []cli.Flag{
|
||||
&cli.StringFlag{Name: "node", Aliases: []string{"n"}, Usage: "target node uuid, client id, or DNS label", Required: true},
|
||||
&cli.StringSliceFlag{Name: "scope", Aliases: []string{"s"}, Usage: "token scope", Value: cli.NewStringSlice("cluster")},
|
||||
&cli.DurationFlag{Name: "ttl", Usage: "token lifetime", Value: clusterjwt.TokenTTL},
|
||||
&cli.StringFlag{Name: "subject", Usage: "token subject (default portal:<clusterUUID>)"},
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTIssueAction,
|
||||
}
|
||||
|
||||
// authJWTIssueAction handles CLI issuance of portal-signed JWTs for nodes.
|
||||
func authJWTIssueAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
node, err := resolveNode(conf, ctx.String("node"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
scopes, err := normalizeScopes(ctx.StringSlice("scope"), "cluster")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ttl := ctx.Duration("ttl")
|
||||
if ttl <= 0 {
|
||||
ttl = clusterjwt.TokenTTL
|
||||
}
|
||||
|
||||
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
|
||||
if clusterUUID == "" {
|
||||
return cli.Exit(fmt.Errorf("cluster uuid not configured"), 1)
|
||||
}
|
||||
|
||||
subject := strings.TrimSpace(ctx.String("subject"))
|
||||
if subject == "" {
|
||||
subject = fmt.Sprintf("portal:%s", clusterUUID)
|
||||
}
|
||||
|
||||
var token string
|
||||
if subject == fmt.Sprintf("portal:%s", clusterUUID) {
|
||||
token, err = get.IssuePortalJWTForNode(node.UUID, scopes, ttl)
|
||||
} else {
|
||||
spec := clusterjwt.ClaimsSpec{
|
||||
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Subject: subject,
|
||||
Audience: fmt.Sprintf("node:%s", node.UUID),
|
||||
Scope: scopes,
|
||||
TTL: ttl,
|
||||
}
|
||||
token, err = get.IssuePortalJWT(spec)
|
||||
}
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
header, claims, err := decodeJWTClaims(token)
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
if ctx.Bool("json") {
|
||||
payload := map[string]any{
|
||||
"token": token,
|
||||
"header": header,
|
||||
"claims": claims,
|
||||
"node": map[string]string{
|
||||
"uuid": node.UUID,
|
||||
"clientId": node.ClientID,
|
||||
"name": node.Name,
|
||||
"role": string(node.Role),
|
||||
},
|
||||
}
|
||||
return printJSON(payload)
|
||||
}
|
||||
|
||||
expires := "unknown"
|
||||
if claims.ExpiresAt != nil {
|
||||
expires = claims.ExpiresAt.Time.UTC().Format(time.RFC3339)
|
||||
}
|
||||
audience := strings.Join(claims.Audience, " ")
|
||||
if audience == "" {
|
||||
audience = "(none)"
|
||||
}
|
||||
|
||||
fmt.Printf("\nIssued JWT for node %s (%s)\n", node.Name, node.UUID)
|
||||
fmt.Printf("Scopes: %s\n", strings.Join(scopes, " "))
|
||||
fmt.Printf("Expires: %s\n", expires)
|
||||
fmt.Printf("Audience: %s\n", audience)
|
||||
fmt.Printf("Subject: %s\n", claims.Subject)
|
||||
fmt.Printf("Key ID: %v\n", header["kid"])
|
||||
fmt.Printf("\n%s\n", token)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
107
internal/commands/auth_jwt_keys.go
Normal file
107
internal/commands/auth_jwt_keys.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
)
|
||||
|
||||
// AuthJWTKeysCommand groups JWT key management helpers.
|
||||
var AuthJWTKeysCommand = &cli.Command{
|
||||
Name: "keys",
|
||||
Usage: "JWT signing key helpers",
|
||||
Subcommands: []*cli.Command{
|
||||
AuthJWTKeysListCommand,
|
||||
},
|
||||
}
|
||||
|
||||
// AuthJWTKeysListCommand lists JWT signing keys.
|
||||
var AuthJWTKeysListCommand = &cli.Command{
|
||||
Name: "ls",
|
||||
Usage: "Lists JWT signing keys",
|
||||
Aliases: []string{"list"},
|
||||
ArgsUsage: "",
|
||||
Flags: []cli.Flag{
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTKeysListAction,
|
||||
}
|
||||
|
||||
// authJWTKeysListAction lists portal signing keys with metadata.
|
||||
func authJWTKeysListAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
manager := get.JWTManager()
|
||||
if manager == nil {
|
||||
return cli.Exit(errors.New("jwt manager not available"), 1)
|
||||
}
|
||||
|
||||
keys := manager.AllKeys()
|
||||
active, _ := manager.ActiveKey()
|
||||
activeKid := ""
|
||||
if active != nil {
|
||||
activeKid = active.Kid
|
||||
}
|
||||
|
||||
type keyInfo struct {
|
||||
Kid string `json:"kid"`
|
||||
CreatedAt string `json:"createdAt"`
|
||||
NotAfter string `json:"notAfter,omitempty"`
|
||||
Active bool `json:"active"`
|
||||
}
|
||||
|
||||
rows := make([]keyInfo, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
info := keyInfo{Kid: k.Kid, Active: k.Kid == activeKid}
|
||||
if k.CreatedAt > 0 {
|
||||
info.CreatedAt = time.Unix(k.CreatedAt, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
if k.NotAfter > 0 {
|
||||
info.NotAfter = time.Unix(k.NotAfter, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
rows = append(rows, info)
|
||||
}
|
||||
|
||||
if ctx.Bool("json") {
|
||||
payload := map[string]any{
|
||||
"keys": rows,
|
||||
}
|
||||
return printJSON(payload)
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
fmt.Println()
|
||||
fmt.Println("No signing keys found.")
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Println("JWT signing keys:")
|
||||
for _, row := range rows {
|
||||
status := ""
|
||||
if row.Active {
|
||||
status = " (active)"
|
||||
}
|
||||
parts := []string{fmt.Sprintf("KID: %s%s", row.Kid, status)}
|
||||
if row.CreatedAt != "" {
|
||||
parts = append(parts, fmt.Sprintf("created %s", row.CreatedAt))
|
||||
}
|
||||
if row.NotAfter != "" {
|
||||
parts = append(parts, fmt.Sprintf("expires %s", row.NotAfter))
|
||||
}
|
||||
fmt.Printf("- %s\n", strings.Join(parts, ", "))
|
||||
}
|
||||
fmt.Println()
|
||||
return nil
|
||||
})
|
||||
}
|
67
internal/commands/auth_jwt_status.go
Normal file
67
internal/commands/auth_jwt_status.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
)
|
||||
|
||||
// AuthJWTStatusCommand reports verifier cache diagnostics.
|
||||
var AuthJWTStatusCommand = &cli.Command{
|
||||
Name: "status",
|
||||
Usage: "Shows JWT verifier cache status",
|
||||
Flags: []cli.Flag{
|
||||
JsonFlag(),
|
||||
},
|
||||
Action: authJWTStatusAction,
|
||||
}
|
||||
|
||||
// authJWTStatusAction prints JWKS cache diagnostics for the current node.
|
||||
func authJWTStatusAction(ctx *cli.Context) error {
|
||||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
verifier := get.JWTVerifier()
|
||||
if verifier == nil {
|
||||
return cli.Exit(errors.New("jwt verifier not available"), 1)
|
||||
}
|
||||
|
||||
ttl := time.Duration(conf.JWKSCacheTTL()) * time.Second
|
||||
status := verifier.Status(ttl)
|
||||
status.JWKSURL = strings.TrimSpace(conf.JWKSUrl())
|
||||
|
||||
if ctx.Bool("json") {
|
||||
return printJSON(status)
|
||||
}
|
||||
|
||||
fmt.Println()
|
||||
fmt.Printf("JWKS URL: %s\n", status.JWKSURL)
|
||||
fmt.Printf("Cache Path: %s\n", status.CachePath)
|
||||
fmt.Printf("Cache URL: %s\n", status.CacheURL)
|
||||
fmt.Printf("Cache ETag: %s\n", status.CacheETag)
|
||||
fmt.Printf("Cached Keys: %d\n", status.KeyCount)
|
||||
if len(status.KeyIDs) > 0 {
|
||||
fmt.Printf("Key IDs: %s\n", strings.Join(status.KeyIDs, ", "))
|
||||
}
|
||||
if !status.CacheFetchedAt.IsZero() {
|
||||
fmt.Printf("Last Fetch: %s\n", status.CacheFetchedAt.Format(time.RFC3339))
|
||||
} else {
|
||||
fmt.Println("Last Fetch: never")
|
||||
}
|
||||
fmt.Printf("Cache Age: %ds\n", status.CacheAgeSeconds)
|
||||
if status.CacheTTLSeconds > 0 {
|
||||
fmt.Printf("Cache TTL: %ds\n", status.CacheTTLSeconds)
|
||||
}
|
||||
if status.CacheStale {
|
||||
fmt.Println("Cache Status: STALE")
|
||||
} else {
|
||||
fmt.Println("Cache Status: fresh")
|
||||
}
|
||||
fmt.Println()
|
||||
return nil
|
||||
})
|
||||
}
|
94
internal/commands/auth_jwt_test.go
Normal file
94
internal/commands/auth_jwt_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
"github.com/photoprism/photoprism/internal/service/cluster"
|
||||
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
func TestAuthJWTCommands(t *testing.T) {
|
||||
conf := get.Config()
|
||||
|
||||
origEdition := conf.Options().Edition
|
||||
origRole := conf.Options().NodeRole
|
||||
origUUID := conf.Options().ClusterUUID
|
||||
origPortal := conf.Options().PortalUrl
|
||||
origJWKS := conf.JWKSUrl()
|
||||
|
||||
conf.Options().Edition = config.Portal
|
||||
conf.Options().NodeRole = string(cluster.RolePortal)
|
||||
conf.Options().ClusterUUID = "11111111-1111-4111-8111-111111111111"
|
||||
conf.Options().PortalUrl = "https://portal.test"
|
||||
conf.SetJWKSUrl("https://portal.test/.well-known/jwks.json")
|
||||
|
||||
get.SetConfig(conf)
|
||||
conf.RegisterDb()
|
||||
|
||||
require.True(t, conf.IsPortal())
|
||||
|
||||
manager := get.JWTManager()
|
||||
require.NotNil(t, manager)
|
||||
_, err := manager.EnsureActiveKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
registry, err := reg.NewClientRegistryWithConfig(conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodeUUID := rnd.UUID()
|
||||
node := ®.Node{}
|
||||
node.UUID = nodeUUID
|
||||
node.Name = "pp-node-01"
|
||||
node.Role = string(cluster.RoleInstance)
|
||||
require.NoError(t, registry.Put(node))
|
||||
t.Cleanup(func() {
|
||||
conf.Options().Edition = origEdition
|
||||
conf.Options().NodeRole = origRole
|
||||
conf.Options().ClusterUUID = origUUID
|
||||
conf.Options().PortalUrl = origPortal
|
||||
conf.SetJWKSUrl(origJWKS)
|
||||
get.SetConfig(conf)
|
||||
conf.RegisterDb()
|
||||
})
|
||||
|
||||
output, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, output, "Issued JWT")
|
||||
|
||||
jsonOut, err := RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--json"})
|
||||
require.NoError(t, err)
|
||||
|
||||
var payload struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal([]byte(jsonOut), &payload))
|
||||
require.NotEmpty(t, payload.Token)
|
||||
|
||||
inspectOut, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", payload.Token})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, inspectOut, "\"verified\": true")
|
||||
|
||||
inspectStrict, err := RunWithTestContext(AuthJWTInspectCommand, []string{"inspect", "--json", "--expect-audience", "node:" + nodeUUID, "--require-scope", "cluster", payload.Token})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, inspectStrict, "\"verified\": true")
|
||||
|
||||
keysOut, err := RunWithTestContext(AuthJWTKeysListCommand, []string{"ls", "--json"})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, keysOut, "\"keys\"")
|
||||
|
||||
statusOut, err := RunWithTestContext(AuthJWTStatusCommand, []string{"status"})
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, statusOut, "JWKS URL")
|
||||
assert.Contains(t, statusOut, "Cached Keys")
|
||||
|
||||
// invalid scope should fail
|
||||
_, err = RunWithTestContext(AuthJWTIssueCommand, []string{"issue", "--node", nodeUUID, "--scope", "unknown"})
|
||||
require.Error(t, err)
|
||||
}
|
@@ -71,7 +71,7 @@ func clientsAddAction(ctx *cli.Context) error {
|
||||
|
||||
// Set a default client name if no specific name has been provided.
|
||||
if frm.AuthScope == "" {
|
||||
frm.AuthScope = list.All
|
||||
frm.AuthScope = list.Any
|
||||
}
|
||||
|
||||
client, addErr := entity.AddClient(frm)
|
||||
|
@@ -19,8 +19,9 @@ type healthResponse struct {
|
||||
// ClusterHealthCommand prints a minimal health response (Portal-only).
|
||||
var ClusterHealthCommand = &cli.Command{
|
||||
Name: "health",
|
||||
Usage: "Shows cluster health (Portal-only)",
|
||||
Usage: "Shows cluster health status",
|
||||
Flags: report.CliFlags,
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterHealthAction,
|
||||
}
|
||||
|
||||
|
@@ -14,8 +14,9 @@ import (
|
||||
|
||||
// ClusterNodesCommands groups node subcommands.
|
||||
var ClusterNodesCommands = &cli.Command{
|
||||
Name: "nodes",
|
||||
Usage: "Node registry subcommands",
|
||||
Name: "nodes",
|
||||
Usage: "Node registry subcommands",
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Subcommands: []*cli.Command{
|
||||
ClusterNodesListCommand,
|
||||
ClusterNodesShowCommand,
|
||||
@@ -28,9 +29,10 @@ var ClusterNodesCommands = &cli.Command{
|
||||
// ClusterNodesListCommand lists registered nodes.
|
||||
var ClusterNodesListCommand = &cli.Command{
|
||||
Name: "ls",
|
||||
Usage: "Lists registered cluster nodes (Portal-only)",
|
||||
Usage: "Lists registered cluster nodes",
|
||||
Flags: append(report.CliFlags, CountFlag, OffsetFlag),
|
||||
ArgsUsage: "",
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesListAction,
|
||||
}
|
||||
|
||||
|
@@ -22,9 +22,10 @@ var (
|
||||
// ClusterNodesModCommand updates node fields.
|
||||
var ClusterNodesModCommand = &cli.Command{
|
||||
Name: "mod",
|
||||
Usage: "Updates node properties (Portal-only)",
|
||||
Usage: "Updates node properties",
|
||||
ArgsUsage: "<id|name>",
|
||||
Flags: []cli.Flag{nodesModRoleFlag, nodesModInternal, nodesModLabel, &cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"}},
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesModAction,
|
||||
}
|
||||
|
||||
|
@@ -14,12 +14,13 @@ import (
|
||||
// ClusterNodesRemoveCommand deletes a node from the registry.
|
||||
var ClusterNodesRemoveCommand = &cli.Command{
|
||||
Name: "rm",
|
||||
Usage: "Deletes a node from the registry (Portal-only)",
|
||||
Usage: "Deletes a node from the registry",
|
||||
ArgsUsage: "<id|name>",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{Name: "yes", Aliases: []string{"y"}, Usage: "runs the command non-interactively"},
|
||||
&cli.BoolFlag{Name: "all-ids", Usage: "delete all records that share the same UUID (admin cleanup)"},
|
||||
},
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesRemoveAction,
|
||||
}
|
||||
|
||||
|
@@ -2,6 +2,7 @@ package commands
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -106,11 +107,13 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
|
||||
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
|
||||
var resp cluster.RegisterResponse
|
||||
if err := postWithBackoff(url, token, b, &resp); err != nil {
|
||||
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
|
||||
// Map common HTTP errors similarly to register command
|
||||
if he, ok := err.(*httpError); ok {
|
||||
var he *httpError
|
||||
if errors.As(err, &he) {
|
||||
switch he.Status {
|
||||
case 401, 403:
|
||||
return cli.Exit(fmt.Errorf("%s", he.Error()), 4)
|
||||
@@ -151,6 +154,7 @@ func clusterNodesRotateAction(ctx *cli.Context) error {
|
||||
fmt.Printf("DSN: %s\n", resp.Database.DSN)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
@@ -15,9 +15,10 @@ import (
|
||||
// ClusterNodesShowCommand shows node details.
|
||||
var ClusterNodesShowCommand = &cli.Command{
|
||||
Name: "show",
|
||||
Usage: "Shows node details (Portal-only)",
|
||||
Usage: "Shows node details",
|
||||
ArgsUsage: "<id|name>",
|
||||
Flags: report.CliFlags,
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterNodesShowAction,
|
||||
}
|
||||
|
||||
|
@@ -20,11 +20,12 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/service/cluster"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/txt/report"
|
||||
)
|
||||
|
||||
// flags for register
|
||||
// Supported cluster node register flags.
|
||||
var (
|
||||
regNameFlag = &cli.StringFlag{Name: "name", Usage: "node `NAME` (lowercase letters, digits, hyphens)"}
|
||||
regRoleFlag = &cli.StringFlag{Name: "role", Usage: "node `ROLE` (instance, service)", Value: "instance"}
|
||||
@@ -42,7 +43,7 @@ var (
|
||||
// ClusterRegisterCommand registers a node with the Portal via HTTP.
|
||||
var ClusterRegisterCommand = &cli.Command{
|
||||
Name: "register",
|
||||
Usage: "Registers/rotates a node via Portal (HTTP)",
|
||||
Usage: "Registers a node or updates its credentials within a cluster",
|
||||
Flags: append(append([]cli.Flag{regNameFlag, regRoleFlag, regIntUrlFlag, regLabelFlag, regRotateDatabase, regRotateSec, regPortalURL, regPortalTok, regWriteConf, regForceFlag, regDryRun}, report.CliFlags...)),
|
||||
Action: clusterRegisterAction,
|
||||
}
|
||||
@@ -52,15 +53,18 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
// Resolve inputs
|
||||
name := clean.DNSLabel(ctx.String("name"))
|
||||
derivedName := false
|
||||
|
||||
if name == "" { // default from config if set
|
||||
name = clean.DNSLabel(conf.NodeName())
|
||||
if name != "" {
|
||||
derivedName = true
|
||||
}
|
||||
}
|
||||
|
||||
if name == "" {
|
||||
return cli.Exit(fmt.Errorf("node name is required (use --name or set node-name)"), 2)
|
||||
}
|
||||
|
||||
nodeRole := clean.TypeLowerDash(ctx.String("role"))
|
||||
switch nodeRole {
|
||||
case "instance", "service":
|
||||
@@ -76,7 +80,6 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
derivedPortal = true
|
||||
}
|
||||
}
|
||||
// In dry-run, we allow empty portalURL (will print derived/empty values).
|
||||
|
||||
// Derive advertise/site URLs when omitted.
|
||||
advertise := ctx.String("advertise-url")
|
||||
@@ -93,17 +96,20 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
RotateDatabase: ctx.Bool("rotate"),
|
||||
RotateSecret: ctx.Bool("rotate-secret"),
|
||||
}
|
||||
|
||||
// If we already have client credentials (e.g., re-register), include them so the
|
||||
// portal can verify and authorize UUID/name moves or metadata updates.
|
||||
if id, secret := strings.TrimSpace(conf.NodeClientID()), strings.TrimSpace(conf.NodeClientSecret()); id != "" && secret != "" {
|
||||
payload.ClientID = id
|
||||
payload.ClientSecret = secret
|
||||
}
|
||||
|
||||
if site != "" && site != advertise {
|
||||
payload.SiteUrl = site
|
||||
}
|
||||
b, _ := json.Marshal(payload)
|
||||
|
||||
// In dry-run, we allow empty portalURL (will print derived/empty values).
|
||||
if ctx.Bool("dry-run") {
|
||||
if ctx.Bool("json") {
|
||||
out := map[string]any{"portalUrl": portalURL, "payload": payload}
|
||||
@@ -140,18 +146,22 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
if portalURL == "" {
|
||||
return cli.Exit(fmt.Errorf("portal URL is required (use --portal-url or set portal-url)"), 2)
|
||||
}
|
||||
|
||||
token := ctx.String("join-token")
|
||||
|
||||
if token == "" {
|
||||
token = conf.JoinToken()
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return cli.Exit(fmt.Errorf("portal token is required (use --join-token or set join-token)"), 2)
|
||||
}
|
||||
|
||||
// POST with bounded backoff on 429
|
||||
url := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
endpointUrl := stringsTrimRightSlash(portalURL) + "/api/v1/cluster/nodes/register"
|
||||
|
||||
var resp cluster.RegisterResponse
|
||||
if err := postWithBackoff(url, token, b, &resp); err != nil {
|
||||
if err := postWithBackoff(endpointUrl, token, b, &resp); err != nil {
|
||||
var httpErr *httpError
|
||||
if errors.As(err, &httpErr) && httpErr.Status == http.StatusTooManyRequests {
|
||||
return cli.Exit(fmt.Errorf("portal rate-limited registration attempts"), 6)
|
||||
@@ -179,13 +189,17 @@ func clusterRegisterAction(ctx *cli.Context) error {
|
||||
} else {
|
||||
// Human-readable: node row and credentials if present (UUID first as primary identifier)
|
||||
cols := []string{"UUID", "ClientID", "Name", "Role", "DB Driver", "DB Name", "DB User", "Host", "Port"}
|
||||
|
||||
var dbName, dbUser string
|
||||
|
||||
if resp.Database.Name != "" {
|
||||
dbName = resp.Database.Name
|
||||
}
|
||||
|
||||
if resp.Database.User != "" {
|
||||
dbUser = resp.Database.User
|
||||
}
|
||||
|
||||
rows := [][]string{{resp.Node.UUID, resp.Node.ClientID, resp.Node.Name, resp.Node.Role, resp.Database.Driver, dbName, dbUser, resp.Database.Host, fmt.Sprintf("%d", resp.Database.Port)}}
|
||||
out, _ := report.RenderFormat(rows, cols, report.CliFormat(ctx))
|
||||
fmt.Printf("\n%s\n", out)
|
||||
@@ -317,6 +331,16 @@ func parseLabelSlice(labels []string) map[string]string {
|
||||
|
||||
// Persistence helpers for --write-config
|
||||
func persistRegisterResponse(conf *config.Config, resp *cluster.RegisterResponse) error {
|
||||
updates := map[string]any{}
|
||||
|
||||
if rnd.IsUUID(resp.UUID) {
|
||||
updates["ClusterUUID"] = resp.UUID
|
||||
}
|
||||
|
||||
if cidr := strings.TrimSpace(resp.ClusterCIDR); cidr != "" {
|
||||
updates["ClusterCIDR"] = cidr
|
||||
}
|
||||
|
||||
// Node client secret file
|
||||
if resp.Secrets != nil && resp.Secrets.ClientSecret != "" {
|
||||
// Prefer PHOTOPRISM_NODE_CLIENT_SECRET_FILE; otherwise config cluster path
|
||||
@@ -335,16 +359,18 @@ func persistRegisterResponse(conf *config.Config, resp *cluster.RegisterResponse
|
||||
|
||||
// DB settings (MySQL/MariaDB only)
|
||||
if resp.Database.Name != "" && resp.Database.User != "" {
|
||||
if err := mergeOptionsYaml(conf, map[string]any{
|
||||
"DatabaseDriver": config.MySQL,
|
||||
"DatabaseName": resp.Database.Name,
|
||||
"DatabaseServer": fmt.Sprintf("%s:%d", resp.Database.Host, resp.Database.Port),
|
||||
"DatabaseUser": resp.Database.User,
|
||||
"DatabasePassword": resp.Database.Password,
|
||||
}); err != nil {
|
||||
updates["DatabaseDriver"] = config.MySQL
|
||||
updates["DatabaseName"] = resp.Database.Name
|
||||
updates["DatabaseServer"] = fmt.Sprintf("%s:%d", resp.Database.Host, resp.Database.Port)
|
||||
updates["DatabaseUser"] = resp.Database.User
|
||||
updates["DatabasePassword"] = resp.Database.Password
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := mergeOptionsYaml(conf, updates); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("updated options.yml with database settings for node %s", clean.LogQuote(resp.Node.Name))
|
||||
log.Infof("updated options.yml with cluster registration settings for node %s", clean.LogQuote(resp.Node.Name))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -13,11 +13,12 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/txt/report"
|
||||
)
|
||||
|
||||
// ClusterSummaryCommand prints a minimal cluster summary (Portal-only).
|
||||
// ClusterSummaryCommand prints a minimal cluster summary.
|
||||
var ClusterSummaryCommand = &cli.Command{
|
||||
Name: "summary",
|
||||
Usage: "Shows cluster summary (Portal-only)",
|
||||
Usage: "Shows cluster summary",
|
||||
Flags: report.CliFlags,
|
||||
Hidden: true, // Required for cluster-management only.
|
||||
Action: clusterSummaryAction,
|
||||
}
|
||||
|
||||
@@ -35,10 +36,11 @@ func clusterSummaryAction(ctx *cli.Context) error {
|
||||
nodes, _ := r.List()
|
||||
|
||||
resp := cluster.SummaryResponse{
|
||||
UUID: conf.ClusterUUID(),
|
||||
Nodes: len(nodes),
|
||||
Database: cluster.DatabaseInfo{Driver: conf.DatabaseDriverName(), Host: conf.DatabaseHost(), Port: conf.DatabasePort()},
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
UUID: conf.ClusterUUID(),
|
||||
ClusterCIDR: conf.ClusterCIDR(),
|
||||
Nodes: len(nodes),
|
||||
Database: cluster.DatabaseInfo{Driver: conf.DatabaseDriverName(), Host: conf.DatabaseHost(), Port: conf.DatabasePort()},
|
||||
Time: time.Now().UTC().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
if ctx.Bool("json") {
|
||||
@@ -47,8 +49,8 @@ func clusterSummaryAction(ctx *cli.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
cols := []string{"Portal UUID", "Nodes", "DB Driver", "DB Host", "DB Port", "Time"}
|
||||
rows := [][]string{{resp.UUID, fmt.Sprintf("%d", resp.Nodes), resp.Database.Driver, resp.Database.Host, fmt.Sprintf("%d", resp.Database.Port), resp.Time}}
|
||||
cols := []string{"Portal UUID", "Cluster CIDR", "Nodes", "DB Driver", "DB Host", "DB Port", "Time"}
|
||||
rows := [][]string{{resp.UUID, resp.ClusterCIDR, fmt.Sprintf("%d", resp.Nodes), resp.Database.Driver, resp.Database.Host, fmt.Sprintf("%d", resp.Database.Port), resp.Time}}
|
||||
out, err := report.RenderFormat(rows, cols, report.CliFormat(ctx))
|
||||
fmt.Printf("\n%s\n", out)
|
||||
return err
|
||||
|
@@ -95,9 +95,10 @@ func TestClusterThemePull_JoinTokenToOAuth(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Return NodeClientID and a fresh secret
|
||||
_ = json.NewEncoder(w).Encode(cluster.RegisterResponse{
|
||||
UUID: rnd.UUID(),
|
||||
Node: cluster.Node{ClientID: "cs5gfen1bgxz7s9i", Name: "pp-node-01"},
|
||||
Secrets: &cluster.RegisterSecrets{ClientSecret: "s3cr3t"},
|
||||
UUID: rnd.UUID(),
|
||||
ClusterCIDR: "203.0.113.0/24",
|
||||
Node: cluster.Node{ClientID: "cs5gfen1bgxz7s9i", Name: "pp-node-01"},
|
||||
Secrets: &cluster.RegisterSecrets{ClientSecret: "s3cr3t"},
|
||||
})
|
||||
case "/api/v1/oauth/token":
|
||||
// Expect Basic for the returned creds
|
||||
|
@@ -28,7 +28,13 @@ func TestMain(m *testing.M) {
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
event.AuditLog = log
|
||||
|
||||
c := config.NewTestConfig("commands")
|
||||
tempDir, err := os.MkdirTemp("", "commands-test")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
c := config.NewMinimalTestConfigWithDb("commands", tempDir)
|
||||
get.SetConfig(c)
|
||||
|
||||
// Keep DB connection open for the duration of this package's tests to
|
||||
@@ -91,7 +97,6 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
|
||||
|
||||
// Ensure DB connection is open for each command run (some commands call Shutdown).
|
||||
if c := get.Config(); c != nil {
|
||||
_ = c.Init() // safe to call; re-opens DB if needed
|
||||
c.RegisterDb() // (re)register provider
|
||||
}
|
||||
|
||||
@@ -104,5 +109,11 @@ func RunWithTestContext(cmd *cli.Command, args []string) (output string, err err
|
||||
err = cmd.Run(ctx, args...)
|
||||
})
|
||||
|
||||
// Re-open the database after the command completed so follow-up checks
|
||||
// (potentially issued by the test itself) have an active connection.
|
||||
if c := get.Config(); c != nil {
|
||||
c.RegisterDb()
|
||||
}
|
||||
|
||||
return output, err
|
||||
}
|
||||
|
@@ -81,9 +81,10 @@ func TestDownloadImpl_FileMethod_AutoSkipsRemux(t *testing.T) {
|
||||
if conf == nil {
|
||||
t.Fatalf("missing test config")
|
||||
}
|
||||
|
||||
// Ensure DB is initialized and registered (bypassing CLI InitConfig)
|
||||
_ = conf.Init()
|
||||
conf.RegisterDb()
|
||||
|
||||
// Override yt-dlp after config init (config may set dl.YtDlpBin)
|
||||
dl.YtDlpBin = fake
|
||||
t.Logf("using yt-dlp binary: %s", dl.YtDlpBin)
|
||||
@@ -125,7 +126,6 @@ func TestDownloadImpl_FileMethod_Skip_NoRemux(t *testing.T) {
|
||||
if conf == nil {
|
||||
t.Fatalf("missing test config")
|
||||
}
|
||||
_ = conf.Init()
|
||||
conf.RegisterDb()
|
||||
dl.YtDlpBin = fake
|
||||
|
||||
@@ -196,8 +196,9 @@ func TestDownloadImpl_FileMethod_Always_RemuxFails(t *testing.T) {
|
||||
if conf == nil {
|
||||
t.Fatalf("missing test config")
|
||||
}
|
||||
_ = conf.Init()
|
||||
|
||||
conf.RegisterDb()
|
||||
|
||||
dl.YtDlpBin = fake
|
||||
|
||||
err := runDownload(conf, DownloadOpts{
|
||||
|
8
internal/commands/flags.go
Normal file
8
internal/commands/flags.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package commands
|
||||
|
||||
import "github.com/urfave/cli/v2"
|
||||
|
||||
// JsonFlag returns the shared CLI flag definition for JSON output across commands.
|
||||
func JsonFlag() *cli.BoolFlag {
|
||||
return &cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"}
|
||||
}
|
173
internal/commands/jwt_helpers.go
Normal file
173
internal/commands/jwt_helpers.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package commands
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/auth/acl"
|
||||
clusterjwt "github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
reg "github.com/photoprism/photoprism/internal/service/cluster/registry"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
var allowedJWTScope = func() map[string]struct{} {
|
||||
out := make(map[string]struct{}, len(acl.ResourceNames))
|
||||
for _, res := range acl.ResourceNames {
|
||||
out[res.String()] = struct{}{}
|
||||
}
|
||||
return out
|
||||
}()
|
||||
|
||||
// requirePortal returns a CLI error when the active configuration is not a portal node.
|
||||
func requirePortal(conf *config.Config) error {
|
||||
if conf == nil || !conf.IsPortal() {
|
||||
return cli.Exit(errors.New("command requires a Portal node"), 2)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveNode finds a node by UUID, client ID, or DNS label using the portal registry.
|
||||
func resolveNode(conf *config.Config, identifier string) (*reg.Node, error) {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := strings.TrimSpace(identifier)
|
||||
if key == "" {
|
||||
return nil, cli.Exit(errors.New("node identifier required"), 2)
|
||||
}
|
||||
|
||||
registry, err := reg.NewClientRegistryWithConfig(conf)
|
||||
if err != nil {
|
||||
return nil, cli.Exit(err, 1)
|
||||
}
|
||||
|
||||
if node, err := registry.FindByNodeUUID(key); err == nil && node != nil {
|
||||
return node, nil
|
||||
}
|
||||
if node, err := registry.FindByClientID(key); err == nil && node != nil {
|
||||
return node, nil
|
||||
}
|
||||
|
||||
name := clean.DNSLabel(key)
|
||||
if name == "" {
|
||||
return nil, cli.Exit(errors.New("invalid node identifier"), 2)
|
||||
}
|
||||
|
||||
node, err := registry.FindByName(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, reg.ErrNotFound) {
|
||||
return nil, cli.Exit(fmt.Errorf("node %q not found", identifier), 3)
|
||||
}
|
||||
return nil, cli.Exit(err, 1)
|
||||
}
|
||||
return node, nil
|
||||
}
|
||||
|
||||
// decodeJWTClaims decodes the compact JWT and returns header and claims without verifying the signature.
|
||||
func decodeJWTClaims(token string) (map[string]any, *clusterjwt.Claims, error) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, nil, errors.New("jwt: token must contain three segments")
|
||||
}
|
||||
|
||||
decode := func(segment string) ([]byte, error) {
|
||||
return base64.RawURLEncoding.DecodeString(segment)
|
||||
}
|
||||
|
||||
headerBytes, err := decode(parts[0])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
payloadBytes, err := decode(parts[1])
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var header map[string]any
|
||||
if err := json.Unmarshal(headerBytes, &header); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
claims := &clusterjwt.Claims{}
|
||||
if err := json.Unmarshal(payloadBytes, claims); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return header, claims, nil
|
||||
}
|
||||
|
||||
// verifyPortalToken verifies a JWT using the portal's in-memory key manager.
|
||||
func verifyPortalToken(conf *config.Config, token string, expected clusterjwt.ExpectedClaims) (*clusterjwt.Claims, error) {
|
||||
if err := requirePortal(conf); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager := get.JWTManager()
|
||||
if manager == nil {
|
||||
return nil, cli.Exit(errors.New("jwt issuer not available"), 1)
|
||||
}
|
||||
|
||||
jwks := manager.JWKS()
|
||||
if jwks == nil || len(jwks.Keys) == 0 {
|
||||
return nil, cli.Exit(errors.New("jwks key set is empty"), 1)
|
||||
}
|
||||
|
||||
leeway := time.Duration(conf.JWTLeeway()) * time.Second
|
||||
if leeway <= 0 {
|
||||
leeway = 60 * time.Second
|
||||
}
|
||||
|
||||
claims, err := clusterjwt.VerifyTokenWithKeys(token, expected, jwks.Keys, leeway)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// normalizeScopes trims and de-duplicates scope values, falling back to defaults when necessary.
|
||||
func normalizeScopes(values []string, defaults ...string) ([]string, error) {
|
||||
src := values
|
||||
if len(src) == 0 {
|
||||
src = defaults
|
||||
}
|
||||
out := make([]string, 0, len(src))
|
||||
seen := make(map[string]struct{}, len(src))
|
||||
for _, raw := range src {
|
||||
for _, parsed := range clean.Scopes(raw) {
|
||||
scope := clean.Scope(parsed)
|
||||
if scope == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[scope]; exists {
|
||||
continue
|
||||
}
|
||||
if _, ok := allowedJWTScope[scope]; !ok {
|
||||
return nil, cli.Exit(fmt.Errorf("unsupported scope %q", scope), 2)
|
||||
}
|
||||
seen[scope] = struct{}{}
|
||||
out = append(out, scope)
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil, cli.Exit(errors.New("at least one scope is required"), 2)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// printJSON pretty-prints the payload as JSON.
|
||||
func printJSON(payload any) error {
|
||||
data, err := json.MarshalIndent(payload, "", " ")
|
||||
if err != nil {
|
||||
return cli.Exit(err, 1)
|
||||
}
|
||||
fmt.Printf("%s\n", data)
|
||||
return nil
|
||||
}
|
@@ -16,7 +16,7 @@ var ShowCommandsCommand = &cli.Command{
|
||||
Name: "commands",
|
||||
Usage: "Displays a structured catalog of CLI commands",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"},
|
||||
JsonFlag(),
|
||||
&cli.BoolFlag{Name: "all", Usage: "include hidden commands and flags"},
|
||||
&cli.BoolFlag{Name: "short", Usage: "omit flags in Markdown output"},
|
||||
&cli.IntFlag{Name: "base-heading", Value: 2, Usage: "base Markdown heading level"},
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -43,9 +44,9 @@ func statusAction(ctx *cli.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
|
||||
endpointUrl := buildStatusEndpoint(conf)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
req, err := http.NewRequest(http.MethodGet, endpointUrl, nil)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -53,12 +54,12 @@ func statusAction(ctx *cli.Context) error {
|
||||
|
||||
var status string
|
||||
|
||||
if resp, err := client.Do(req); err != nil {
|
||||
if resp, reqErr := client.Do(req); reqErr != nil {
|
||||
return fmt.Errorf("cannot connect to %s:%d", conf.HttpHost(), conf.HttpPort())
|
||||
} else if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("server running at %s:%d, bad status %d\n", conf.HttpHost(), conf.HttpPort(), resp.StatusCode)
|
||||
} else if body, err := io.ReadAll(resp.Body); err != nil {
|
||||
return err
|
||||
} else if body, readErr := io.ReadAll(resp.Body); readErr != nil {
|
||||
return readErr
|
||||
} else {
|
||||
status = string(body)
|
||||
}
|
||||
@@ -73,3 +74,21 @@ func statusAction(ctx *cli.Context) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildStatusEndpoint returns the status endpoint URL, preferring the public
|
||||
// SiteUrl (which carries the correct scheme) and falling back to the local
|
||||
// HTTP host/port. When a Unix socket is configured, an http+unix style URL is
|
||||
// used so the custom transport can dial the socket.
|
||||
func buildStatusEndpoint(conf *config.Config) string {
|
||||
if socket := conf.HttpSocket(); socket != nil {
|
||||
return fmt.Sprintf("%s://%s/api/v1/status", socket.Scheme, strings.TrimPrefix(socket.Path, "/"))
|
||||
}
|
||||
|
||||
siteUrl := strings.TrimRight(conf.SiteUrl(), "/")
|
||||
|
||||
if siteUrl != "" {
|
||||
return siteUrl + "/api/v1/status"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("http://%s:%d/api/v1/status", conf.HttpHost(), conf.HttpPort())
|
||||
}
|
||||
|
@@ -82,7 +82,7 @@ func TestConfig_ClientShareConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_ClientUser(t *testing.T) {
|
||||
c := NewTestConfig("config")
|
||||
c := NewMinimalTestConfigWithDb("client-user", t.TempDir())
|
||||
c.SetAuthMode(AuthModePasswd)
|
||||
|
||||
assert.Equal(t, AuthModePasswd, c.AuthMode())
|
||||
@@ -112,7 +112,7 @@ func TestConfig_ClientUser(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestConfig_ClientRoleConfig(t *testing.T) {
|
||||
c := NewTestConfig("config")
|
||||
c := NewMinimalTestConfigWithDb("client-role", t.TempDir())
|
||||
c.SetAuthMode(AuthModePasswd)
|
||||
|
||||
assert.Equal(t, AuthModePasswd, c.AuthMode())
|
||||
|
@@ -2,6 +2,8 @@ package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
urlpkg "net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -11,6 +13,7 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/service/cluster"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/list"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/header"
|
||||
)
|
||||
@@ -222,7 +225,36 @@ func (c *Config) SetJWKSUrl(url string) {
|
||||
if c == nil || c.options == nil {
|
||||
return
|
||||
}
|
||||
c.options.JWKSUrl = strings.TrimSpace(url)
|
||||
|
||||
trimmed := strings.TrimSpace(url)
|
||||
if trimmed == "" {
|
||||
c.options.JWKSUrl = ""
|
||||
return
|
||||
}
|
||||
|
||||
parsed, err := urlpkg.Parse(trimmed)
|
||||
if err != nil || parsed == nil || parsed.Scheme == "" || parsed.Host == "" {
|
||||
log.Warnf("config: ignoring JWKS URL %q (%v)", trimmed, err)
|
||||
return
|
||||
}
|
||||
|
||||
scheme := strings.ToLower(parsed.Scheme)
|
||||
host := parsed.Hostname()
|
||||
|
||||
switch scheme {
|
||||
case "https":
|
||||
// Always allowed.
|
||||
case "http":
|
||||
if !isLoopbackHost(host) {
|
||||
log.Warnf("config: rejecting JWKS URL %q (http only allowed for localhost/loopback)", trimmed)
|
||||
return
|
||||
}
|
||||
default:
|
||||
log.Warnf("config: rejecting JWKS URL %q (unsupported scheme)", trimmed)
|
||||
return
|
||||
}
|
||||
|
||||
c.options.JWKSUrl = trimmed
|
||||
}
|
||||
|
||||
// JWKSCacheTTL returns the JWKS cache lifetime in seconds (default 300, max 3600).
|
||||
@@ -247,6 +279,18 @@ func (c *Config) JWTLeeway() int {
|
||||
return c.options.JWTLeeway
|
||||
}
|
||||
|
||||
// JWTAllowedScopes returns an optional allow-list of accepted JWT scopes.
|
||||
func (c *Config) JWTAllowedScopes() list.Attr {
|
||||
if s := strings.TrimSpace(c.options.JWTScope); s != "" {
|
||||
parsed := list.ParseAttr(strings.ToLower(s))
|
||||
if len(parsed) > 0 {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
|
||||
return list.ParseAttr("cluster vision metrics")
|
||||
}
|
||||
|
||||
// AdvertiseUrl returns the advertised node URL for intra-cluster calls (scheme://host[:port]).
|
||||
func (c *Config) AdvertiseUrl() string {
|
||||
if c.options.AdvertiseUrl != "" {
|
||||
@@ -261,6 +305,23 @@ func (c *Config) AdvertiseUrl() string {
|
||||
return c.SiteUrl()
|
||||
}
|
||||
|
||||
// isLoopbackHost returns true when host represents localhost or a loopback IP.
|
||||
func isLoopbackHost(host string) bool {
|
||||
if host == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if strings.EqualFold(host, "localhost") {
|
||||
return true
|
||||
}
|
||||
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
return ip.IsLoopback()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// SaveClusterUUID writes or updates the ClusterUUID key in options.yml without
|
||||
// touching unrelated keys. Creates the file and directories if needed.
|
||||
func (c *Config) SaveClusterUUID(uuid string) error {
|
||||
|
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/photoprism/photoprism/internal/service/cluster"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
"github.com/photoprism/photoprism/pkg/list"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
@@ -73,15 +74,84 @@ func TestConfig_Cluster(t *testing.T) {
|
||||
c.Options().NodeRole = ""
|
||||
})
|
||||
t.Run("JWKSUrlSetter", func(t *testing.T) {
|
||||
const existing = "https://existing.example/.well-known/jwks.json"
|
||||
tests := []struct {
|
||||
name string
|
||||
prev string
|
||||
input string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "TrimHTTPS",
|
||||
prev: "",
|
||||
input: " https://portal.example/.well-known/jwks.json ",
|
||||
expect: "https://portal.example/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "CaseInsensitiveScheme",
|
||||
prev: "",
|
||||
input: "HTTPS://portal.example/.well-known/jwks.json",
|
||||
expect: "HTTPS://portal.example/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "AllowHTTPOnLocalhost",
|
||||
prev: "",
|
||||
input: "http://localhost:2342/.well-known/jwks.json",
|
||||
expect: "http://localhost:2342/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "AllowHTTPOnLoopbackIPv4",
|
||||
prev: "",
|
||||
input: "http://127.0.0.1/.well-known/jwks.json",
|
||||
expect: "http://127.0.0.1/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "AllowHTTPOnLoopbackIPv6",
|
||||
prev: "",
|
||||
input: "http://[::1]/.well-known/jwks.json",
|
||||
expect: "http://[::1]/.well-known/jwks.json",
|
||||
},
|
||||
{
|
||||
name: "RejectHTTPNonLoopback",
|
||||
prev: existing,
|
||||
input: "http://portal.example/.well-known/jwks.json",
|
||||
expect: existing,
|
||||
},
|
||||
{
|
||||
name: "RejectUnsupportedScheme",
|
||||
prev: existing,
|
||||
input: "ftp://portal.example/.well-known/jwks.json",
|
||||
expect: existing,
|
||||
},
|
||||
{
|
||||
name: "RejectMalformedURL",
|
||||
prev: existing,
|
||||
input: "://not-a-url",
|
||||
expect: existing,
|
||||
},
|
||||
{
|
||||
name: "ClearValue",
|
||||
prev: existing,
|
||||
input: "",
|
||||
expect: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
c.options.JWKSUrl = tc.prev
|
||||
c.SetJWKSUrl(tc.input)
|
||||
assert.Equal(t, tc.expect, c.JWKSUrl())
|
||||
})
|
||||
}
|
||||
})
|
||||
t.Run("JWTAllowedScopes", func(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
c.options.JWKSUrl = ""
|
||||
assert.Equal(t, "", c.JWKSUrl())
|
||||
|
||||
c.SetJWKSUrl(" https://portal.example/.well-known/jwks.json ")
|
||||
assert.Equal(t, "https://portal.example/.well-known/jwks.json", c.JWKSUrl())
|
||||
|
||||
c.SetJWKSUrl("")
|
||||
assert.Equal(t, "", c.JWKSUrl())
|
||||
c.options.JWTScope = "cluster vision"
|
||||
assert.Equal(t, list.ParseAttr("cluster vision"), c.JWTAllowedScopes())
|
||||
c.options.JWTScope = ""
|
||||
assert.Equal(t, list.ParseAttr("cluster vision metrics"), c.JWTAllowedScopes())
|
||||
})
|
||||
t.Run("Paths", func(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
|
@@ -342,12 +342,18 @@ func (c *Config) SetDbOptions() {
|
||||
case Postgres:
|
||||
// Ignore for now.
|
||||
case SQLite3:
|
||||
// Not required as unicode is default.
|
||||
// Not required as Unicode is default.
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterDb sets the database options and connection provider.
|
||||
// RegisterDb opens a database connection if needed,
|
||||
// sets the database options and connection provider.
|
||||
func (c *Config) RegisterDb() {
|
||||
if err := c.connectDb(); err != nil {
|
||||
log.Errorf("config: %s (register db)")
|
||||
return
|
||||
}
|
||||
|
||||
c.SetDbOptions()
|
||||
entity.SetDbProvider(c)
|
||||
}
|
||||
@@ -456,6 +462,11 @@ func (c *Config) connectDb() error {
|
||||
mutex.Db.Lock()
|
||||
defer mutex.Db.Unlock()
|
||||
|
||||
// Database connection already exists.
|
||||
if c.db != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get database driver and data source name.
|
||||
dbDriver := c.DatabaseDriver()
|
||||
dbDsn := c.DatabaseDSN()
|
||||
|
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
// ApplyScope updates the current settings based on the authorization scope passed.
|
||||
func (s *Settings) ApplyScope(scope string) *Settings {
|
||||
if scope == "" || scope == list.All {
|
||||
if scope == "" || scope == list.Any {
|
||||
return s
|
||||
}
|
||||
|
||||
|
@@ -731,6 +731,11 @@ var Flags = CliFlags{
|
||||
Value: 300,
|
||||
EnvVars: EnvVars("JWKS_CACHE_TTL"),
|
||||
}}, {
|
||||
Flag: &cli.StringFlag{
|
||||
Name: "jwt-scope",
|
||||
Usage: "allowed JWT `SCOPES` (space separated). Leave empty to accept defaults",
|
||||
EnvVars: EnvVars("JWT_SCOPE"),
|
||||
}}, {
|
||||
Flag: &cli.IntFlag{
|
||||
Name: "jwt-leeway",
|
||||
Usage: "JWT clock skew allowance in `SECONDS` (default 60, max 300)",
|
||||
|
@@ -154,6 +154,7 @@ type Options struct {
|
||||
NodeClientSecret string `yaml:"NodeClientSecret" json:"-" flag:"node-client-secret"`
|
||||
JWKSUrl string `yaml:"JWKSUrl" json:"-" flag:"jwks-url"`
|
||||
JWKSCacheTTL int `yaml:"JWKSCacheTTL" json:"-" flag:"jwks-cache-ttl"`
|
||||
JWTScope string `yaml:"JWTScope" json:"-" flag:"jwt-scope"`
|
||||
JWTLeeway int `yaml:"JWTLeeway" json:"-" flag:"jwt-leeway"`
|
||||
AdvertiseUrl string `yaml:"AdvertiseUrl" json:"-" flag:"advertise-url"`
|
||||
HttpsProxy string `yaml:"HttpsProxy" json:"HttpsProxy" flag:"https-proxy"`
|
||||
|
@@ -190,6 +190,7 @@ func (c *Config) Report() (rows [][]string, cols []string) {
|
||||
{"node-client-secret", fmt.Sprintf("%s", strings.Repeat("*", utf8.RuneCountInString(c.NodeClientSecret())))},
|
||||
{"jwks-url", c.JWKSUrl()},
|
||||
{"jwks-cache-ttl", fmt.Sprintf("%d", c.JWKSCacheTTL())},
|
||||
{"jwt-scope", c.JWTAllowedScopes().String()},
|
||||
{"jwt-leeway", fmt.Sprintf("%d", c.JWTLeeway())},
|
||||
{"advertise-url", c.AdvertiseUrl()},
|
||||
|
||||
|
@@ -45,13 +45,7 @@ func testDataPath(assetsPath string) string {
|
||||
var PkgNameRegexp = regexp.MustCompile("[^a-zA-Z\\-_]+")
|
||||
|
||||
// NewTestOptions returns valid config options for tests.
|
||||
func NewTestOptions(pkg string) *Options {
|
||||
// Find assets path.
|
||||
assetsPath := os.Getenv("PHOTOPRISM_ASSETS_PATH")
|
||||
if assetsPath == "" {
|
||||
fs.Abs("../../assets")
|
||||
}
|
||||
|
||||
func NewTestOptions(dbName string) *Options {
|
||||
// Find storage path.
|
||||
storagePath := os.Getenv("PHOTOPRISM_STORAGE_PATH")
|
||||
if storagePath == "" {
|
||||
@@ -60,7 +54,43 @@ func NewTestOptions(pkg string) *Options {
|
||||
|
||||
dataPath := filepath.Join(storagePath, fs.TestdataDir)
|
||||
|
||||
pkg = PkgNameRegexp.ReplaceAllString(pkg, "")
|
||||
return NewTestOptionsForPath(dbName, dataPath)
|
||||
}
|
||||
|
||||
// NewTestOptionsForPath returns new test Options using the specified data path as storage.
|
||||
func NewTestOptionsForPath(dbName, dataPath string) *Options {
|
||||
// Default to storage/testdata is no path was specified.
|
||||
if dataPath == "" {
|
||||
storagePath := os.Getenv("PHOTOPRISM_STORAGE_PATH")
|
||||
|
||||
if storagePath == "" {
|
||||
storagePath = fs.Abs("../../storage")
|
||||
}
|
||||
|
||||
dataPath = filepath.Join(storagePath, fs.TestdataDir)
|
||||
}
|
||||
|
||||
dataPath = fs.Abs(dataPath)
|
||||
|
||||
if err := fs.MkdirAll(dataPath); err != nil {
|
||||
log.Errorf("config: %s (create test data path)", err)
|
||||
return &Options{}
|
||||
}
|
||||
|
||||
configPath := filepath.Join(dataPath, "config")
|
||||
|
||||
if err := fs.MkdirAll(configPath); err != nil {
|
||||
log.Errorf("config: %s (create test config path)", err)
|
||||
return &Options{}
|
||||
}
|
||||
|
||||
// Find assets path.
|
||||
assetsPath := os.Getenv("PHOTOPRISM_ASSETS_PATH")
|
||||
if assetsPath == "" {
|
||||
fs.Abs("../../assets")
|
||||
}
|
||||
|
||||
dbName = PkgNameRegexp.ReplaceAllString(dbName, "")
|
||||
driver := os.Getenv("PHOTOPRISM_TEST_DRIVER")
|
||||
dsn := os.Getenv("PHOTOPRISM_TEST_DSN")
|
||||
|
||||
@@ -75,16 +105,16 @@ func NewTestOptions(pkg string) *Options {
|
||||
|
||||
// Set default database DSN.
|
||||
if driver == SQLite3 {
|
||||
if dsn == "" && pkg != "" {
|
||||
if dsn = fmt.Sprintf(".%s.db", clean.TypeLower(pkg)); !fs.FileExists(dsn) {
|
||||
log.Debugf("sqlite: test database %s does not already exist", clean.Log(dsn))
|
||||
if dsn == "" && dbName != "" {
|
||||
if dsn = fmt.Sprintf(".%s.db", clean.TypeLower(dbName)); !fs.FileExists(dsn) {
|
||||
log.Tracef("sqlite: test database %s does not already exist", clean.Log(dsn))
|
||||
} else if err := os.Remove(dsn); err != nil {
|
||||
log.Errorf("sqlite: failed to remove existing test database %s (%s)", clean.Log(dsn), err)
|
||||
}
|
||||
} else if dsn == "" || dsn == SQLiteTestDB {
|
||||
dsn = SQLiteTestDB
|
||||
if !fs.FileExists(dsn) {
|
||||
log.Debugf("sqlite: test database %s does not already exist", clean.Log(dsn))
|
||||
log.Tracef("sqlite: test database %s does not already exist", clean.Log(dsn))
|
||||
} else if err := os.Remove(dsn); err != nil {
|
||||
log.Errorf("sqlite: failed to remove existing test database %s (%s)", clean.Log(dsn), err)
|
||||
}
|
||||
@@ -92,7 +122,7 @@ func NewTestOptions(pkg string) *Options {
|
||||
}
|
||||
|
||||
// Test config options.
|
||||
c := &Options{
|
||||
opts := &Options{
|
||||
Name: "PhotoPrism",
|
||||
Version: "0.0.0",
|
||||
Copyright: "(c) 2018-2025 PhotoPrism UG. All rights reserved.",
|
||||
@@ -111,12 +141,14 @@ func NewTestOptions(pkg string) *Options {
|
||||
IndexSchedule: DefaultIndexSchedule,
|
||||
AutoImport: 7200,
|
||||
StoragePath: dataPath,
|
||||
CachePath: dataPath + "/cache",
|
||||
OriginalsPath: dataPath + "/originals",
|
||||
ImportPath: dataPath + "/import",
|
||||
ConfigPath: dataPath + "/config",
|
||||
SidecarPath: dataPath + "/sidecar",
|
||||
TempPath: dataPath + "/temp",
|
||||
CachePath: filepath.Join(dataPath, "cache"),
|
||||
OriginalsPath: filepath.Join(dataPath, "originals"),
|
||||
ImportPath: filepath.Join(dataPath, "import"),
|
||||
ConfigPath: configPath,
|
||||
DefaultsYaml: filepath.Join(configPath, "defaults.yml"),
|
||||
OptionsYaml: filepath.Join(configPath, "options.yml"),
|
||||
SidecarPath: filepath.Join(dataPath, "sidecar"),
|
||||
TempPath: filepath.Join(dataPath, "temp"),
|
||||
BackupRetain: DefaultBackupRetain,
|
||||
BackupSchedule: DefaultBackupSchedule,
|
||||
DatabaseDriver: driver,
|
||||
@@ -128,7 +160,7 @@ func NewTestOptions(pkg string) *Options {
|
||||
DetectNSFW: true,
|
||||
}
|
||||
|
||||
return c
|
||||
return opts
|
||||
}
|
||||
|
||||
// NewTestOptionsError returns invalid config options for tests.
|
||||
@@ -162,11 +194,94 @@ func TestConfig() *Config {
|
||||
return testConfig
|
||||
}
|
||||
|
||||
// NewTestConfig returns a valid test config.
|
||||
// NewMinimalTestConfig creates a lightweight test Config (no DB, minimal filesystem).
|
||||
//
|
||||
// Not suitable for tests requiring a database or pre-created storage directories.
|
||||
func NewMinimalTestConfig(dataPath string) *Config {
|
||||
return NewIsolatedTestConfig("", dataPath, false)
|
||||
}
|
||||
|
||||
var testDbCache []byte
|
||||
var testDbMutex sync.Mutex
|
||||
|
||||
// NewMinimalTestConfigWithDb creates a lightweight test Config (minimal filesystem).
|
||||
//
|
||||
// Creates an isolated SQLite DB (cached after first run) without seeding media fixtures.
|
||||
func NewMinimalTestConfigWithDb(dbName, dataPath string) *Config {
|
||||
c := NewIsolatedTestConfig(dbName, dataPath, true)
|
||||
|
||||
cachedDb := false
|
||||
|
||||
// Try to restore test db from cache.
|
||||
if len(testDbCache) > 0 && c.DatabaseDriver() == SQLite3 && !fs.FileExists(c.DatabaseDSN()) {
|
||||
if err := os.WriteFile(c.DatabaseDSN(), testDbCache, fs.ModeFile); err != nil {
|
||||
log.Warnf("config: %s (restore test database)", err)
|
||||
} else {
|
||||
cachedDb = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.Init(); err != nil {
|
||||
log.Fatalf("config: %s (init)", err.Error())
|
||||
}
|
||||
|
||||
c.RegisterDb()
|
||||
|
||||
if cachedDb {
|
||||
return c
|
||||
}
|
||||
|
||||
c.InitTestDb()
|
||||
|
||||
if testDbCache == nil && c.DatabaseDriver() == SQLite3 && fs.FileExistsNotEmpty(c.DatabaseDSN()) {
|
||||
testDbMutex.Lock()
|
||||
defer testDbMutex.Unlock()
|
||||
|
||||
if testDbCache != nil {
|
||||
return c
|
||||
}
|
||||
|
||||
if testDb, readErr := os.ReadFile(c.DatabaseDSN()); readErr != nil {
|
||||
log.Warnf("config: could not cache test database (%s)", readErr)
|
||||
} else {
|
||||
testDbCache = testDb
|
||||
}
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NewIsolatedTestConfig constructs a lightweight Config backed by the provided config path.
|
||||
//
|
||||
// It avoids running migrations or loading test fixtures, making it useful for unit tests that
|
||||
// only need basic access to config options (for example, JWT helpers). The caller should provide
|
||||
// an isolated directory (e.g. via testing.T.TempDir) so temporary files are cleaned up automatically.
|
||||
func NewIsolatedTestConfig(dbName, dataPath string, createDirs bool) *Config {
|
||||
if dataPath == "" {
|
||||
dataPath = filepath.Join(os.TempDir(), "photoprism-test-"+rnd.Base36(6))
|
||||
}
|
||||
|
||||
opts := NewTestOptionsForPath(dbName, dataPath)
|
||||
|
||||
c := &Config{
|
||||
options: opts,
|
||||
token: rnd.Base36(8),
|
||||
}
|
||||
|
||||
if !createDirs {
|
||||
return c
|
||||
}
|
||||
|
||||
if err := c.CreateDirectories(); err != nil {
|
||||
log.Errorf("config: %s (create test directories)", err)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// NewTestConfig initializes test data so required directories exist before tests run.
|
||||
// See AGENTS.md (Test Data & Fixtures) and specs/dev/backend-testing.md for guidance.
|
||||
func NewTestConfig(pkg string) *Config {
|
||||
func NewTestConfig(dbName string) *Config {
|
||||
defer log.Debug(capture.Time(time.Now(), "config: new test config created"))
|
||||
|
||||
testConfigMutex.Lock()
|
||||
@@ -174,7 +289,7 @@ func NewTestConfig(pkg string) *Config {
|
||||
|
||||
c := &Config{
|
||||
cliCtx: CliTestContext(),
|
||||
options: NewTestOptions(pkg),
|
||||
options: NewTestOptions(dbName),
|
||||
token: rnd.Base36(8),
|
||||
}
|
||||
|
||||
|
@@ -17,7 +17,6 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/authn"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/i18n"
|
||||
"github.com/photoprism/photoprism/pkg/list"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
"github.com/photoprism/photoprism/pkg/service/http/header"
|
||||
"github.com/photoprism/photoprism/pkg/time/unix"
|
||||
@@ -492,40 +491,7 @@ func (m *Session) Scope() string {
|
||||
|
||||
// ValidateScope checks if the scope does not exclude access to specified resource.
|
||||
func (m *Session) ValidateScope(resource acl.Resource, perms acl.Permissions) bool {
|
||||
// Get scope string.
|
||||
scope := m.Scope()
|
||||
|
||||
// Skip detailed check and allow all if scope is "*".
|
||||
if scope == list.All {
|
||||
return true
|
||||
}
|
||||
|
||||
// Skip resource check if scope includes all read operations.
|
||||
if scope == acl.ScopeRead.String() {
|
||||
return !acl.GrantScopeRead.DenyAny(perms)
|
||||
}
|
||||
|
||||
// Parse scope to check for resources and permissions.
|
||||
attr := list.ParseAttr(scope)
|
||||
|
||||
// Check if resource is within scope.
|
||||
if granted := attr.Contains(resource.String()); !granted {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if permission is within scope.
|
||||
if len(perms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check if scope is limited to read or write operations.
|
||||
if a := attr.Find(acl.ScopeRead.String()); a.Value == list.True && acl.GrantScopeRead.DenyAny(perms) {
|
||||
return false
|
||||
} else if a = attr.Find(acl.ScopeWrite.String()); a.Value == list.True && acl.GrantScopeWrite.DenyAny(perms) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return acl.ScopePermits(m.AuthScope, resource, perms)
|
||||
}
|
||||
|
||||
// InsufficientScope checks if the scope does not include access to specified resource.
|
||||
|
@@ -227,13 +227,16 @@ func TestDownloadPlaylistEntry(t *testing.T) {
|
||||
}
|
||||
|
||||
// Download the same file but with the direct link
|
||||
url := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
|
||||
dlUrl := "https://soundcloud.com/mattheis/b1-mattheis-ben-m"
|
||||
|
||||
stderrBuf = &bytes.Buffer{}
|
||||
r, err = NewMetadata(context.Background(), url, Options{
|
||||
|
||||
r, err = NewMetadata(context.Background(), dlUrl, Options{
|
||||
StderrFn: func(cmd *exec.Cmd) io.Writer {
|
||||
return stderrBuf
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@@ -5,14 +5,24 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
c := config.NewTestConfig("service")
|
||||
tempDir, err := os.MkdirTemp("", "internal-photoprism-get")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
c := config.NewMinimalTestConfigWithDb("test", tempDir)
|
||||
|
||||
SetConfig(c)
|
||||
defer c.CloseDb()
|
||||
|
||||
code := m.Run()
|
||||
|
||||
// Remove temporary SQLite files after running the tests.
|
||||
fs.PurgeTestDbFiles(".", false)
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
@@ -1,30 +1,42 @@
|
||||
package get
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/auth/jwt"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
var (
|
||||
onceJWTManager sync.Once
|
||||
onceJWTIssuer sync.Once
|
||||
onceJWTManager sync.Once
|
||||
onceJWTIssuer sync.Once
|
||||
onceJWTVerifier sync.Once
|
||||
)
|
||||
|
||||
// initJWTManager lazily initializes the shared portal key manager for JWT issuance.
|
||||
func initJWTManager() {
|
||||
conf := Config()
|
||||
if conf == nil || !conf.IsPortal() {
|
||||
if conf == nil {
|
||||
return
|
||||
} else if !conf.IsPortal() {
|
||||
return
|
||||
}
|
||||
|
||||
manager, err := jwt.NewManager(conf)
|
||||
|
||||
if err != nil {
|
||||
log.Warnf("jwt: manager init failed (%s)", clean.Error(err))
|
||||
return
|
||||
}
|
||||
if _, err := manager.EnsureActiveKey(); err != nil {
|
||||
|
||||
if _, err = manager.EnsureActiveKey(); err != nil {
|
||||
log.Warnf("jwt: ensure signing key failed (%s)", clean.Error(err))
|
||||
}
|
||||
|
||||
services.JWTManager = manager
|
||||
}
|
||||
|
||||
@@ -34,6 +46,7 @@ func JWTManager() *jwt.Manager {
|
||||
return services.JWTManager
|
||||
}
|
||||
|
||||
// initJWTIssuer lazily binds the shared issuer to the active portal key manager.
|
||||
func initJWTIssuer() {
|
||||
manager := JWTManager()
|
||||
if manager == nil {
|
||||
@@ -50,9 +63,98 @@ func JWTIssuer() *jwt.Issuer {
|
||||
|
||||
// JWTVerifier returns a verifier bound to the current config.
|
||||
func JWTVerifier() *jwt.Verifier {
|
||||
conf := Config()
|
||||
if conf == nil {
|
||||
return nil
|
||||
}
|
||||
return jwt.NewVerifier(conf)
|
||||
onceJWTVerifier.Do(initJWTVerifier)
|
||||
return services.JWTVerifier
|
||||
}
|
||||
|
||||
// VerifyJWT verifies a token using the shared verifier instance.
|
||||
func VerifyJWT(ctx context.Context, token string, expected jwt.ExpectedClaims) (*jwt.Claims, error) {
|
||||
verifier := JWTVerifier()
|
||||
if verifier == nil {
|
||||
return nil, errors.New("jwt: verifier not available")
|
||||
}
|
||||
return verifier.VerifyToken(ctx, token, expected)
|
||||
}
|
||||
|
||||
// initJWTVerifier lazily constructs the shared verifier for the current configuration.
|
||||
func initJWTVerifier() {
|
||||
if conf != nil {
|
||||
services.JWTVerifier = jwt.NewVerifier(conf)
|
||||
}
|
||||
}
|
||||
|
||||
// resetJWTVerifier clears the cached verifier so it can be rebuilt for a new configuration.
|
||||
func resetJWTVerifier() {
|
||||
services.JWTVerifier = nil
|
||||
onceJWTVerifier = sync.Once{}
|
||||
}
|
||||
|
||||
// resetJWTIssuer clears the cached issuer so it can be recreated for a new configuration.
|
||||
func resetJWTIssuer() {
|
||||
services.JWTIssuer = nil
|
||||
onceJWTIssuer = sync.Once{}
|
||||
}
|
||||
|
||||
// resetJWTManager clears the cached key manager so subsequent calls reload keys for the active configuration.
|
||||
func resetJWTManager() {
|
||||
services.JWTManager = nil
|
||||
onceJWTManager = sync.Once{}
|
||||
}
|
||||
|
||||
// resetJWT clears all cached JWT helpers.
|
||||
func resetJWT() {
|
||||
resetJWTVerifier()
|
||||
resetJWTIssuer()
|
||||
resetJWTManager()
|
||||
}
|
||||
|
||||
// IssuePortalJWT signs a token using the shared portal issuer with the provided claims.
|
||||
func IssuePortalJWT(spec jwt.ClaimsSpec) (string, error) {
|
||||
if issuer := JWTIssuer(); issuer == nil {
|
||||
return "", errors.New("jwt: issuer not available")
|
||||
} else {
|
||||
return issuer.Issue(spec)
|
||||
}
|
||||
}
|
||||
|
||||
// IssuePortalJWTForNode issues a portal-signed JWT targeting the specified node UUID.
|
||||
func IssuePortalJWTForNode(nodeUUID string, scopes []string, ttl time.Duration) (string, error) {
|
||||
if conf == nil {
|
||||
return "", errors.New("jwt: missing config")
|
||||
} else if !conf.IsPortal() {
|
||||
return "", errors.New("jwt: not supported on nodes")
|
||||
}
|
||||
|
||||
clusterUUID := strings.TrimSpace(conf.ClusterUUID())
|
||||
if clusterUUID == "" {
|
||||
return "", errors.New("jwt: cluster uuid not configured")
|
||||
}
|
||||
|
||||
nodeUUID = strings.TrimSpace(nodeUUID)
|
||||
if nodeUUID == "" {
|
||||
return "", errors.New("jwt: node uuid required")
|
||||
}
|
||||
if len(scopes) == 0 {
|
||||
return "", errors.New("jwt: at least one scope is required")
|
||||
}
|
||||
|
||||
normalized := make([]string, 0, len(scopes))
|
||||
for _, s := range scopes {
|
||||
if trimmed := strings.TrimSpace(s); trimmed != "" {
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return "", errors.New("jwt: at least one scope is required")
|
||||
}
|
||||
|
||||
spec := jwt.ClaimsSpec{
|
||||
Issuer: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Subject: fmt.Sprintf("portal:%s", clusterUUID),
|
||||
Audience: fmt.Sprintf("node:%s", nodeUUID),
|
||||
Scope: normalized,
|
||||
TTL: ttl,
|
||||
}
|
||||
|
||||
return IssuePortalJWT(spec)
|
||||
}
|
||||
|
39
internal/photoprism/get/jwt_test.go
Normal file
39
internal/photoprism/get/jwt_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package get
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
)
|
||||
|
||||
func TestJWTVerifierReuse(t *testing.T) {
|
||||
verifier1 := JWTVerifier()
|
||||
require.NotNil(t, verifier1)
|
||||
|
||||
verifier2 := JWTVerifier()
|
||||
require.NotNil(t, verifier2)
|
||||
|
||||
assert.Same(t, verifier1, verifier2)
|
||||
}
|
||||
|
||||
func TestJWTVerifierResetOnConfigChange(t *testing.T) {
|
||||
orig := Config()
|
||||
verifier1 := JWTVerifier()
|
||||
require.NotNil(t, verifier1)
|
||||
|
||||
tempConf := config.NewMinimalTestConfigWithDb("jwt-reset", t.TempDir())
|
||||
SetConfig(tempConf)
|
||||
t.Cleanup(func() {
|
||||
SetConfig(orig)
|
||||
tempConf.CloseDb()
|
||||
orig.RegisterDb()
|
||||
})
|
||||
|
||||
verifier2 := JWTVerifier()
|
||||
require.NotNil(t, verifier2)
|
||||
|
||||
assert.NotSame(t, verifier1, verifier2)
|
||||
}
|
@@ -57,13 +57,17 @@ var services struct {
|
||||
OIDC *oidc.Client
|
||||
JWTManager *clusterjwt.Manager
|
||||
JWTIssuer *clusterjwt.Issuer
|
||||
JWTVerifier *clusterjwt.Verifier
|
||||
}
|
||||
|
||||
func SetConfig(c *config.Config) {
|
||||
if c == nil {
|
||||
log.Panic("panic: argument is nil in get.SetConfig(c *config.Config)")
|
||||
return
|
||||
}
|
||||
|
||||
resetJWT()
|
||||
|
||||
conf = c
|
||||
|
||||
photoprism.SetConfig(c)
|
||||
@@ -72,6 +76,7 @@ func SetConfig(c *config.Config) {
|
||||
func Config() *config.Config {
|
||||
if conf == nil {
|
||||
log.Panic("panic: conf is nil in get.Config()")
|
||||
return nil
|
||||
}
|
||||
|
||||
return conf
|
||||
|
@@ -35,7 +35,7 @@ func TestWebDAVFileName_PathTraversalRejected(t *testing.T) {
|
||||
insideFile := filepath.Join(dir, "ok.txt")
|
||||
assert.NoError(t, fs.WriteString(insideFile, "ok"))
|
||||
|
||||
conf := config.NewTestConfig("server-webdav")
|
||||
conf := newWebDAVTestConfig(t)
|
||||
conf.Options().OriginalsPath = dir
|
||||
|
||||
r := gin.New()
|
||||
@@ -55,7 +55,7 @@ func TestWebDAVFileName_PathTraversalRejected(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebDAVFileName_MethodNotPut(t *testing.T) {
|
||||
conf := config.NewTestConfig("server-webdav")
|
||||
conf := newWebDAVTestConfig(t)
|
||||
r := gin.New()
|
||||
grp := r.Group(conf.BaseUri(WebDAVOriginals))
|
||||
req := &http.Request{Method: http.MethodGet}
|
||||
@@ -65,7 +65,7 @@ func TestWebDAVFileName_MethodNotPut(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWebDAVFileName_ImportBasePath(t *testing.T) {
|
||||
conf := config.NewTestConfig("server-webdav")
|
||||
conf := newWebDAVTestConfig(t)
|
||||
r := gin.New()
|
||||
grp := r.Group(conf.BaseUri(WebDAVImport))
|
||||
// create a real file under import
|
||||
@@ -88,3 +88,7 @@ func TestWebDAVSetFileMtime_FutureIgnored(t *testing.T) {
|
||||
after, _ := os.Stat(file)
|
||||
assert.Equal(t, before.ModTime().Unix(), after.ModTime().Unix())
|
||||
}
|
||||
|
||||
func newWebDAVTestConfig(t *testing.T) *config.Config {
|
||||
return config.NewMinimalTestConfig(t.TempDir())
|
||||
}
|
||||
|
@@ -219,6 +219,10 @@ func persistRegistration(c *config.Config, r *cluster.RegisterResponse, wantRota
|
||||
updates["ClusterUUID"] = r.UUID
|
||||
}
|
||||
|
||||
if cidr := strings.TrimSpace(r.ClusterCIDR); cidr != "" {
|
||||
updates["ClusterCIDR"] = cidr
|
||||
}
|
||||
|
||||
// Always persist NodeClientID (client UID) from response for future OAuth token requests.
|
||||
if r.Node.ClientID != "" {
|
||||
updates["NodeClientID"] = r.Node.ClientID
|
||||
@@ -229,9 +233,9 @@ func persistRegistration(c *config.Config, r *cluster.RegisterResponse, wantRota
|
||||
updates["NodeClientSecret"] = r.Secrets.ClientSecret
|
||||
}
|
||||
|
||||
if url := strings.TrimSpace(r.JWKSUrl); url != "" {
|
||||
updates["JWKSUrl"] = url
|
||||
c.SetJWKSUrl(url)
|
||||
if jwksUrl := strings.TrimSpace(r.JWKSUrl); jwksUrl != "" {
|
||||
updates["JWKSUrl"] = jwksUrl
|
||||
c.SetJWKSUrl(jwksUrl)
|
||||
}
|
||||
|
||||
// Persist NodeUUID from portal response if provided and not set locally.
|
||||
|
@@ -19,15 +19,15 @@ import (
|
||||
)
|
||||
|
||||
func TestInitConfig_NoPortal_NoOp(t *testing.T) {
|
||||
t.Setenv("PHOTOPRISM_STORAGE_PATH", t.TempDir())
|
||||
c := config.NewTestConfig("bootstrap-np")
|
||||
c := config.NewMinimalTestConfigWithDb("bootstrap", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
|
||||
// Default NodeRole() resolves to instance; no Portal configured.
|
||||
assert.Equal(t, cluster.RoleInstance, c.NodeRole())
|
||||
assert.NoError(t, InitConfig(c))
|
||||
}
|
||||
|
||||
func TestRegister_PersistSecretAndDB(t *testing.T) {
|
||||
t.Setenv("PHOTOPRISM_STORAGE_PATH", t.TempDir())
|
||||
// Fake Portal server.
|
||||
var jwksURL string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -37,10 +37,11 @@ func TestRegister_PersistSecretAndDB(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
resp := cluster.RegisterResponse{
|
||||
Node: cluster.Node{Name: "pp-node-01"},
|
||||
UUID: rnd.UUID(),
|
||||
Secrets: &cluster.RegisterSecrets{ClientSecret: "SECRET"},
|
||||
JWKSUrl: jwksURL,
|
||||
Node: cluster.Node{Name: "pp-node-01"},
|
||||
UUID: rnd.UUID(),
|
||||
ClusterCIDR: "192.0.2.0/24",
|
||||
Secrets: &cluster.RegisterSecrets{ClientSecret: "SECRET"},
|
||||
JWKSUrl: jwksURL,
|
||||
Database: cluster.RegisterDatabase{
|
||||
Driver: config.MySQL,
|
||||
Host: "db.local",
|
||||
@@ -62,7 +63,9 @@ func TestRegister_PersistSecretAndDB(t *testing.T) {
|
||||
jwksURL = srv.URL + "/.well-known/jwks.json"
|
||||
defer srv.Close()
|
||||
|
||||
c := config.NewTestConfig("bootstrap-reg")
|
||||
c := config.NewMinimalTestConfigWithDb("bootstrap-reg", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
|
||||
// Configure Portal.
|
||||
c.Options().PortalUrl = srv.URL
|
||||
c.Options().JoinToken = "t0k3n"
|
||||
@@ -82,10 +85,10 @@ func TestRegister_PersistSecretAndDB(t *testing.T) {
|
||||
assert.Contains(t, c.Options().DatabaseDSN, "@tcp(db.local:3306)/pp_db")
|
||||
assert.Equal(t, config.MySQL, c.Options().DatabaseDriver)
|
||||
assert.Equal(t, srv.URL+"/.well-known/jwks.json", c.JWKSUrl())
|
||||
assert.Equal(t, "192.0.2.0/24", c.ClusterCIDR())
|
||||
}
|
||||
|
||||
func TestThemeInstall_Missing(t *testing.T) {
|
||||
t.Setenv("PHOTOPRISM_STORAGE_PATH", t.TempDir())
|
||||
// Build a tiny zip in-memory with one file style.css
|
||||
var buf bytes.Buffer
|
||||
zw := zip.NewWriter(&buf)
|
||||
@@ -100,7 +103,7 @@ func TestThemeInstall_Missing(t *testing.T) {
|
||||
case "/api/v1/cluster/nodes/register":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Return NodeClientID + NodeClientSecret so bootstrap can request OAuth token
|
||||
_ = json.NewEncoder(w).Encode(cluster.RegisterResponse{UUID: rnd.UUID(), Node: cluster.Node{ClientID: "cs5gfen1bgxz7s9i", Name: "pp-node-01"}, Secrets: &cluster.RegisterSecrets{ClientSecret: "s3cr3t"}, JWKSUrl: jwksURL2})
|
||||
_ = json.NewEncoder(w).Encode(cluster.RegisterResponse{UUID: rnd.UUID(), ClusterCIDR: "198.51.100.0/24", Node: cluster.Node{ClientID: "cs5gfen1bgxz7s9i", Name: "pp-node-01"}, Secrets: &cluster.RegisterSecrets{ClientSecret: "s3cr3t"}, JWKSUrl: jwksURL2})
|
||||
case "/api/v1/oauth/token":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{"access_token": "tok", "token_type": "Bearer"})
|
||||
@@ -115,7 +118,9 @@ func TestThemeInstall_Missing(t *testing.T) {
|
||||
jwksURL2 = srv.URL + "/.well-known/jwks.json"
|
||||
defer srv.Close()
|
||||
|
||||
c := config.NewTestConfig("bootstrap-theme")
|
||||
c := config.NewMinimalTestConfigWithDb("bootstrap-theme", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
|
||||
// Point Portal.
|
||||
c.Options().PortalUrl = srv.URL
|
||||
c.Options().JoinToken = "t0k3n"
|
||||
@@ -137,7 +142,6 @@ func TestThemeInstall_Missing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegister_SQLite_NoDBPersist(t *testing.T) {
|
||||
t.Setenv("PHOTOPRISM_STORAGE_PATH", t.TempDir())
|
||||
// Portal responds with DB DSN, but local driver is SQLite → must not persist DB.
|
||||
var jwksURL3 string
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -146,10 +150,11 @@ func TestRegister_SQLite_NoDBPersist(t *testing.T) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
resp := cluster.RegisterResponse{
|
||||
Node: cluster.Node{Name: "pp-node-01"},
|
||||
Secrets: &cluster.RegisterSecrets{ClientSecret: "SECRET"},
|
||||
JWKSUrl: jwksURL3,
|
||||
Database: cluster.RegisterDatabase{Host: "db.local", Port: 3306, Name: "pp_db", User: "pp_user", Password: "pp_pw", DSN: "pp_user:pp_pw@tcp(db.local:3306)/pp_db?charset=utf8mb4&parseTime=true"},
|
||||
Node: cluster.Node{Name: "pp-node-01"},
|
||||
Secrets: &cluster.RegisterSecrets{ClientSecret: "SECRET"},
|
||||
ClusterCIDR: "203.0.113.0/24",
|
||||
JWKSUrl: jwksURL3,
|
||||
Database: cluster.RegisterDatabase{Host: "db.local", Port: 3306, Name: "pp_db", User: "pp_user", Password: "pp_pw", DSN: "pp_user:pp_pw@tcp(db.local:3306)/pp_db?charset=utf8mb4&parseTime=true"},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
default:
|
||||
@@ -159,7 +164,9 @@ func TestRegister_SQLite_NoDBPersist(t *testing.T) {
|
||||
jwksURL3 = srv.URL + "/.well-known/jwks.json"
|
||||
defer srv.Close()
|
||||
|
||||
c := config.NewTestConfig("bootstrap-sqlite")
|
||||
c := config.NewMinimalTestConfigWithDb("bootstrap-sqlite", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
|
||||
// SQLite driver by default; set Portal.
|
||||
c.Options().PortalUrl = srv.URL
|
||||
c.Options().JoinToken = "t0k3n"
|
||||
@@ -175,10 +182,10 @@ func TestRegister_SQLite_NoDBPersist(t *testing.T) {
|
||||
assert.Equal(t, config.SQLite3, c.DatabaseDriver())
|
||||
assert.Equal(t, origDSN, c.Options().DatabaseDSN)
|
||||
assert.Equal(t, srv.URL+"/.well-known/jwks.json", c.JWKSUrl())
|
||||
assert.Equal(t, "203.0.113.0/24", c.ClusterCIDR())
|
||||
}
|
||||
|
||||
func TestRegister_404_NoRetry(t *testing.T) {
|
||||
t.Setenv("PHOTOPRISM_STORAGE_PATH", t.TempDir())
|
||||
var hits int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/api/v1/cluster/nodes/register" {
|
||||
@@ -190,7 +197,9 @@ func TestRegister_404_NoRetry(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := config.NewTestConfig("bootstrap-404")
|
||||
c := config.NewMinimalTestConfigWithDb("bootstrap", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
|
||||
c.Options().PortalUrl = srv.URL
|
||||
c.Options().JoinToken = "t0k3n"
|
||||
|
||||
@@ -201,7 +210,6 @@ func TestRegister_404_NoRetry(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestThemeInstall_SkipWhenAppJsExists(t *testing.T) {
|
||||
t.Setenv("PHOTOPRISM_STORAGE_PATH", t.TempDir())
|
||||
// Portal returns a valid zip, but theme dir already has app.js → skip.
|
||||
var served int
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -218,7 +226,9 @@ func TestThemeInstall_SkipWhenAppJsExists(t *testing.T) {
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
c := config.NewTestConfig("bootstrap-theme-skip")
|
||||
c := config.NewMinimalTestConfigWithDb("bootstrap", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
|
||||
c.Options().PortalUrl = srv.URL
|
||||
c.Options().JoinToken = "t0k3n"
|
||||
|
||||
|
@@ -14,9 +14,8 @@ import (
|
||||
|
||||
// Duplicate names: FindByName should return the most recently updated.
|
||||
func TestClientRegistry_DuplicateNamePrefersLatest(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-dupes")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-dupes", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
// Create two clients directly to simulate duplicates with same name.
|
||||
c1 := entity.NewClient().SetName("pp-dupe").SetRole("instance")
|
||||
@@ -40,9 +39,8 @@ func TestClientRegistry_DuplicateNamePrefersLatest(t *testing.T) {
|
||||
|
||||
// Role change path: Put should update ClientRole via mapping.
|
||||
func TestClientRegistry_RoleChange(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-role")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-role", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
n := &Node{Node: cluster.Node{Name: "pp-role", Role: "service"}}
|
||||
|
@@ -13,11 +13,8 @@ import (
|
||||
)
|
||||
|
||||
func TestClientRegistry_PutFindListRotate(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-client")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-client", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
if err := c.Init(); err != nil {
|
||||
t.Fatalf("init config: %v", err)
|
||||
}
|
||||
|
||||
r, err := NewClientRegistryWithConfig(c)
|
||||
assert.NoError(t, err)
|
||||
|
@@ -15,9 +15,8 @@ import (
|
||||
// rule prevents hijacking: the update applies to the UUID's row and does not move
|
||||
// the ClientID from its original node.
|
||||
func TestClientRegistry_ClientIDReuse_CannotHijackExistingUUID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-cid-hijack")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-cid-hijack", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
// Seed two independent nodes
|
||||
@@ -51,9 +50,8 @@ func TestClientRegistry_ClientIDReuse_CannotHijackExistingUUID(t *testing.T) {
|
||||
// migrates the row to the new UUID. This mirrors restore flows where a node's ClientID
|
||||
// is reused for a regenerated or reassigned UUID.
|
||||
func TestClientRegistry_ClientIDReuse_ChangesUUIDWhenTargetMissing(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-cid-move")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-cid-move", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
// Seed one node
|
||||
|
@@ -14,9 +14,8 @@ import (
|
||||
|
||||
// Basic FindByClientID flow with Put and DTO mapping.
|
||||
func TestClientRegistry_FindByClientID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-find-clientid")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-find-clientid", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
n := &Node{Node: cluster.Node{Name: "pp-find-client", Role: "instance", UUID: rnd.UUIDv7()}}
|
||||
@@ -34,9 +33,8 @@ func TestClientRegistry_FindByClientID(t *testing.T) {
|
||||
|
||||
// Simulate client ID changing after a restore: old row removed, new row created with same NodeUUID.
|
||||
func TestClientRegistry_ClientIDChangedAfterRestore(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-clientid-restore")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-clientid-restore", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
uuid := rnd.UUIDv7()
|
||||
// Original row
|
||||
@@ -71,9 +69,8 @@ func TestClientRegistry_ClientIDChangedAfterRestore(t *testing.T) {
|
||||
|
||||
// Names swapped between two nodes: UUIDs must remain authoritative.
|
||||
func TestClientRegistry_SwapNames_UUIDAuthoritative(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-swap-names")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-swap-names", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
a := &Node{Node: cluster.Node{UUID: rnd.UUIDv7(), Name: "pp-a", Role: "instance"}}
|
||||
@@ -117,9 +114,8 @@ func TestClientRegistry_SwapNames_UUIDAuthoritative(t *testing.T) {
|
||||
|
||||
// Ensure DB driver and fields round-trip through Put → toNode → BuildClusterNode.
|
||||
func TestClientRegistry_DBDriverAndFields(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-dbdriver")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-dbdriver", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
n := &Node{Node: cluster.Node{UUID: rnd.UUIDv7(), Name: "pp-db", Role: "instance"}}
|
||||
|
@@ -12,9 +12,8 @@ import (
|
||||
|
||||
// Ensure List() excludes clients that look like nodes by role but have no NodeUUID.
|
||||
func TestClientRegistry_ListExcludesNodeRoleWithoutUUID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-list-exclude-node-role")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-list-exclude-node-role", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
// Bad records: node-like roles but empty NodeUUID
|
||||
bad1 := entity.NewClient().SetName("pp-bad1").SetRole("instance")
|
||||
|
@@ -14,9 +14,8 @@ import (
|
||||
|
||||
// Rotating secret selects the latest row for a UUID and persists rotation timestamp and password.
|
||||
func TestClientRegistry_RotateSecretByUUID_LatestRow(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-rotate-latest")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-rotate-latest", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
uuid := rnd.UUIDv7()
|
||||
|
@@ -26,9 +26,8 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func TestClientRegistry_GetAndDelete(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-delete")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-delete", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
|
||||
@@ -68,9 +67,8 @@ func TestClientRegistry_GetAndDelete(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientRegistry_ListOrderByUpdatedAtDesc(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-order")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-order", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
|
||||
@@ -160,9 +158,8 @@ func TestNodeOptsForSession_AdminVsNonAdmin(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestToNode_Mapping(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-map")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-map", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
m := entity.NewClient().SetName("pp-map").SetRole("instance")
|
||||
m.NodeUUID = rnd.UUIDv7()
|
||||
@@ -191,7 +188,7 @@ func TestToNode_Mapping(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientRegistry_GetClusterNodeByUUID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-getbyuuid")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-getbyuuid", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
@@ -210,7 +207,7 @@ func TestClientRegistry_GetClusterNodeByUUID(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestClientRegistry_FindByName_NormalizesDNSLabel(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-findname")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-findname", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
|
@@ -14,9 +14,8 @@ import (
|
||||
|
||||
// UUID-first upsert: Put finds existing row by UUID and updates fields.
|
||||
func TestClientRegistry_PutUpdateByUUID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-put-uuid")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-put-uuid", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
uuid := rnd.UUIDv7()
|
||||
@@ -47,9 +46,8 @@ func TestClientRegistry_PutUpdateByUUID(t *testing.T) {
|
||||
|
||||
// Latest-by-UpdatedAt when multiple rows share the same NodeUUID (historical duplicates).
|
||||
func TestClientRegistry_FindByNodeUUID_PrefersLatest(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-find-uuid-latest")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-find-uuid-latest", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
uuid := rnd.UUIDv7()
|
||||
// Create two raw client rows with the same NodeUUID and different UpdatedAt
|
||||
@@ -74,9 +72,8 @@ func TestClientRegistry_FindByNodeUUID_PrefersLatest(t *testing.T) {
|
||||
|
||||
// DeleteAllByUUID removes all rows that share a NodeUUID.
|
||||
func TestClientRegistry_DeleteAllByUUID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-delete-all")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-delete-all", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
uuid := rnd.UUIDv7()
|
||||
// Two rows with same UUID
|
||||
@@ -99,9 +96,8 @@ func TestClientRegistry_DeleteAllByUUID(t *testing.T) {
|
||||
|
||||
// List() should only include clients that represent cluster nodes (i.e., have a NodeUUID).
|
||||
func TestClientRegistry_ListOnlyUUID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-list-only-uuid")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-list-only-uuid", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
// Create one client with empty NodeUUID (non-node), and one proper node
|
||||
nonNode := entity.NewClient().SetName("webapp").SetRole("client")
|
||||
@@ -122,9 +118,8 @@ func TestClientRegistry_ListOnlyUUID(t *testing.T) {
|
||||
|
||||
// Put should prefer UUID over ClientID when both are provided, avoiding cross-attachment.
|
||||
func TestClientRegistry_PutPrefersUUIDOverClientID(t *testing.T) {
|
||||
c := cfg.NewTestConfig("cluster-registry-put-prefers-uuid")
|
||||
c := cfg.NewMinimalTestConfigWithDb("cluster-registry-put-prefers-uuid", t.TempDir())
|
||||
defer c.CloseDb()
|
||||
assert.NoError(t, c.Init())
|
||||
|
||||
r, _ := NewClientRegistryWithConfig(c)
|
||||
// Seed two separate records
|
||||
|
@@ -35,10 +35,11 @@ type DatabaseInfo struct {
|
||||
// SummaryResponse is the response type for GET /api/v1/cluster.
|
||||
// swagger:model SummaryResponse
|
||||
type SummaryResponse struct {
|
||||
UUID string `json:"uuid"` // ClusterUUID
|
||||
Nodes int `json:"nodes"`
|
||||
Database DatabaseInfo `json:"database"`
|
||||
Time string `json:"time"`
|
||||
UUID string `json:"uuid"` // ClusterUUID
|
||||
ClusterCIDR string `json:"clusterCidr,omitempty"`
|
||||
Nodes int `json:"nodes"`
|
||||
Database DatabaseInfo `json:"database"`
|
||||
Time string `json:"time"`
|
||||
}
|
||||
|
||||
// RegisterSecrets contains newly issued or rotated node secrets.
|
||||
@@ -65,6 +66,7 @@ type RegisterDatabase struct {
|
||||
// swagger:model RegisterResponse
|
||||
type RegisterResponse struct {
|
||||
UUID string `json:"uuid"` // ClusterUUID
|
||||
ClusterCIDR string `json:"clusterCidr,omitempty"`
|
||||
Node Node `json:"node"`
|
||||
Database RegisterDatabase `json:"database"`
|
||||
Secrets *RegisterSecrets `json:"secrets,omitempty"`
|
||||
|
@@ -221,13 +221,13 @@ func (c *Config) ReSync(token string) (err error) {
|
||||
// interrupt reading of the Response.Body.
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
|
||||
url := ServiceURL
|
||||
endpointUrl := ServiceURL
|
||||
method := http.MethodPost
|
||||
|
||||
var req *http.Request
|
||||
|
||||
if c.Key != "" {
|
||||
url = fmt.Sprintf(ServiceURL+"/%s", c.Key)
|
||||
endpointUrl = fmt.Sprintf(ServiceURL+"/%s", c.Key)
|
||||
method = http.MethodPut
|
||||
log.Tracef("config: requesting updated keys for maps and places")
|
||||
} else {
|
||||
@@ -239,7 +239,7 @@ func (c *Config) ReSync(token string) (err error) {
|
||||
|
||||
if j, err = json.Marshal(NewRequest(c.Version, c.Serial, c.Env, c.PartnerID, token)); err != nil {
|
||||
return err
|
||||
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
|
||||
} else if req, err = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@@ -67,17 +67,17 @@ func (c *Config) SendFeedback(frm form.Feedback) (err error) {
|
||||
// interrupt reading of the Response.Body.
|
||||
client := &http.Client{Timeout: 60 * time.Second}
|
||||
|
||||
url := fmt.Sprintf(FeedbackURL, c.Key)
|
||||
endpointUrl := fmt.Sprintf(FeedbackURL, c.Key)
|
||||
method := http.MethodPost
|
||||
|
||||
var req *http.Request
|
||||
|
||||
log.Debugf("sending feedback to %s", ApiHost())
|
||||
|
||||
if j, err := json.Marshal(feedback); err != nil {
|
||||
return err
|
||||
} else if req, err = http.NewRequest(method, url, bytes.NewReader(j)); err != nil {
|
||||
return err
|
||||
if j, reqErr := json.Marshal(feedback); reqErr != nil {
|
||||
return reqErr
|
||||
} else if req, reqErr = http.NewRequest(method, endpointUrl, bytes.NewReader(j)); reqErr != nil {
|
||||
return reqErr
|
||||
}
|
||||
|
||||
// Set user agent.
|
||||
|
@@ -9,18 +9,28 @@ import (
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/photoprism"
|
||||
"github.com/photoprism/photoprism/internal/photoprism/get"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
log = logrus.StandardLogger()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
|
||||
c := config.NewTestConfig("avatar")
|
||||
tempDir, err := os.MkdirTemp("", "avatar-test")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
c := config.NewMinimalTestConfigWithDb("avatar", tempDir)
|
||||
get.SetConfig(c)
|
||||
photoprism.SetConfig(c)
|
||||
defer c.CloseDb()
|
||||
|
||||
code := m.Run()
|
||||
|
||||
// Remove temporary SQLite files after running the tests.
|
||||
fs.PurgeTestDbFiles(".", false)
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
@@ -6,7 +6,8 @@ import (
|
||||
"github.com/photoprism/photoprism/pkg/list"
|
||||
)
|
||||
|
||||
// Scope sanitizes a string that contains authentication scope identifiers.
|
||||
// Scope sanitizes a string that contains auth scope identifiers.
|
||||
// Callers should use acl.ScopeAttrPermits / acl.ScopePermits for authorization checks.
|
||||
func Scope(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
@@ -15,7 +16,8 @@ func Scope(s string) string {
|
||||
return list.ParseAttr(strings.ToLower(s)).String()
|
||||
}
|
||||
|
||||
// Scopes sanitizes authentication scope identifiers and returns them as string slice.
|
||||
// Scopes sanitizes auth scope identifiers and returns them as strings.
|
||||
// Callers should use acl.ScopeAttrPermits / acl.ScopePermits for authorization checks.
|
||||
func Scopes(s string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
|
@@ -74,7 +74,7 @@ func (f *KeyValue) Parse(s string) *KeyValue {
|
||||
}
|
||||
|
||||
// Default?
|
||||
if f.Key == All {
|
||||
if f.Key == Any {
|
||||
return f
|
||||
} else if v = Value(v); v == "" {
|
||||
f.Value = True
|
||||
@@ -97,8 +97,8 @@ func (f *KeyValue) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
if f.Key == All {
|
||||
return All
|
||||
if f.Key == Any {
|
||||
return Any
|
||||
}
|
||||
|
||||
if Bool[strings.ToLower(f.Value)] == True {
|
||||
@@ -111,3 +111,8 @@ func (f *KeyValue) String() string {
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Any checks if this represents any value (asterisk).
|
||||
func (f *KeyValue) Any() bool {
|
||||
return f.Key == Any
|
||||
}
|
||||
|
@@ -68,9 +68,9 @@ func (list Attr) Sort() Attr {
|
||||
sort.Slice(list, func(i, j int) bool {
|
||||
if list[i].Key == list[j].Key {
|
||||
return list[i].Value < list[j].Value
|
||||
} else if list[i].Key == All {
|
||||
} else if list[i].Key == Any {
|
||||
return false
|
||||
} else if list[j].Key == All {
|
||||
} else if list[j].Key == Any {
|
||||
return true
|
||||
} else {
|
||||
return list[i].Key < list[j].Key
|
||||
@@ -95,23 +95,25 @@ func (list Attr) Contains(s string) bool {
|
||||
func (list Attr) Find(s string) (a KeyValue) {
|
||||
if len(list) == 0 || s == "" {
|
||||
return a
|
||||
} else if s == All {
|
||||
return KeyValue{Key: All, Value: ""}
|
||||
} else if s == Any {
|
||||
return KeyValue{Key: Any, Value: ""}
|
||||
}
|
||||
|
||||
attr := ParseKeyValue(s)
|
||||
|
||||
// Return nil if key is invalid or all.
|
||||
if attr.Key == "" {
|
||||
// Return if key is invalid.
|
||||
if attr == nil {
|
||||
return a
|
||||
} else if attr.Key == "" {
|
||||
return a
|
||||
}
|
||||
|
||||
// Find and return first match.
|
||||
if attr.Value == "" || attr.Value == All {
|
||||
if attr.Value == "" || attr.Value == Any {
|
||||
for i := range list {
|
||||
if strings.EqualFold(attr.Key, list[i].Key) {
|
||||
return *list[i]
|
||||
} else if list[i].Key == All {
|
||||
} else if list[i].Key == Any {
|
||||
a = *list[i]
|
||||
}
|
||||
}
|
||||
@@ -122,10 +124,10 @@ func (list Attr) Find(s string) (a KeyValue) {
|
||||
return KeyValue{Key: "", Value: ""}
|
||||
} else if attr.Value == list[i].Value {
|
||||
return *list[i]
|
||||
} else if list[i].Value == All {
|
||||
} else if list[i].Value == Any {
|
||||
a = *list[i]
|
||||
}
|
||||
} else if list[i].Key == All && attr.Value != False {
|
||||
} else if list[i].Key == Any && attr.Value != False {
|
||||
a = *list[i]
|
||||
}
|
||||
}
|
||||
|
@@ -164,7 +164,7 @@ func TestAttr_Find(t *testing.T) {
|
||||
assert.Len(t, attr, 1)
|
||||
result := attr.Find("metrics")
|
||||
|
||||
assert.Equal(t, All, result.Key)
|
||||
assert.Equal(t, Any, result.Key)
|
||||
assert.Equal(t, "", result.Value)
|
||||
})
|
||||
t.Run("Empty", func(t *testing.T) {
|
||||
@@ -182,6 +182,7 @@ func TestAttr_Find(t *testing.T) {
|
||||
|
||||
assert.Len(t, attr, 1)
|
||||
result := attr.Find("*")
|
||||
assert.Equal(t, Any, result.Key)
|
||||
assert.Equal(t, All, result.Key)
|
||||
assert.Equal(t, "", result.Value)
|
||||
})
|
||||
@@ -191,6 +192,7 @@ func TestAttr_Find(t *testing.T) {
|
||||
|
||||
assert.Len(t, attr, 1)
|
||||
result := attr.Find("6VU:*")
|
||||
assert.Equal(t, Any, result.Key)
|
||||
assert.Equal(t, All, result.Key)
|
||||
assert.Equal(t, "", result.Value)
|
||||
})
|
||||
@@ -230,7 +232,7 @@ func TestAttr_Find(t *testing.T) {
|
||||
assert.Len(t, attr, 2)
|
||||
|
||||
result := attr.Find("read")
|
||||
assert.Equal(t, All, result.Key)
|
||||
assert.Equal(t, Any, result.Key)
|
||||
assert.Equal(t, "", result.Value)
|
||||
|
||||
result = attr.Find("read:other")
|
||||
@@ -238,7 +240,7 @@ func TestAttr_Find(t *testing.T) {
|
||||
assert.Equal(t, "other", result.Value)
|
||||
|
||||
result = attr.Find("read:true")
|
||||
assert.Equal(t, All, result.Key)
|
||||
assert.Equal(t, Any, result.Key)
|
||||
assert.Equal(t, "", result.Value)
|
||||
|
||||
result = attr.Find("read:false")
|
||||
|
@@ -1,18 +1,22 @@
|
||||
package list
|
||||
|
||||
const All = "*"
|
||||
// Any matches everything.
|
||||
const Any = "*"
|
||||
|
||||
// All is kept for backward compatibility, but deprecated.
|
||||
const All = Any
|
||||
|
||||
// Contains tests if a string is contained in the list.
|
||||
func Contains(list []string, s string) bool {
|
||||
if len(list) == 0 || s == "" {
|
||||
return false
|
||||
} else if s == All {
|
||||
} else if s == Any {
|
||||
return true
|
||||
}
|
||||
|
||||
// Find matches.
|
||||
for i := range list {
|
||||
if s == list[i] || list[i] == All {
|
||||
if s == list[i] || list[i] == Any {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -27,11 +31,11 @@ func ContainsAny(l, s []string) bool {
|
||||
}
|
||||
|
||||
// If second list contains All, it's a wildcard match.
|
||||
if s[0] == All {
|
||||
if s[0] == Any {
|
||||
return true
|
||||
}
|
||||
for j := 1; j < len(s); j++ {
|
||||
if s[j] == All {
|
||||
if s[j] == Any {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
@@ -4,7 +4,7 @@ package list
|
||||
func Remove(list []string, s string) []string {
|
||||
if len(list) == 0 || s == "" {
|
||||
return list
|
||||
} else if s == All {
|
||||
} else if s == Any {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user