OAuth2: Remove client soft delete and fix client add command #213 #3943

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2024-01-29 21:08:01 +01:00
parent fd0006928e
commit d0ad3c23fb
20 changed files with 216 additions and 110 deletions

View File

@@ -195,7 +195,8 @@ services:
## Prometheus Test Server
## Docs: https://prometheus.io/docs/prometheus/latest/configuration/configuration/#oauth2
## Auth: photoprism client add --id=cs5cpu17n6gj2qo5 --secret=xcCbOrw6I0vcoXzhnOmXhjpVSyFq0l0e -s metrics -n Prometheus -e 600
## Run the following command in a container terminal to register the preconfigured client credentials for testing (adjust parameters as needed):
## ./photoprism client add --id=cs5cpu17n6gj2qo5 --secret=xcCbOrw6I0vcoXzhnOmXhjpVSyFq0l0e -s metrics -n Prometheus -e 60 -t 1
prometheus:
image: prom/prometheus:latest
container_name: prometheus

View File

@@ -113,7 +113,7 @@ func CreateOAuthToken(router *gin.RouterGroup) {
// Deletes old client sessions above the configured limit.
if deleted := client.EnforceAuthTokenLimit(); deleted > 0 {
event.AuditInfo([]string{clientIp, "client %s", "%s deleted"}, f.ClientID, english.Plural(deleted, "old oauth2 session", "old oauth2 sessions"))
event.AuditInfo([]string{clientIp, "client %s", "session %s", "oauth2", "deleted %s"}, f.ClientID, sess.RefID, english.Plural(deleted, "previously created client session", "previously created client sessions"))
}
// Response includes access token, token type, and token lifetime.

View File

@@ -34,7 +34,7 @@ func authShowAction(ctx *cli.Context) error {
sess, err := query.Session(id)
if err != nil {
return fmt.Errorf("session %s not found: %s", clean.LogQuote(id), err)
return fmt.Errorf("session %s not found: %s", clean.Log(id), err)
}
// Get session information.

View File

@@ -3,6 +3,7 @@ package commands
import (
"fmt"
"github.com/dustin/go-humanize/english"
"github.com/manifoldco/promptui"
"github.com/urfave/cli"
@@ -92,7 +93,7 @@ func clientsAddAction(ctx *cli.Context) error {
if client.AuthTokens > 0 {
if authExpires != "" {
authExpires = fmt.Sprintf("%s; up to %d tokens", authExpires, client.AuthTokens)
authExpires = fmt.Sprintf("%s; up to %s", authExpires, english.Plural(int(client.Tokens()), "token", "tokens"))
} else {
authExpires = fmt.Sprintf("up to %d tokens", client.AuthTokens)
}

View File

@@ -52,7 +52,7 @@ func clientsListAction(ctx *cli.Context) error {
if client.AuthTokens > 0 {
if authExpires != "" {
authExpires = fmt.Sprintf("%s; up to %d tokens", authExpires, client.AuthTokens)
authExpires = fmt.Sprintf("%s; up to %s", authExpires, english.Plural(int(client.Tokens()), "token", "tokens"))
} else {
authExpires = fmt.Sprintf("up to %d tokens", client.AuthTokens)
}

View File

@@ -40,7 +40,7 @@ func clientsModAction(ctx *cli.Context) error {
client = entity.FindClientByUID(frm.ID())
if client == nil {
return fmt.Errorf("client %s not found", clean.LogQuote(frm.ID()))
return fmt.Errorf("client %s not found", clean.Log(frm.ID()))
}
// Update client from form values.

View File

@@ -23,21 +23,6 @@ func TestClientsModCommand(t *testing.T) {
assert.Error(t, err)
assert.Empty(t, output)
})
t.Run("ModDeletedClient", func(t *testing.T) {
var err error
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"mod", "--name=New", "--scope=test", "cs5cpu17n6gj2gf7"})
// Run command with test context.
output := capture.Output(func() {
err = ClientsModCommand.Run(ctx)
})
// Check command output for plausibility.
assert.Error(t, err)
assert.Empty(t, output)
})
t.Run("DisableEnableAuth", func(t *testing.T) {
var err error

View File

@@ -44,9 +44,9 @@ func clientsRemoveAction(ctx *cli.Context) error {
m = entity.FindClientByUID(id)
if m == nil {
return fmt.Errorf("client %s not found", clean.LogQuote(id))
return fmt.Errorf("client %s not found", clean.Log(id))
} else if m.Deleted() {
return fmt.Errorf("client %s has already been deleted", clean.LogQuote(id))
return fmt.Errorf("client %s has already been deleted", clean.Log(id))
}
if !ctx.Bool("force") {

View File

@@ -19,8 +19,8 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output0)
assert.NoError(t, err)
assert.NotContains(t, output0, "| DeletedAt | time.Date")
assert.Contains(t, output0, "| DeletedAt | <nil>")
assert.NotContains(t, output0, "not found")
assert.Contains(t, output0, "client_credentials")
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "cs7pvt5h8rw9aaqj"})
@@ -43,8 +43,8 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output2)
assert.NoError(t, err)
assert.NotContains(t, output2, "| DeletedAt | time.Date")
assert.Contains(t, output2, "| DeletedAt | <nil>")
assert.NotContains(t, output2, "not found")
assert.Contains(t, output2, "client_credentials")
})
t.Run("RemoveClient", func(t *testing.T) {
var err error
@@ -57,8 +57,8 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output0)
assert.NoError(t, err)
assert.NotContains(t, output0, "| DeletedAt | time.Date")
assert.Contains(t, output0, "| DeletedAt | <nil>")
assert.NotContains(t, output0, "not found")
assert.Contains(t, output0, "client_credentials")
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "--force", "cs7pvt5h8rw9aaqj"})
@@ -69,7 +69,6 @@ func TestCientsRemoveCommand(t *testing.T) {
})
// Check command output for plausibility.
//t.Logf(output)
assert.NoError(t, err)
assert.Empty(t, output)
@@ -79,10 +78,8 @@ func TestCientsRemoveCommand(t *testing.T) {
err = ClientsShowCommand.Run(ctx2)
})
//t.Logf(output2)
assert.NoError(t, err)
assert.Contains(t, output2, "| DeletedAt | time.Date")
assert.NotContains(t, output2, "| DeletedAt | <nil>")
assert.Error(t, err)
assert.Empty(t, output2)
})
t.Run("NotFound", func(t *testing.T) {
var err error
@@ -96,23 +93,6 @@ func TestCientsRemoveCommand(t *testing.T) {
})
// Check command output for plausibility.
//t.Logf(output)
assert.Error(t, err)
assert.Empty(t, output)
})
t.Run("AlreadyDeleted", func(t *testing.T) {
var err error
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "--force", "cs5cpu17n6gj2gf7"})
// Run command with test context.
output := capture.Output(func() {
err = ClientsRemoveCommand.Run(ctx)
})
// Check command output for plausibility.
//t.Logf(output)
assert.Error(t, err)
assert.Empty(t, output)
})

