OAuth2: Remove client soft delete and fix client add command #213 #3943

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2024-01-29 21:08:01 +01:00
parent fd0006928e
commit d0ad3c23fb
20 changed files with 216 additions and 110 deletions

View File

@@ -195,7 +195,8 @@ services:
## Prometheus Test Server ## Prometheus Test Server
## Docs: https://prometheus.io/docs/prometheus/latest/configuration/configuration/#oauth2 ## Docs: https://prometheus.io/docs/prometheus/latest/configuration/configuration/#oauth2
## Auth: photoprism client add --id=cs5cpu17n6gj2qo5 --secret=xcCbOrw6I0vcoXzhnOmXhjpVSyFq0l0e -s metrics -n Prometheus -e 600 ## Run the following command in a container terminal to register the preconfigured client credentials for testing (adjust parameters as needed):
## ./photoprism client add --id=cs5cpu17n6gj2qo5 --secret=xcCbOrw6I0vcoXzhnOmXhjpVSyFq0l0e -s metrics -n Prometheus -e 60 -t 1
prometheus: prometheus:
image: prom/prometheus:latest image: prom/prometheus:latest
container_name: prometheus container_name: prometheus

View File

@@ -113,7 +113,7 @@ func CreateOAuthToken(router *gin.RouterGroup) {
// Deletes old client sessions above the configured limit. // Deletes old client sessions above the configured limit.
if deleted := client.EnforceAuthTokenLimit(); deleted > 0 { if deleted := client.EnforceAuthTokenLimit(); deleted > 0 {
event.AuditInfo([]string{clientIp, "client %s", "%s deleted"}, f.ClientID, english.Plural(deleted, "old oauth2 session", "old oauth2 sessions")) event.AuditInfo([]string{clientIp, "client %s", "session %s", "oauth2", "deleted %s"}, f.ClientID, sess.RefID, english.Plural(deleted, "previously created client session", "previously created client sessions"))
} }
// Response includes access token, token type, and token lifetime. // Response includes access token, token type, and token lifetime.

View File

@@ -34,7 +34,7 @@ func authShowAction(ctx *cli.Context) error {
sess, err := query.Session(id) sess, err := query.Session(id)
if err != nil { if err != nil {
return fmt.Errorf("session %s not found: %s", clean.LogQuote(id), err) return fmt.Errorf("session %s not found: %s", clean.Log(id), err)
} }
// Get session information. // Get session information.

View File

@@ -3,6 +3,7 @@ package commands
import ( import (
"fmt" "fmt"
"github.com/dustin/go-humanize/english"
"github.com/manifoldco/promptui" "github.com/manifoldco/promptui"
"github.com/urfave/cli" "github.com/urfave/cli"
@@ -92,7 +93,7 @@ func clientsAddAction(ctx *cli.Context) error {
if client.AuthTokens > 0 { if client.AuthTokens > 0 {
if authExpires != "" { if authExpires != "" {
authExpires = fmt.Sprintf("%s; up to %d tokens", authExpires, client.AuthTokens) authExpires = fmt.Sprintf("%s; up to %s", authExpires, english.Plural(int(client.Tokens()), "token", "tokens"))
} else { } else {
authExpires = fmt.Sprintf("up to %d tokens", client.AuthTokens) authExpires = fmt.Sprintf("up to %d tokens", client.AuthTokens)
} }

View File

@@ -52,7 +52,7 @@ func clientsListAction(ctx *cli.Context) error {
if client.AuthTokens > 0 { if client.AuthTokens > 0 {
if authExpires != "" { if authExpires != "" {
authExpires = fmt.Sprintf("%s; up to %d tokens", authExpires, client.AuthTokens) authExpires = fmt.Sprintf("%s; up to %s", authExpires, english.Plural(int(client.Tokens()), "token", "tokens"))
} else { } else {
authExpires = fmt.Sprintf("up to %d tokens", client.AuthTokens) authExpires = fmt.Sprintf("up to %d tokens", client.AuthTokens)
} }

View File

@@ -40,7 +40,7 @@ func clientsModAction(ctx *cli.Context) error {
client = entity.FindClientByUID(frm.ID()) client = entity.FindClientByUID(frm.ID())
if client == nil { if client == nil {
return fmt.Errorf("client %s not found", clean.LogQuote(frm.ID())) return fmt.Errorf("client %s not found", clean.Log(frm.ID()))
} }
// Update client from form values. // Update client from form values.

View File

@@ -23,21 +23,6 @@ func TestClientsModCommand(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, output) assert.Empty(t, output)
}) })
t.Run("ModDeletedClient", func(t *testing.T) {
var err error
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"mod", "--name=New", "--scope=test", "cs5cpu17n6gj2gf7"})
// Run command with test context.
output := capture.Output(func() {
err = ClientsModCommand.Run(ctx)
})
// Check command output for plausibility.
assert.Error(t, err)
assert.Empty(t, output)
})
t.Run("DisableEnableAuth", func(t *testing.T) { t.Run("DisableEnableAuth", func(t *testing.T) {
var err error var err error

View File

@@ -44,9 +44,9 @@ func clientsRemoveAction(ctx *cli.Context) error {
m = entity.FindClientByUID(id) m = entity.FindClientByUID(id)
if m == nil { if m == nil {
return fmt.Errorf("client %s not found", clean.LogQuote(id)) return fmt.Errorf("client %s not found", clean.Log(id))
} else if m.Deleted() { } else if m.Deleted() {
return fmt.Errorf("client %s has already been deleted", clean.LogQuote(id)) return fmt.Errorf("client %s has already been deleted", clean.Log(id))
} }
if !ctx.Bool("force") { if !ctx.Bool("force") {

View File

@@ -19,8 +19,8 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output0) //t.Logf(output0)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotContains(t, output0, "| DeletedAt | time.Date") assert.NotContains(t, output0, "not found")
assert.Contains(t, output0, "| DeletedAt | <nil>") assert.Contains(t, output0, "client_credentials")
// Create test context with flags and arguments. // Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "cs7pvt5h8rw9aaqj"}) ctx := NewTestContext([]string{"rm", "cs7pvt5h8rw9aaqj"})
@@ -43,8 +43,8 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output2) //t.Logf(output2)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotContains(t, output2, "| DeletedAt | time.Date") assert.NotContains(t, output2, "not found")
assert.Contains(t, output2, "| DeletedAt | <nil>") assert.Contains(t, output2, "client_credentials")
}) })
t.Run("RemoveClient", func(t *testing.T) { t.Run("RemoveClient", func(t *testing.T) {
var err error var err error
@@ -57,8 +57,8 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output0) //t.Logf(output0)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotContains(t, output0, "| DeletedAt | time.Date") assert.NotContains(t, output0, "not found")
assert.Contains(t, output0, "| DeletedAt | <nil>") assert.Contains(t, output0, "client_credentials")
// Create test context with flags and arguments. // Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "--force", "cs7pvt5h8rw9aaqj"}) ctx := NewTestContext([]string{"rm", "--force", "cs7pvt5h8rw9aaqj"})
@@ -69,7 +69,6 @@ func TestCientsRemoveCommand(t *testing.T) {
}) })
// Check command output for plausibility. // Check command output for plausibility.
//t.Logf(output)
assert.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, output) assert.Empty(t, output)
@@ -79,10 +78,8 @@ func TestCientsRemoveCommand(t *testing.T) {
err = ClientsShowCommand.Run(ctx2) err = ClientsShowCommand.Run(ctx2)
}) })
//t.Logf(output2) assert.Error(t, err)
assert.NoError(t, err) assert.Empty(t, output2)
assert.Contains(t, output2, "| DeletedAt | time.Date")
assert.NotContains(t, output2, "| DeletedAt | <nil>")
}) })
t.Run("NotFound", func(t *testing.T) { t.Run("NotFound", func(t *testing.T) {
var err error var err error
@@ -96,23 +93,6 @@ func TestCientsRemoveCommand(t *testing.T) {
}) })
// Check command output for plausibility. // Check command output for plausibility.
//t.Logf(output)
assert.Error(t, err)
assert.Empty(t, output)
})
t.Run("AlreadyDeleted", func(t *testing.T) {
var err error
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "--force", "cs5cpu17n6gj2gf7"})
// Run command with test context.
output := capture.Output(func() {
err = ClientsRemoveCommand.Run(ctx)
})
// Check command output for plausibility.
//t.Logf(output)
assert.Error(t, err) assert.Error(t, err)
assert.Empty(t, output) assert.Empty(t, output)
}) })

