add ctx to DB funcs (#3435)

This commit is contained in:
Abhishek K
2025-04-29 00:22:02 +04:00
committed by GitHub
parent 119ef4e17e
commit 262803c234
5 changed files with 39 additions and 39 deletions

View File

@@ -110,7 +110,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
)
return
}
err = req.Create()
err = req.Create(r.Context())
if err != nil {
logic.ReturnErrorResponse(
w,
@@ -140,7 +140,7 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
return
}
logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username)
logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(r.Context()), "fetched api access tokens for user "+username)
}
// @Summary Authenticate a user to retrieve an authorization token
@@ -161,7 +161,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
a := schema.UserAccessToken{
ID: id,
}
err := a.Get()
err := a.Get(r.Context())
if err != nil {
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
return
@@ -188,7 +188,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
}
}
err = (&schema.UserAccessToken{ID: id}).Delete()
err = (&schema.UserAccessToken{ID: id}).Delete(r.Context())
if err != nil {
logic.ReturnErrorResponse(
w,
@@ -754,7 +754,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
}
logic.AddGlobalNetRolesToAdmins(&userchange)
if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens()
(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens(r.Context())
}
user, err = logic.UpdateUser(&userchange, user)
if err != nil {

View File

@@ -75,6 +75,10 @@ func Middleware(next http.Handler) http.Handler {
//
// The function panics, if a connection does not exist.
func FromContext(ctx context.Context) *gorm.DB {
db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
if !ok {
panic(ErrDBNotFound)
}
return db
}

View File

@@ -1,6 +1,7 @@
package logic
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
@@ -12,6 +13,7 @@ import (
"golang.org/x/exp/slog"
"github.com/gravitl/netmaker/database"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
@@ -361,7 +363,7 @@ func DeleteUser(user string) error {
return err
}
go RemoveUserFromAclPolicy(user)
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens()
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
}
func SetAuthSecret(secret string) error {

View File

@@ -1,6 +1,7 @@
package logic
import (
"context"
"errors"
"fmt"
"strings"
@@ -8,6 +9,7 @@ import (
"github.com/golang-jwt/jwt/v4"
"github.com/gravitl/netmaker/db"
"github.com/gravitl/netmaker/logger"
"github.com/gravitl/netmaker/models"
"github.com/gravitl/netmaker/schema"
@@ -127,13 +129,13 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
if jti != "" {
a := schema.UserAccessToken{ID: jti}
// check if access token is active
err := a.Get()
err := a.Get(db.WithContext(context.TODO()))
if err != nil {
err = errors.New("token revoked")
return "", err
}
a.LastUsed = time.Now()
a.Update()
a.Update(db.WithContext(context.TODO()))
}
}
@@ -171,13 +173,13 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
if jti != "" {
a := schema.UserAccessToken{ID: jti}
// check if access token is active
err := a.Get()
err := a.Get(db.WithContext(context.TODO()))
if err != nil {
err = errors.New("token revoked")
return "", false, false, err
}
a.LastUsed = time.Now()
a.Update()
a.Update(db.WithContext(context.TODO()))
}
}
if token != nil && token.Valid {

View File

@@ -7,54 +7,46 @@ import (
"github.com/gravitl/netmaker/db"
)
// accessTokenTableName - access tokens table
const accessTokenTableName = "user_access_tokens"
// UserAccessToken - token used to access netmaker
type UserAccessToken struct {
ID string `gorm:"id,primary_key" json:"id"`
Name string `gorm:"name" json:"name"`
UserName string `gorm:"user_name" json:"user_name"`
ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
LastUsed time.Time `gorm:"last_used" json:"last_used"`
CreatedBy string `gorm:"created_by" json:"created_by"`
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
ID string `gorm:"primaryKey" json:"id"`
Name string `json:"name"`
UserName string `json:"user_name"`
ExpiresAt time.Time `json:"expires_at"`
LastUsed time.Time `json:"last_used"`
CreatedBy string `json:"created_by"`
CreatedAt time.Time `json:"created_at"`
}
func (a *UserAccessToken) Table() string {
return accessTokenTableName
func (a *UserAccessToken) Get(ctx context.Context) error {
return db.FromContext(ctx).Model(&UserAccessToken{}).First(&a).Where("id = ?", a.ID).Error
}
func (a *UserAccessToken) Get() error {
return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
func (a *UserAccessToken) Update(ctx context.Context) error {
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Updates(&a).Error
}
func (a *UserAccessToken) Update() error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
func (a *UserAccessToken) Create(ctx context.Context) error {
return db.FromContext(ctx).Model(&UserAccessToken{}).Create(&a).Error
}
func (a *UserAccessToken) Create() error {
return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
}
func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
func (a *UserAccessToken) List(ctx context.Context) (ats []UserAccessToken, err error) {
err = db.FromContext(ctx).Model(&UserAccessToken{}).Find(&ats).Error
return
}
func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) {
db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
func (a *UserAccessToken) ListByUser(ctx context.Context) (ats []UserAccessToken) {
db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Find(&ats)
if ats == nil {
ats = []UserAccessToken{}
}
return
}
func (a *UserAccessToken) Delete() error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
func (a *UserAccessToken) Delete(ctx context.Context) error {
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Delete(&a).Error
}
func (a *UserAccessToken) DeleteAllUserTokens() error {
return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error
func (a *UserAccessToken) DeleteAllUserTokens(ctx context.Context) error {
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Delete(&a).Error
}