mirror of
https://github.com/gravitl/netmaker.git
synced 2025-11-02 13:04:11 +08:00
add ctx to DB funcs (#3435)
This commit is contained in:
@@ -110,7 +110,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
return
|
||||
}
|
||||
err = req.Create()
|
||||
err = req.Create(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
@@ -140,7 +140,7 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username)
|
||||
logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(r.Context()), "fetched api access tokens for user "+username)
|
||||
}
|
||||
|
||||
// @Summary Authenticate a user to retrieve an authorization token
|
||||
@@ -161,7 +161,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
||||
a := schema.UserAccessToken{
|
||||
ID: id,
|
||||
}
|
||||
err := a.Get()
|
||||
err := a.Get(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
|
||||
return
|
||||
@@ -188,7 +188,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
err = (&schema.UserAccessToken{ID: id}).Delete()
|
||||
err = (&schema.UserAccessToken{ID: id}).Delete(r.Context())
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
@@ -754,7 +754,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
logic.AddGlobalNetRolesToAdmins(&userchange)
|
||||
if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
|
||||
(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens()
|
||||
(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens(r.Context())
|
||||
}
|
||||
user, err = logic.UpdateUser(&userchange, user)
|
||||
if err != nil {
|
||||
|
||||
4
db/db.go
4
db/db.go
@@ -75,6 +75,10 @@ func Middleware(next http.Handler) http.Handler {
|
||||
//
|
||||
// The function panics, if a connection does not exist.
|
||||
func FromContext(ctx context.Context) *gorm.DB {
|
||||
db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
|
||||
if !ok {
|
||||
panic(ErrDBNotFound)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"golang.org/x/exp/slog"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
@@ -361,7 +363,7 @@ func DeleteUser(user string) error {
|
||||
return err
|
||||
}
|
||||
go RemoveUserFromAclPolicy(user)
|
||||
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens()
|
||||
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
|
||||
}
|
||||
|
||||
func SetAuthSecret(secret string) error {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
@@ -8,6 +9,7 @@ import (
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
@@ -127,13 +129,13 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
|
||||
if jti != "" {
|
||||
a := schema.UserAccessToken{ID: jti}
|
||||
// check if access token is active
|
||||
err := a.Get()
|
||||
err := a.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
err = errors.New("token revoked")
|
||||
return "", err
|
||||
}
|
||||
a.LastUsed = time.Now()
|
||||
a.Update()
|
||||
a.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,13 +173,13 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
|
||||
if jti != "" {
|
||||
a := schema.UserAccessToken{ID: jti}
|
||||
// check if access token is active
|
||||
err := a.Get()
|
||||
err := a.Get(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
err = errors.New("token revoked")
|
||||
return "", false, false, err
|
||||
}
|
||||
a.LastUsed = time.Now()
|
||||
a.Update()
|
||||
a.Update(db.WithContext(context.TODO()))
|
||||
}
|
||||
}
|
||||
if token != nil && token.Valid {
|
||||
|
||||
@@ -7,54 +7,46 @@ import (
|
||||
"github.com/gravitl/netmaker/db"
|
||||
)
|
||||
|
||||
// accessTokenTableName - access tokens table
|
||||
const accessTokenTableName = "user_access_tokens"
|
||||
|
||||
// UserAccessToken - token used to access netmaker
|
||||
type UserAccessToken struct {
|
||||
ID string `gorm:"id,primary_key" json:"id"`
|
||||
Name string `gorm:"name" json:"name"`
|
||||
UserName string `gorm:"user_name" json:"user_name"`
|
||||
ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
|
||||
LastUsed time.Time `gorm:"last_used" json:"last_used"`
|
||||
CreatedBy string `gorm:"created_by" json:"created_by"`
|
||||
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `json:"name"`
|
||||
UserName string `json:"user_name"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
LastUsed time.Time `json:"last_used"`
|
||||
CreatedBy string `json:"created_by"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Table() string {
|
||||
return accessTokenTableName
|
||||
func (a *UserAccessToken) Get(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&UserAccessToken{}).First(&a).Where("id = ?", a.ID).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Get() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
|
||||
func (a *UserAccessToken) Update(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Updates(&a).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Update() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
|
||||
func (a *UserAccessToken) Create(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&UserAccessToken{}).Create(&a).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Create() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
|
||||
err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
|
||||
func (a *UserAccessToken) List(ctx context.Context) (ats []UserAccessToken, err error) {
|
||||
err = db.FromContext(ctx).Model(&UserAccessToken{}).Find(&ats).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) {
|
||||
db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
|
||||
func (a *UserAccessToken) ListByUser(ctx context.Context) (ats []UserAccessToken) {
|
||||
db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Find(&ats)
|
||||
if ats == nil {
|
||||
ats = []UserAccessToken{}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) Delete() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
|
||||
func (a *UserAccessToken) Delete(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Delete(&a).Error
|
||||
}
|
||||
|
||||
func (a *UserAccessToken) DeleteAllUserTokens() error {
|
||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error
|
||||
|
||||
func (a *UserAccessToken) DeleteAllUserTokens(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Delete(&a).Error
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user