View File

@@ -36,7 +36,7 @@ func clientsShowAction(ctx *cli.Context) error {
m = entity.FindClientByUID(id) m = entity.FindClientByUID(id)
if m == nil { if m == nil {
return fmt.Errorf("client %s not found", clean.LogQuote(id)) return fmt.Errorf("client %s not found", clean.Log(id))
} }
// Get client information. // Get client information.

View File

@@ -29,25 +29,24 @@ type Clients []Client
// Client represents a client application. // Client represents a client application.
type Client struct { type Client struct {
ClientUID string `gorm:"type:VARBINARY(42);primary_key;auto_increment:false;" json:"-" yaml:"ClientUID"` ClientUID string `gorm:"type:VARBINARY(42);primary_key;auto_increment:false;" json:"-" yaml:"ClientUID"`
UserUID string `gorm:"type:VARBINARY(42);index;default:'';" json:"UserUID" yaml:"UserUID,omitempty"` UserUID string `gorm:"type:VARBINARY(42);index;default:'';" json:"UserUID" yaml:"UserUID,omitempty"`
UserName string `gorm:"size:200;index;" json:"UserName" yaml:"UserName,omitempty"` UserName string `gorm:"size:200;index;" json:"UserName" yaml:"UserName,omitempty"`
user *User `gorm:"-" yaml:"-"` user *User `gorm:"-" yaml:"-"`
ClientName string `gorm:"size:200;" json:"ClientName" yaml:"ClientName,omitempty"` ClientName string `gorm:"size:200;" json:"ClientName" yaml:"ClientName,omitempty"`
ClientRole string `gorm:"size:64;default:'';" json:"ClientRole" yaml:"ClientRole,omitempty"` ClientRole string `gorm:"size:64;default:'';" json:"ClientRole" yaml:"ClientRole,omitempty"`
ClientType string `gorm:"type:VARBINARY(16)" json:"ClientType" yaml:"ClientType,omitempty"` ClientType string `gorm:"type:VARBINARY(16)" json:"ClientType" yaml:"ClientType,omitempty"`
ClientURL string `gorm:"type:VARBINARY(255);default:'';column:client_url;" json:"ClientURL" yaml:"ClientURL,omitempty"` ClientURL string `gorm:"type:VARBINARY(255);default:'';column:client_url;" json:"ClientURL" yaml:"ClientURL,omitempty"`
CallbackURL string `gorm:"type:VARBINARY(255);default:'';column:callback_url;" json:"CallbackURL" yaml:"CallbackURL,omitempty"` CallbackURL string `gorm:"type:VARBINARY(255);default:'';column:callback_url;" json:"CallbackURL" yaml:"CallbackURL,omitempty"`
AuthProvider string `gorm:"type:VARBINARY(128);default:'';" json:"AuthProvider" yaml:"AuthProvider,omitempty"` AuthProvider string `gorm:"type:VARBINARY(128);default:'';" json:"AuthProvider" yaml:"AuthProvider,omitempty"`
AuthMethod string `gorm:"type:VARBINARY(128);default:'';" json:"AuthMethod" yaml:"AuthMethod,omitempty"` AuthMethod string `gorm:"type:VARBINARY(128);default:'';" json:"AuthMethod" yaml:"AuthMethod,omitempty"`
AuthScope string `gorm:"size:1024;default:'';" json:"AuthScope" yaml:"AuthScope,omitempty"` AuthScope string `gorm:"size:1024;default:'';" json:"AuthScope" yaml:"AuthScope,omitempty"`
AuthExpires int64 `json:"AuthExpires" yaml:"AuthExpires,omitempty"` AuthExpires int64 `json:"AuthExpires" yaml:"AuthExpires,omitempty"`
AuthTokens int64 `json:"AuthTokens" yaml:"AuthTokens,omitempty"` AuthTokens int64 `json:"AuthTokens" yaml:"AuthTokens,omitempty"`
AuthEnabled bool `json:"AuthEnabled" yaml:"AuthEnabled,omitempty"` AuthEnabled bool `json:"AuthEnabled" yaml:"AuthEnabled,omitempty"`
LastActive int64 `json:"LastActive" yaml:"LastActive,omitempty"` LastActive int64 `json:"LastActive" yaml:"LastActive,omitempty"`
CreatedAt time.Time `json:"CreatedAt" yaml:"-"` CreatedAt time.Time `json:"CreatedAt" yaml:"-"`
UpdatedAt time.Time `json:"UpdatedAt" yaml:"-"` UpdatedAt time.Time `json:"UpdatedAt" yaml:"-"`
DeletedAt *time.Time `sql:"index" json:"DeletedAt,omitempty" yaml:"-"`
} }
// TableName returns the entity table name. // TableName returns the entity table name.
@@ -106,16 +105,44 @@ func (m *Client) UID() string {
return m.ClientUID return m.ClientUID
} }
// HasUID tests if the entity has a valid uid. // HasUID tests if the client has a valid uid.
func (m *Client) HasUID() bool { func (m *Client) HasUID() bool {
return rnd.IsUID(m.ClientUID, ClientUID) return rnd.IsUID(m.ClientUID, ClientUID)
} }
// NoUID tests if the client does not have a valid uid.
func (m *Client) NoUID() bool {
return !m.HasUID()
}
// Name returns the client name string. // Name returns the client name string.
func (m *Client) Name() string { func (m *Client) Name() string {
return m.ClientName return m.ClientName
} }
// HasName tests if the client has a name.
func (m *Client) HasName() bool {
return m.ClientName != ""
}
// NoName tests if the client does not have a name.
func (m *Client) NoName() bool {
return !m.HasName()
}
// String returns the client id or name for use in logs and reports.
func (m *Client) String() string {
if m == nil {
return report.NotAssigned
} else if m.HasUID() {
return m.UID()
} else if m.HasName() {
return m.Name()
}
return report.NotAssigned
}
// SetName sets a custom client name. // SetName sets a custom client name.
func (m *Client) SetName(s string) *Client { func (m *Client) SetName(s string) *Client {
if s = clean.Name(s); s != "" { if s = clean.Name(s); s != "" {
@@ -152,7 +179,7 @@ func (m *Client) AclRole() acl.Role {
return acl.RoleNone return acl.RoleNone
} }
// User returns the related user account, if any. // User returns the user who owns the client, if any.
func (m *Client) User() *User { func (m *Client) User() *User {
if m.user != nil { if m.user != nil {
return m.user return m.user
@@ -168,7 +195,12 @@ func (m *Client) User() *User {
return &User{} return &User{}
} }
// SetUser updates the related user account. // HasUser checks the client belongs to a user.
func (m *Client) HasUser() bool {
return rnd.IsUID(m.UserUID, UserUID)
}
// SetUser sets the user to which the client belongs.
func (m *Client) SetUser(u *User) *Client { func (m *Client) SetUser(u *User) *Client {
if u == nil { if u == nil {
return m return m
@@ -227,7 +259,7 @@ func (m *Client) Save() error {
} }
// Delete related sessions if authentication is disabled. // Delete related sessions if authentication is disabled.
if m.AuthEnabled && m.DeletedAt == nil { if m.AuthEnabled {
return nil return nil
} else if _, err := m.DeleteSessions(); err != nil { } else if _, err := m.DeleteSessions(); err != nil {
return err return err
@@ -257,8 +289,8 @@ func (m *Client) DeleteSessions() (deleted int, err error) {
return 0, fmt.Errorf("client uid is missing") return 0, fmt.Errorf("client uid is missing")
} }
if deleted = DeleteClientSessions(m.UID(), "", 0); deleted > 0 { if deleted = DeleteClientSessions(m, "", 0); deleted > 0 {
event.AuditInfo([]string{"client %s", "%s deleted"}, m.ClientUID, english.Plural(deleted, "session", "sessions")) event.AuditInfo([]string{"client %s", "deleted %s"}, m.String(), english.Plural(deleted, "session", "sessions"))
} }
return deleted, nil return deleted, nil
@@ -266,11 +298,20 @@ func (m *Client) DeleteSessions() (deleted int, err error) {
// Deleted checks if the client has been deleted. // Deleted checks if the client has been deleted.
func (m *Client) Deleted() bool { func (m *Client) Deleted() bool {
if m.DeletedAt == nil { if m == nil {
return false return true
} }
return !m.DeletedAt.IsZero() return false
}
// Disabled checks if the client authentication has been disabled.
func (m *Client) Disabled() bool {
if m == nil {
return true
}
return !m.AuthEnabled
} }
// Updates multiple properties in the database. // Updates multiple properties in the database.
@@ -412,7 +453,7 @@ func (m *Client) EnforceAuthTokenLimit() (deleted int) {
return 0 return 0
} }
return DeleteClientSessions(m.ClientUID, authn.MethodOAuth2, m.AuthTokens) return DeleteClientSessions(m, authn.MethodOAuth2, m.AuthTokens)
} }
// Expires returns the auth expiration duration. // Expires returns the auth expiration duration.
@@ -430,8 +471,12 @@ func (m *Client) SetExpires(i int64) *Client {
} }
// Tokens returns maximum number of access tokens this client can create. // Tokens returns maximum number of access tokens this client can create.
func (m *Client) Tokens() time.Duration { func (m *Client) Tokens() int64 {
return time.Duration(m.AuthExpires) * time.Second if m.AuthTokens == 0 {
return 1
}
return m.AuthTokens
} }
// SetTokens sets a custom access token limit for this client. // SetTokens sets a custom access token limit for this client.

View File

@@ -9,7 +9,7 @@ import (
// AddClient creates a new client and returns it if successful. // AddClient creates a new client and returns it if successful.
func AddClient(frm form.Client) (client *Client, err error) { func AddClient(frm form.Client) (client *Client, err error) {
if found := FindClientByUID(frm.ID()); found != nil { if found := FindClientByUID(frm.ID()); found != nil {
return found, fmt.Errorf("client id %s already exists", found.ClientUID) return found, fmt.Errorf("client %s already exists", found.ClientUID)
} }
client = NewClient().SetFormValues(frm) client = NewClient().SetFormValues(frm)

View File

@@ -97,7 +97,7 @@ var ClientFixtures = ClientMap{
AuthEnabled: true, AuthEnabled: true,
LastActive: 0, LastActive: 0,
}, },
"deleted": { "disabled": {
ClientUID: "cs5cpu17n6gj2gf7", ClientUID: "cs5cpu17n6gj2gf7",
UserUID: "", UserUID: "",
UserName: "", UserName: "",
@@ -112,9 +112,8 @@ var ClientFixtures = ClientMap{
AuthScope: "metrics", AuthScope: "metrics",
AuthExpires: unix.Hour, AuthExpires: unix.Hour,
AuthTokens: 2, AuthTokens: 2,
AuthEnabled: true, AuthEnabled: false,
LastActive: 0, LastActive: 0,
DeletedAt: TimePointer(),
}, },
"analytics": { "analytics": {
ClientUID: "cs7pvt5h8rw9aaqj", ClientUID: "cs7pvt5h8rw9aaqj",

View File

@@ -199,8 +199,17 @@ func TestClient_Delete(t *testing.T) {
} }
func TestClient_Deleted(t *testing.T) { func TestClient_Deleted(t *testing.T) {
var ptr *Client
assert.False(t, ClientFixtures.Pointer("alice").Deleted()) assert.False(t, ClientFixtures.Pointer("alice").Deleted())
assert.True(t, ClientFixtures.Pointer("deleted").Deleted()) assert.False(t, ClientFixtures.Pointer("deleted").Deleted())
assert.True(t, ptr.Deleted())
}
func TestClient_Disabled(t *testing.T) {
var ptr *Client
assert.False(t, ClientFixtures.Pointer("alice").Disabled())
assert.True(t, ClientFixtures.Pointer("deleted").Disabled())
assert.True(t, ptr.Disabled())
} }
func TestClient_Updates(t *testing.T) { func TestClient_Updates(t *testing.T) {
@@ -297,6 +306,31 @@ func TestClient_UpdateLastActive(t *testing.T) {
}) })
} }
func TestClient_Tokens(t *testing.T) {
t.Run("Set", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: 0}
assert.Equal(t, int64(1), m.Tokens())
m.SetTokens(0)
assert.Equal(t, int64(1), m.Tokens())
m.SetTokens(1)
assert.Equal(t, int64(1), m.Tokens())
m.SetTokens(10)
assert.Equal(t, int64(10), m.Tokens())
})
t.Run("Unlimited", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: -1}
assert.Equal(t, int64(-1), m.Tokens())
})
t.Run("One", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: 1}
assert.Equal(t, int64(1), m.Tokens())
})
t.Run("Many", func(t *testing.T) {
var m = Client{ClientName: "cs5cpu17n6gj2bbb", AuthTokens: 10}
assert.Equal(t, int64(10), m.Tokens())
})
}
func TestClient_EnforceAuthTokenLimit(t *testing.T) { func TestClient_EnforceAuthTokenLimit(t *testing.T) {
t.Run("EmptyUID", func(t *testing.T) { t.Run("EmptyUID", func(t *testing.T) {
var m = Client{ClientName: "No UUID"} var m = Client{ClientName: "No UUID"}
@@ -379,6 +413,30 @@ func TestClient_Expires(t *testing.T) {
}) })
} }
func TestClient_String(t *testing.T) {
t.Run("Default", func(t *testing.T) {
m := &Client{}
assert.Equal(t, m.String(), "n/a")
})
t.Run("NewClient", func(t *testing.T) {
m := NewClient()
assert.Equal(t, m.String(), "n/a")
})
t.Run("Metrics", func(t *testing.T) {
m := ClientFixtures.Get("metrics")
assert.Equal(t, m.String(), "cs5cpu17n6gj2qo5")
})
t.Run("Alice", func(t *testing.T) {
m := ClientFixtures.Get("alice")
assert.Equal(t, m.String(), "cs5gfen1bgxz7s9i")
})
t.Run("Name", func(t *testing.T) {
m := NewClient()
m.ClientName = "Foo Bar"
assert.Equal(t, m.String(), "Foo Bar")
})
}
func TestClient_UserInfo(t *testing.T) { func TestClient_UserInfo(t *testing.T) {
t.Run("New", func(t *testing.T) { t.Run("New", func(t *testing.T) {
assert.Equal(t, report.NotAssigned, NewClient().UserInfo()) assert.Equal(t, report.NotAssigned, NewClient().UserInfo())

View File

@@ -7,6 +7,8 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/dustin/go-humanize/english"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
@@ -120,21 +122,36 @@ func DeleteExpiredSessions() (deleted int) {
} }
// DeleteClientSessions deletes client sessions above the specified limit. // DeleteClientSessions deletes client sessions above the specified limit.
func DeleteClientSessions(clientUID string, authMethod authn.MethodType, limit int64) (deleted int) { func DeleteClientSessions(client *Client, authMethod authn.MethodType, limit int64) (deleted int) {
if !rnd.IsUID(clientUID, ClientUID) || limit < 0 { if limit < 0 {
return 0
} else if client == nil {
return 0 return 0
} }
found := Sessions{} q := Db()
q := Db().Where("client_uid = ?", clientUID) if client.HasUID() {
q = q.Where("client_uid = ?", client.UID())
} else if client.HasName() {
q = q.Where("client_name = ?", client.Name())
} else {
return 0
}
if client.HasUser() {
q = q.Where("user_uid = ?", client.UserUID)
}
if !authMethod.IsDefault() { if !authMethod.IsDefault() {
q = q.Where("auth_method = ?", authMethod.String()) q = q.Where("auth_method = ?", authMethod.String())
} }
if err := q.Order("created_at DESC").Limit(2147483648).Offset(limit). q = q.Order("created_at DESC").Limit(2147483648).Offset(limit)
Find(&found).Error; err != nil {
found := Sessions{}
if err := q.Find(&found).Error; err != nil {
event.AuditErr([]string{"failed to fetch client sessions", "%s"}, err) event.AuditErr([]string{"failed to fetch client sessions", "%s"}, err)
return deleted return deleted
} }
@@ -255,6 +272,17 @@ func (m *Session) Save() error {
m.Cache() m.Cache()
} }
// Limit the number of sessions that are created with an app password.
if !m.Method().IsSession() {
return nil
} else if !m.Provider().IsApplication() {
return nil
} else if client := m.Client(); client.NoName() || client.Tokens() < 1 {
return nil
} else if deleted := DeleteClientSessions(client, authn.MethodSession, client.Tokens()); deleted > 0 {
event.AuditInfo([]string{m.IP(), "session %s", "deleted %s"}, m.RefID, english.Plural(deleted, "previously created client session", "previously created client sessions"))
}
return nil return nil
} }
@@ -348,11 +376,7 @@ func (m *Session) ClientRole() acl.Role {
// ClientInfo returns the session's client identifier string. // ClientInfo returns the session's client identifier string.
func (m *Session) ClientInfo() string { func (m *Session) ClientInfo() string {
if m.HasClient() { if m.HasClient() {
if uid := m.Client().UID(); uid != "" { return m.Client().String()
return uid
} else if name := m.Client().Name(); name != "" {
return name
}
} else if m.ClientName != "" { } else if m.ClientName != "" {
return m.ClientName return m.ClientName
} }

View File

@@ -107,13 +107,13 @@ func TestDeleteClientSessions(t *testing.T) {
// Create new test client. // Create new test client.
client := NewClient() client := NewClient()
client.ClientUID = "cs5gfen1bgx00000" client.ClientUID = clientUID
// Make sure no sessions exist yet and test missing arguments. // Make sure no sessions exist yet and test missing arguments.
assert.Equal(t, 0, DeleteClientSessions("", "", -1)) assert.Equal(t, 0, DeleteClientSessions(&Client{}, "", -1))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, -1)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions("", authn.MethodDefault, 0)) assert.Equal(t, 0, DeleteClientSessions(&Client{}, authn.MethodDefault, 0))
// Create 10 test client sessions. // Create 10 test client sessions.
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
@@ -126,11 +126,11 @@ func TestDeleteClientSessions(t *testing.T) {
} }
// Check if the expected number of sessions is deleted until none are left. // Check if the expected number of sessions is deleted until none are left.
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, -1)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOIDC, 1)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOIDC, 1))
assert.Equal(t, 9, DeleteClientSessions(clientUID, authn.MethodOAuth2, 1)) assert.Equal(t, 9, DeleteClientSessions(client, authn.MethodOAuth2, 1))
assert.Equal(t, 1, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0)) assert.Equal(t, 1, DeleteClientSessions(client, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(clientUID, authn.MethodOAuth2, 0)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0))
} }
func TestSessionStatusUnauthorized(t *testing.T) { func TestSessionStatusUnauthorized(t *testing.T) {

View File

@@ -94,6 +94,9 @@ func AddClientFromCli(ctx *cli.Context) Client {
f.AuthScope = "*" f.AuthScope = "*"
} }
f.AuthExpires = ctx.Int64("expires")
f.AuthTokens = ctx.Int64("tokens")
return f return f
} }

