add ctx to DB funcs (#3435)

This commit is contained in:
Abhishek K
2025-04-29 00:22:02 +04:00
committed by GitHub
parent 119ef4e17e
commit 262803c234
5 changed files with 39 additions and 39 deletions

View File

@@ -110,7 +110,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
) )
return return
} }
err = req.Create() err = req.Create(r.Context())
if err != nil { if err != nil {
logic.ReturnErrorResponse( logic.ReturnErrorResponse(
w, w,
@@ -140,7 +140,7 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
return return
} }
logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username) logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(r.Context()), "fetched api access tokens for user "+username)
} }
// @Summary Authenticate a user to retrieve an authorization token // @Summary Authenticate a user to retrieve an authorization token
@@ -161,7 +161,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
a := schema.UserAccessToken{ a := schema.UserAccessToken{
ID: id, ID: id,
} }
err := a.Get() err := a.Get(r.Context())
if err != nil { if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest")) logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
return return
@@ -188,7 +188,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
} }
} }
err = (&schema.UserAccessToken{ID: id}).Delete() err = (&schema.UserAccessToken{ID: id}).Delete(r.Context())
if err != nil { if err != nil {
logic.ReturnErrorResponse( logic.ReturnErrorResponse(
w, w,
@@ -754,7 +754,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
} }
logic.AddGlobalNetRolesToAdmins(&userchange) logic.AddGlobalNetRolesToAdmins(&userchange)
if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) { if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens() (&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens(r.Context())
} }
user, err = logic.UpdateUser(&userchange, user) user, err = logic.UpdateUser(&userchange, user)
if err != nil { if err != nil {

View File

@@ -75,6 +75,10 @@ func Middleware(next http.Handler) http.Handler {
// //
// The function panics, if a connection does not exist. // The function panics, if a connection does not exist.
func FromContext(ctx context.Context) *gorm.DB { func FromContext(ctx context.Context) *gorm.DB {
db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
if !ok {
panic(ErrDBNotFound)
}
return db return db
} }

View File

@@ -1,6 +1,7 @@
package logic package logic
import ( import (
"context"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
@@ -12,6 +13,7 @@ import (
"golang.org/x/exp/slog" "golang.org/x/exp/slog"
"github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/schema"
@@ -361,7 +363,7 @@ func DeleteUser(user string) error {
return err return err
} }
go RemoveUserFromAclPolicy(user) go RemoveUserFromAclPolicy(user)
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens() return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
} }
func SetAuthSecret(secret string) error { func SetAuthSecret(secret string) error {

View File

@@ -1,6 +1,7 @@
package logic package logic
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@@ -8,6 +9,7 @@ import (
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/schema"
@@ -127,13 +129,13 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
if jti != "" { if jti != "" {
a := schema.UserAccessToken{ID: jti} a := schema.UserAccessToken{ID: jti}
// check if access token is active // check if access token is active
err := a.Get() err := a.Get(db.WithContext(context.TODO()))
if err != nil { if err != nil {
err = errors.New("token revoked") err = errors.New("token revoked")
return "", err return "", err
} }
a.LastUsed = time.Now() a.LastUsed = time.Now()
a.Update() a.Update(db.WithContext(context.TODO()))
} }
} }
@@ -171,13 +173,13 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
if jti != "" { if jti != "" {
a := schema.UserAccessToken{ID: jti} a := schema.UserAccessToken{ID: jti}
// check if access token is active // check if access token is active
err := a.Get() err := a.Get(db.WithContext(context.TODO()))
if err != nil { if err != nil {
err = errors.New("token revoked") err = errors.New("token revoked")
return "", false, false, err return "", false, false, err
} }
a.LastUsed = time.Now() a.LastUsed = time.Now()
a.Update() a.Update(db.WithContext(context.TODO()))
} }
} }
if token != nil && token.Valid { if token != nil && token.Valid {

View File

@@ -7,54 +7,46 @@ import (
"github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/db"
) )
// accessTokenTableName - access tokens table
const accessTokenTableName = "user_access_tokens"
// UserAccessToken - token used to access netmaker // UserAccessToken - token used to access netmaker
type UserAccessToken struct { type UserAccessToken struct {
ID string `gorm:"id,primary_key" json:"id"` ID string `gorm:"primaryKey" json:"id"`
Name string `gorm:"name" json:"name"` Name string `json:"name"`
UserName string `gorm:"user_name" json:"user_name"` UserName string `json:"user_name"`
ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
LastUsed time.Time `gorm:"last_used" json:"last_used"` LastUsed time.Time `json:"last_used"`
CreatedBy string `gorm:"created_by" json:"created_by"` CreatedBy string `json:"created_by"`
CreatedAt time.Time `gorm:"created_at" json:"created_at"` CreatedAt time.Time `json:"created_at"`
} }
func (a *UserAccessToken) Table() string { func (a *UserAccessToken) Get(ctx context.Context) error {
return accessTokenTableName return db.FromContext(ctx).Model(&UserAccessToken{}).First(&a).Where("id = ?", a.ID).Error
} }
func (a *UserAccessToken) Get() error { func (a *UserAccessToken) Update(ctx context.Context) error {
return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Updates(&a).Error
} }
func (a *UserAccessToken) Update() error { func (a *UserAccessToken) Create(ctx context.Context) error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error return db.FromContext(ctx).Model(&UserAccessToken{}).Create(&a).Error
} }
func (a *UserAccessToken) Create() error { func (a *UserAccessToken) List(ctx context.Context) (ats []UserAccessToken, err error) {
return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error err = db.FromContext(ctx).Model(&UserAccessToken{}).Find(&ats).Error
}
func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
return return
} }
func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) { func (a *UserAccessToken) ListByUser(ctx context.Context) (ats []UserAccessToken) {
db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats) db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Find(&ats)
if ats == nil { if ats == nil {
ats = []UserAccessToken{} ats = []UserAccessToken{}
} }
return return
} }
func (a *UserAccessToken) Delete() error { func (a *UserAccessToken) Delete(ctx context.Context) error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Delete(&a).Error
} }
func (a *UserAccessToken) DeleteAllUserTokens() error { func (a *UserAccessToken) DeleteAllUserTokens(ctx context.Context) error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error return db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Delete(&a).Error
} }