View File

@@ -36,7 +36,7 @@ func clientsShowAction(ctx *cli.Context) error {
m = entity.FindClientByUID(id)
if m == nil {
return fmt.Errorf("client %s not found", clean.LogQuote(id))
return fmt.Errorf("client %s not found", clean.Log(id))
}
// Get client information.

View File

@@ -47,7 +47,6 @@ type Client struct {
LastActive int64 `json:"LastActive" yaml:"LastActive,omitempty"`
CreatedAt time.Time `json:"CreatedAt" yaml:"-"`
UpdatedAt time.Time `json:"UpdatedAt" yaml:"-"`
DeletedAt *time.Time `sql:"index" json:"DeletedAt,omitempty" yaml:"-"`
}
// TableName returns the entity table name.
@@ -106,16 +105,44 @@ func (m *Client) UID() string {
return m.ClientUID
}
// HasUID tests if the entity has a valid uid.
// HasUID tests if the client has a valid uid.
func (m *Client) HasUID() bool {
return rnd.IsUID(m.ClientUID, ClientUID)
}
// NoUID tests if the client does not have a valid uid.
func (m *Client) NoUID() bool {
return !m.HasUID()
}
// Name returns the client name string.
func (m *Client) Name() string {
return m.ClientName
}
// HasName tests if the client has a name.
func (m *Client) HasName() bool {
return m.ClientName != ""
}
// NoName tests if the client does not have a name.
func (m *Client) NoName() bool {
return !m.HasName()
}
// String returns the client id or name for use in logs and reports.
func (m *Client) String() string {
if m == nil {
return report.NotAssigned
} else if m.HasUID() {
return m.UID()
} else if m.HasName() {
return m.Name()
}
return report.NotAssigned
}
// SetName sets a custom client name.
func (m *Client) SetName(s string) *Client {
if s = clean.Name(s); s != "" {
@@ -152,7 +179,7 @@ func (m *Client) AclRole() acl.Role {
return acl.RoleNone
}
// User returns the related user account, if any.
// User returns the user who owns the client, if any.
func (m *Client) User() *User {
if m.user != nil {
return m.user
@@ -168,7 +195,12 @@ func (m *Client) User() *User {
return &User{}
}
// SetUser updates the related user account.
// HasUser checks the client belongs to a user.
func (m *Client) HasUser() bool {
return rnd.IsUID(m.UserUID, UserUID)
}
// SetUser sets the user to which the client belongs.
func (m *Client) SetUser(u *User) *Client {
if u == nil {
return m
@@ -227,7 +259,7 @@ func (m *Client) Save() error {
}
// Delete related sessions if authentication is disabled.
if m.AuthEnabled && m.DeletedAt == nil {
if m.AuthEnabled {
return nil
} else if _, err := m.DeleteSessions(); err != nil {
return err
@@ -257,8 +289,8 @@ func (m *Client) DeleteSessions() (deleted int, err error) {
return 0, fmt.Errorf("client uid is missing")
}
if deleted = DeleteClientSessions(m.UID(), "", 0); deleted > 0 {
event.AuditInfo([]string{"client %s", "%s deleted"}, m.ClientUID, english.Plural(deleted, "session", "sessions"))
if deleted = DeleteClientSessions(m, "", 0); deleted > 0 {
event.AuditInfo([]string{"client %s", "deleted %s"}, m.String(), english.Plural(deleted, "session", "sessions"))
}
return deleted, nil
@@ -266,11 +298,20 @@ func (m *Client) DeleteSessions() (deleted int, err error) {
// Deleted checks if the client has been deleted.
func (m *Client) Deleted() bool {
if m.DeletedAt == nil {
return false
if m == nil {
return true
}
return !m.DeletedAt.IsZero()
return false
}
// Disabled checks if the client authentication has been disabled.
func (m *Client) Disabled() bool {
if m == nil {
return true
}
return !m.AuthEnabled
}
// Updates multiple properties in the database.
@@ -412,7 +453,7 @@ func (m *Client) EnforceAuthTokenLimit() (deleted int) {
return 0
}
return DeleteClientSessions(m.ClientUID, authn.MethodOAuth2, m.AuthTokens)
return DeleteClientSessions(m, authn.MethodOAuth2, m.AuthTokens)
}
// Expires returns the auth expiration duration.
@@ -430,8 +471,12 @@ func (m *Client) SetExpires(i int64) *Client {
}
// Tokens returns maximum number of access tokens this client can create.
func (m *Client) Tokens() time.Duration {
return time.Duration(m.AuthExpires) * time.Second
func (m *Client) Tokens() int64 {
if m.AuthTokens == 0 {
return 1
}
return m.AuthTokens
}
// SetTokens sets a custom access token limit for this client.

View File

@@ -9,7 +9,7 @@ import (
// AddClient creates a new client and returns it if successful.
func AddClient(frm form.Client) (client *Client, err error) {
if found := FindClientByUID(frm.ID()); found != nil {
return found, fmt.Errorf("client id %s already exists", found.ClientUID)
return found, fmt.Errorf("client %s already exists", found.ClientUID)
}
client = NewClient().SetFormValues(frm)

View File

@@ -97,7 +97,7 @@ var ClientFixtures = ClientMap{
AuthEnabled: true,
LastActive: 0,
},
"deleted": {
"disabled": {
ClientUID: "cs5cpu17n6gj2gf7",
UserUID: "",
UserName: "",
@@ -112,9 +112,8 @@ var ClientFixtures = ClientMap{
AuthScope: "metrics",
AuthExpires: unix.Hour,
AuthTokens: 2,
AuthEnabled: true,
AuthEnabled: false,
LastActive: 0,
DeletedAt: TimePointer(),
},
"analytics": {
ClientUID: "cs7pvt5h8rw9aaqj",

View File

@@ -199,8 +199,17 @@ func TestClient_Delete(t *testing.T) {
}
func TestClient_Deleted(t *testing.T) {
var ptr *Client
assert.False(t, ClientFixtures.Pointer("alice").Deleted())
assert.True(t, ClientFixtures.Pointer("deleted").Deleted())
assert.False(t, ClientFixtures.Pointer("deleted").Deleted())
assert.True(t, ptr.Deleted())
}
func TestClient_Disabled(t *testing.T) {
var ptr *Client
assert.False(t, ClientFixtures.Pointer("alice").Disabled())
assert.True(t, ClientFixtures.Pointer("deleted").Disabled())
assert.True(t, ptr.Disabled())
}
func TestClient_Updates(t *testing.T) {
@@ -297,6 +306,31 @@ func TestClient_UpdateLastActive(t *testing.T) {
})
}
func TestClient_Tokens(t *testing.T) {
t.Run("Set", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: 0}
assert.Equal(t, int64(1), m.Tokens())
m.SetTokens(0)
assert.Equal(t, int64(1), m.Tokens())
m.SetTokens(1)
assert.Equal(t, int64(1), m.Tokens())
m.SetTokens(10)
assert.Equal(t, int64(10), m.Tokens())
})
t.Run("Unlimited", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: -1}
assert.Equal(t, int64(-1), m.Tokens())
})
t.Run("One", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: 1}
assert.Equal(t, int64(1), m.Tokens())
})
t.Run("Many", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: 10}
assert.Equal(t, int64(10), m.Tokens())
})
}
func TestClient_EnforceAuthTokenLimit(t *testing.T) {
t.Run("EmptyUID", func(t *testing.T) {
var m = Client{ClientName: "No UUID"}
@@ -379,6 +413,30 @@ func TestClient_Expires(t *testing.T) {
})
}
func TestClient_String(t *testing.T) {
t.Run("Default", func(t *testing.T) {
m := &Client{}
assert.Equal(t, m.String(), "n/a")
})
t.Run("NewClient", func(t *testing.T) {
m := NewClient()
assert.Equal(t, m.String(), "n/a")
})
t.Run("Metrics", func(t *testing.T) {
m := ClientFixtures.Get("metrics")
assert.Equal(t, m.String(), "cs5cpu17n6gj2qo5")
})
t.Run("Alice", func(t *testing.T) {
m := ClientFixtures.Get("alice")
assert.Equal(t, m.String(), "cs5gfen1bgxz7s9i")
})
t.Run("Name", func(t *testing.T) {
m := NewClient()
m.ClientName = "Foo Bar"
assert.Equal(t, m.String(), "Foo Bar")
})
}
func TestClient_UserInfo(t *testing.T) {
t.Run("New", func(t *testing.T) {
assert.Equal(t, report.NotAssigned, NewClient().UserInfo())

View File

@@ -7,6 +7,8 @@ import (
"net/http"
"time"
"github.com/dustin/go-humanize/english"
"github.com/gin-gonic/gin"
"github.com/jinzhu/gorm"
@@ -120,21 +122,36 @@ func DeleteExpiredSessions() (deleted int) {
}
// DeleteClientSessions deletes client sessions above the specified limit.
func DeleteClientSessions(clientUID string, authMethod authn.MethodType, limit int64) (deleted int) {
if !rnd.IsUID(clientUID, ClientUID) || limit < 0 {
func DeleteClientSessions(client *Client, authMethod authn.MethodType, limit int64) (deleted int) {
if limit < 0 {
return 0
} else if client == nil {
return 0
}
found := Sessions{}
q := Db()
q := Db().Where("client_uid = ?", clientUID)
if client.HasUID() {
q = q.Where("client_uid = ?", client.UID())
} else if client.HasName() {
q = q.Where("client_name = ?", client.Name())
} else {
return 0
}
if client.HasUser() {
q = q.Where("user_uid = ?", client.UserUID)
}
if !authMethod.IsDefault() {
q = q.Where("auth_method = ?", authMethod.String())
}
if err := q.Order("created_at DESC").Limit(2147483648).Offset(limit).
Find(&found).Error; err != nil {
q = q.Order("created_at DESC").Limit(2147483648).Offset(limit)
found := Sessions{}
if err := q.Find(&found).Error; err != nil {
event.AuditErr([]string{"failed to fetch client sessions", "%s"}, err)
return deleted
}
@@ -255,6 +272,17 @@ func (m *Session) Save() error {
m.Cache()
}
// Limit the number of sessions that are created with an app password.
if !m.Method().IsSession() {
return nil
} else if !m.Provider().IsApplication() {
return nil
} else if client := m.Client(); client.NoName() || client.Tokens() < 1 {
return nil
} else if deleted := DeleteClientSessions(client, authn.MethodSession, client.Tokens()); deleted > 0 {
event.AuditInfo([]string{m.IP(), "session %s", "deleted %s"}, m.RefID, english.Plural(deleted, "previously created client session", "previously created client sessions"))
}
return nil
}
@@ -348,11 +376,7 @@ func (m *Session) ClientRole() acl.Role {
// ClientInfo returns the session's client identifier string.
func (m *Session) ClientInfo() string {
if m.HasClient() {
if uid := m.Client().UID(); uid != "" {
return uid
} else if name := m.Client().Name(); name != "" {
return name
}
return m.Client().String()
} else if m.ClientName != "" {
return m.ClientName
}

View File

@@ -107,13 +107,13 @@ func TestDeleteClientSessions(t *testing.T) {
// Create new test client.
client := NewClient()
client.ClientUID = "cs5gfen1bgx00000"
client.ClientUID = clientUID
// Make sure no sessions exist yet and test missing arguments.
assert.Equal(t, 0, DeleteClientSessions("", "", -1))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions("", authn.MethodDefault, 0))
assert.Equal(t, 0, DeleteClientSessions(&Client{}, "", -1))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(&Client{}, authn.MethodDefault, 0))
// Create 10 test client sessions.
for i := 0; i < 10; i++ {
@@ -126,11 +126,11 @@ func TestDeleteClientSessions(t *testing.T) {
}
// Check if the expected number of sessions is deleted until none are left.
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOIDC, 1))
assert.Equal(t, 9, DeleteClientSessions(clientUID, authn.MethodOAuth2, 1))
assert.Equal(t, 1, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOIDC, 1))
assert.Equal(t, 9, DeleteClientSessions(client, authn.MethodOAuth2, 1))
assert.Equal(t, 1, DeleteClientSessions(client, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0))
}
func TestSessionStatusUnauthorized(t *testing.T) {

View File

@@ -94,6 +94,9 @@ func AddClientFromCli(ctx *cli.Context) Client {
f.AuthScope = "*"
}
f.AuthExpires = ctx.Int64("expires")
f.AuthTokens = ctx.Int64("tokens")
return f
}

View File

@@ -31,6 +31,11 @@ func (t MethodType) IsDefault() bool {
return t.String() == MethodDefault.String()
}
// IsSession checks if this is the session method.
func (t MethodType) IsSession() bool {
return t.String() == MethodSession.String()
}
// String returns the provider identifier as a string.
func (t MethodType) String() string {
switch t {

View File

@@ -63,6 +63,11 @@ func (t ProviderType) IsClient() bool {
return list.Contains(ClientProviders, string(t))
}
// IsApplication checks if the authentication is provided for an application.
func (t ProviderType) IsApplication() bool {
return t == ProviderApplication
}
// IsDefault checks if this is the default provider.
func (t ProviderType) IsDefault() bool {
return t.String() == ProviderDefault.String()

View File

@@ -1,5 +1,5 @@
global:
scrape_interval: 15s
scrape_interval: 60s
scrape_timeout: 10s
scrape_configs:
- job_name: "photoprism"