View File

@@ -31,6 +31,11 @@ func (t MethodType) IsDefault() bool {
return t.String() == MethodDefault.String() return t.String() == MethodDefault.String()
} }
// IsSession checks if this is the session method.
func (t MethodType) IsSession() bool {
return t.String() == MethodSession.String()
}
// String returns the provider identifier as a string. // String returns the provider identifier as a string.
func (t MethodType) String() string { func (t MethodType) String() string {
switch t { switch t {

View File

@@ -63,6 +63,11 @@ func (t ProviderType) IsClient() bool {
return list.Contains(ClientProviders, string(t)) return list.Contains(ClientProviders, string(t))
} }
// IsApplication checks if the authentication is provided for an application.
func (t ProviderType) IsApplication() bool {
return t == ProviderApplication
}
// IsDefault checks if this is the default provider. // IsDefault checks if this is the default provider.
func (t ProviderType) IsDefault() bool { func (t ProviderType) IsDefault() bool {
return t.String() == ProviderDefault.String() return t.String() == ProviderDefault.String()

View File

@@ -1,5 +1,5 @@
global: global:
scrape_interval: 15s scrape_interval: 60s
scrape_timeout: 10s scrape_timeout: 10s
scrape_configs: scrape_configs:
- job_name: "photoprism" - job_name: "photoprism"