mirror of
https://github.com/gravitl/netmaker.git
synced 2025-10-25 09:50:24 +08:00
add ctx to DB funcs (#3435)
This commit is contained in:
@@ -110,7 +110,7 @@ func createUserAccessToken(w http.ResponseWriter, r *http.Request) {
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = req.Create()
|
err = req.Create(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logic.ReturnErrorResponse(
|
logic.ReturnErrorResponse(
|
||||||
w,
|
w,
|
||||||
@@ -140,7 +140,7 @@ func getUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
|
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("username is required"), "badrequest"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(), "fetched api access tokens for user "+username)
|
logic.ReturnSuccessResponseWithJson(w, r, (&schema.UserAccessToken{UserName: username}).ListByUser(r.Context()), "fetched api access tokens for user "+username)
|
||||||
}
|
}
|
||||||
|
|
||||||
// @Summary Authenticate a user to retrieve an authorization token
|
// @Summary Authenticate a user to retrieve an authorization token
|
||||||
@@ -161,7 +161,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
a := schema.UserAccessToken{
|
a := schema.UserAccessToken{
|
||||||
ID: id,
|
ID: id,
|
||||||
}
|
}
|
||||||
err := a.Get()
|
err := a.Get(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
|
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
|
||||||
return
|
return
|
||||||
@@ -188,7 +188,7 @@ func deleteUserAccessTokens(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = (&schema.UserAccessToken{ID: id}).Delete()
|
err = (&schema.UserAccessToken{ID: id}).Delete(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logic.ReturnErrorResponse(
|
logic.ReturnErrorResponse(
|
||||||
w,
|
w,
|
||||||
@@ -754,7 +754,7 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
logic.AddGlobalNetRolesToAdmins(&userchange)
|
logic.AddGlobalNetRolesToAdmins(&userchange)
|
||||||
if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
|
if userchange.PlatformRoleID != user.PlatformRoleID || !logic.CompareMaps(user.UserGroups, userchange.UserGroups) {
|
||||||
(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens()
|
(&schema.UserAccessToken{UserName: user.UserName}).DeleteAllUserTokens(r.Context())
|
||||||
}
|
}
|
||||||
user, err = logic.UpdateUser(&userchange, user)
|
user, err = logic.UpdateUser(&userchange, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
4
db/db.go
4
db/db.go
@@ -75,6 +75,10 @@ func Middleware(next http.Handler) http.Handler {
|
|||||||
//
|
//
|
||||||
// The function panics, if a connection does not exist.
|
// The function panics, if a connection does not exist.
|
||||||
func FromContext(ctx context.Context) *gorm.DB {
|
func FromContext(ctx context.Context) *gorm.DB {
|
||||||
|
db, ok := ctx.Value(dbCtxKey).(*gorm.DB)
|
||||||
|
if !ok {
|
||||||
|
panic(ErrDBNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
return db
|
return db
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package logic
|
package logic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"golang.org/x/exp/slog"
|
"golang.org/x/exp/slog"
|
||||||
|
|
||||||
"github.com/gravitl/netmaker/database"
|
"github.com/gravitl/netmaker/database"
|
||||||
|
"github.com/gravitl/netmaker/db"
|
||||||
"github.com/gravitl/netmaker/logger"
|
"github.com/gravitl/netmaker/logger"
|
||||||
"github.com/gravitl/netmaker/models"
|
"github.com/gravitl/netmaker/models"
|
||||||
"github.com/gravitl/netmaker/schema"
|
"github.com/gravitl/netmaker/schema"
|
||||||
@@ -361,7 +363,7 @@ func DeleteUser(user string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
go RemoveUserFromAclPolicy(user)
|
go RemoveUserFromAclPolicy(user)
|
||||||
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens()
|
return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetAuthSecret(secret string) error {
|
func SetAuthSecret(secret string) error {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package logic
|
package logic
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -8,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
|
||||||
|
"github.com/gravitl/netmaker/db"
|
||||||
"github.com/gravitl/netmaker/logger"
|
"github.com/gravitl/netmaker/logger"
|
||||||
"github.com/gravitl/netmaker/models"
|
"github.com/gravitl/netmaker/models"
|
||||||
"github.com/gravitl/netmaker/schema"
|
"github.com/gravitl/netmaker/schema"
|
||||||
@@ -127,13 +129,13 @@ func GetUserNameFromToken(authtoken string) (username string, err error) {
|
|||||||
if jti != "" {
|
if jti != "" {
|
||||||
a := schema.UserAccessToken{ID: jti}
|
a := schema.UserAccessToken{ID: jti}
|
||||||
// check if access token is active
|
// check if access token is active
|
||||||
err := a.Get()
|
err := a.Get(db.WithContext(context.TODO()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New("token revoked")
|
err = errors.New("token revoked")
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
a.LastUsed = time.Now()
|
a.LastUsed = time.Now()
|
||||||
a.Update()
|
a.Update(db.WithContext(context.TODO()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,13 +173,13 @@ func VerifyUserToken(tokenString string) (username string, issuperadmin, isadmin
|
|||||||
if jti != "" {
|
if jti != "" {
|
||||||
a := schema.UserAccessToken{ID: jti}
|
a := schema.UserAccessToken{ID: jti}
|
||||||
// check if access token is active
|
// check if access token is active
|
||||||
err := a.Get()
|
err := a.Get(db.WithContext(context.TODO()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New("token revoked")
|
err = errors.New("token revoked")
|
||||||
return "", false, false, err
|
return "", false, false, err
|
||||||
}
|
}
|
||||||
a.LastUsed = time.Now()
|
a.LastUsed = time.Now()
|
||||||
a.Update()
|
a.Update(db.WithContext(context.TODO()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if token != nil && token.Valid {
|
if token != nil && token.Valid {
|
||||||
|
|||||||
@@ -7,54 +7,46 @@ import (
|
|||||||
"github.com/gravitl/netmaker/db"
|
"github.com/gravitl/netmaker/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
// accessTokenTableName - access tokens table
|
|
||||||
const accessTokenTableName = "user_access_tokens"
|
|
||||||
|
|
||||||
// UserAccessToken - token used to access netmaker
|
// UserAccessToken - token used to access netmaker
|
||||||
type UserAccessToken struct {
|
type UserAccessToken struct {
|
||||||
ID string `gorm:"id,primary_key" json:"id"`
|
ID string `gorm:"primaryKey" json:"id"`
|
||||||
Name string `gorm:"name" json:"name"`
|
Name string `json:"name"`
|
||||||
UserName string `gorm:"user_name" json:"user_name"`
|
UserName string `json:"user_name"`
|
||||||
ExpiresAt time.Time `gorm:"expires_at" json:"expires_at"`
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
LastUsed time.Time `gorm:"last_used" json:"last_used"`
|
LastUsed time.Time `json:"last_used"`
|
||||||
CreatedBy string `gorm:"created_by" json:"created_by"`
|
CreatedBy string `json:"created_by"`
|
||||||
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserAccessToken) Table() string {
|
func (a *UserAccessToken) Get(ctx context.Context) error {
|
||||||
return accessTokenTableName
|
return db.FromContext(ctx).Model(&UserAccessToken{}).First(&a).Where("id = ?", a.ID).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserAccessToken) Get() error {
|
func (a *UserAccessToken) Update(ctx context.Context) error {
|
||||||
return db.FromContext(context.TODO()).Table(a.Table()).First(&a).Where("id = ?", a.ID).Error
|
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Updates(&a).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserAccessToken) Update() error {
|
func (a *UserAccessToken) Create(ctx context.Context) error {
|
||||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Updates(&a).Error
|
return db.FromContext(ctx).Model(&UserAccessToken{}).Create(&a).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserAccessToken) Create() error {
|
func (a *UserAccessToken) List(ctx context.Context) (ats []UserAccessToken, err error) {
|
||||||
return db.FromContext(context.TODO()).Table(a.Table()).Create(&a).Error
|
err = db.FromContext(ctx).Model(&UserAccessToken{}).Find(&ats).Error
|
||||||
}
|
|
||||||
|
|
||||||
func (a *UserAccessToken) List() (ats []UserAccessToken, err error) {
|
|
||||||
err = db.FromContext(context.TODO()).Table(a.Table()).Find(&ats).Error
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserAccessToken) ListByUser() (ats []UserAccessToken) {
|
func (a *UserAccessToken) ListByUser(ctx context.Context) (ats []UserAccessToken) {
|
||||||
db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ?", a.UserName).Find(&ats)
|
db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Find(&ats)
|
||||||
if ats == nil {
|
if ats == nil {
|
||||||
ats = []UserAccessToken{}
|
ats = []UserAccessToken{}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserAccessToken) Delete() error {
|
func (a *UserAccessToken) Delete(ctx context.Context) error {
|
||||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("id = ?", a.ID).Delete(&a).Error
|
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("id = ?", a.ID).Delete(&a).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserAccessToken) DeleteAllUserTokens() error {
|
func (a *UserAccessToken) DeleteAllUserTokens(ctx context.Context) error {
|
||||||
return db.FromContext(context.TODO()).Table(a.Table()).Where("user_name = ? OR created_by = ?", a.UserName, a.UserName).Delete(&a).Error
|
return db.FromContext(ctx).Model(&UserAccessToken{}).Where("user_name = ?", a.UserName).Delete(&a).Error
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user