refactor: split query and mutation

This commit is contained in:
VaalaCat
2025-12-06 08:10:03 +00:00
parent 6742264845
commit a578ae7f0c
50 changed files with 652 additions and 256 deletions

View File

@@ -41,7 +41,7 @@ func LoginHandler(ctx *app.Context, req *pb.LoginRequest) (*pb.LoginResponse, er
userEntity.Role = defs.UserRole_Admin
dao.NewQuery(ctx).AdminUpdateUser(&models.UserEntity{
dao.NewMutation(ctx).AdminUpdateUser(&models.UserEntity{
UserID: user.GetUserID(),
}, userEntity.UserEntity)
}

View File

@@ -56,7 +56,7 @@ func RegisterHandler(c *app.Context, req *pb.RegisterRequest) (*pb.RegisterRespo
newUser.Role = defs.UserRole_Admin
}
err = dao.NewQuery(c).CreateUser(newUser)
err = dao.NewMutation(c).CreateUser(newUser)
if err != nil {
return &pb.RegisterResponse{
Status: &pb.Status{Code: pb.RespCode_RESP_CODE_INVALID, Message: err.Error()},

View File

@@ -31,7 +31,7 @@ func InitClientHandler(c *app.Context, req *pb.InitClientRequest) (*pb.InitClien
logger.Logger(c).Infof("start to init client, request:[%s], transformed global client id:[%s]", req.String(), globalClientID)
if err := dao.NewQuery(c).CreateClient(userInfo,
if err := dao.NewMutation(c).CreateClient(userInfo,
&models.ClientEntity{
ClientID: globalClientID,
TenantID: userInfo.GetTenantID(),

View File

@@ -29,11 +29,11 @@ func DeleteClientHandler(ctx *app.Context, req *pb.DeleteClientRequest) (*pb.Del
}, nil
}
if err := dao.NewQuery(ctx).DeleteClient(userInfo, clientID); err != nil {
if err := dao.NewMutation(ctx).DeleteClient(userInfo, clientID); err != nil {
return nil, err
}
if err := dao.NewQuery(ctx).DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, clientID); err != nil {
if err := dao.NewMutation(ctx).DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, clientID); err != nil {
return nil, err
}

View File

@@ -31,7 +31,7 @@ func RemoveFrpcHandler(c *app.Context, req *pb.RemoveFRPCRequest) (*pb.RemoveFRP
return nil, err
}
err = dao.NewQuery(c).DeleteClient(userInfo, clientID)
err = dao.NewMutation(c).DeleteClient(userInfo, clientID)
if err != nil {
logger.Logger(context.Background()).WithError(err).Errorf("cannot delete client, id: [%s]", clientID)
return nil, err

View File

@@ -42,6 +42,7 @@ func ValidateClientRequest(ctx *app.Context, req ValidateableClientRequest) (*mo
func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.ClientEntity) (*models.ClientEntity, error) {
userInfo := common.GetUserInfo(c)
m := dao.NewMutation(c)
var clientID = clientEntity.ClientID
var childClient *models.ClientEntity
@@ -53,7 +54,7 @@ func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.Cl
return nil, err
}
if err := dao.NewQuery(c).RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: childClient}); err != nil {
if err := m.RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: childClient}); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot rebuild proxy config from client, id: [%s]", childClient.ClientID)
return nil, err
}
@@ -62,12 +63,12 @@ func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.Cl
clientEntity.IsShadow = true
clientEntity.ConfigContent = nil
clientEntity.ServerID = ""
if err := dao.NewQuery(c).UpdateClient(userInfo, clientEntity); err != nil {
if err := m.UpdateClient(userInfo, clientEntity); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot update client, id: [%s]", clientID)
return nil, err
}
if err := dao.NewQuery(c).DeleteProxyConfigsByClientID(userInfo, clientID); err != nil {
if err := m.DeleteProxyConfigsByClientID(userInfo, clientID); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot delete proxy configs, id: [%s]", clientID)
return nil, err
}
@@ -79,13 +80,15 @@ func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.Cl
// bool 表示是否新建
func ChildClientForServer(c *app.Context, serverID string, clientEntity *models.ClientEntity) (*models.ClientEntity, bool, error) {
userInfo := common.GetUserInfo(c)
q := dao.NewQuery(c)
m := dao.NewMutation(c)
originClientID := clientEntity.ClientID
if len(clientEntity.OriginClientID) != 0 {
originClientID = clientEntity.OriginClientID
}
existClient, err := dao.NewQuery(c).GetClientByFilter(userInfo, &models.ClientEntity{
existClient, err := q.GetClientByFilter(userInfo, &models.ClientEntity{
ServerID: serverID,
OriginClientID: originClientID,
}, lo.ToPtr(false))
@@ -93,7 +96,7 @@ func ChildClientForServer(c *app.Context, serverID string, clientEntity *models.
return existClient, false, nil
}
shadowCount, err := dao.NewQuery(c).CountClientsInShadow(userInfo, originClientID)
shadowCount, err := q.CountClientsInShadow(userInfo, originClientID)
if err != nil {
logger.Logger(c).WithError(err).Errorf("cannot count shadow clients, id: [%s]", originClientID)
return nil, false, err
@@ -109,7 +112,7 @@ func ChildClientForServer(c *app.Context, serverID string, clientEntity *models.
copiedClient.OriginClientID = originClientID
copiedClient.IsShadow = false
copiedClient.Stopped = false
if err := dao.NewQuery(c).CreateClient(userInfo, copiedClient); err != nil {
if err := m.CreateClient(userInfo, copiedClient); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot create child client, id: [%s]", copiedClient.ClientID)
return nil, false, err
}

View File

@@ -21,7 +21,7 @@ func RPCPullConfig(ctx *app.Context, req *pb.PullClientConfigReq) (*pb.PullClien
return nil, err
}
if err := dao.NewQuery(ctx).AdminUpdateClientLastSeen(cli.ClientID); err != nil {
if err := dao.NewMutation(ctx).AdminUpdateClientLastSeen(cli.ClientID); err != nil {
logger.Logger(ctx).WithError(err).Errorf("update client last_seen_at time error, req:[%s] clientId:[%s]",
req.String(), cli.ClientID)
}

View File

@@ -38,7 +38,7 @@ func StartFRPCHandler(ctx *app.Context, req *pb.StartFRPCRequest) (*pb.StartFRPC
client.Stopped = false
if err = dao.NewQuery(ctx).UpdateClient(userInfo, client); err != nil {
if err = dao.NewMutation(ctx).UpdateClient(userInfo, client); err != nil {
return nil, err
}

View File

@@ -38,7 +38,7 @@ func StopFRPCHandler(ctx *app.Context, req *pb.StopFRPCRequest) (*pb.StopFRPCRes
client.Stopped = true
if err = dao.NewQuery(ctx).UpdateClient(userInfo, client); err != nil {
if err = dao.NewMutation(ctx).UpdateClient(userInfo, client); err != nil {
return nil, err
}

View File

@@ -32,7 +32,7 @@ func SyncTunnel(ctx *app.Context, userInfo models.UserInfo) error {
return
}
if err := dao.NewQuery(ctx).UpdateClient(userInfo, cli); err != nil {
if err := dao.NewMutation(ctx).UpdateClient(userInfo, cli); err != nil {
logger.Logger(context.Background()).WithError(err).Errorf("cannot update client, id: [%s]", cli.ClientID)
return
}

View File

@@ -27,6 +27,8 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC
reqClientID = req.GetClientId() // may be shadow or child
userInfo = common.GetUserInfo(c)
)
q := dao.NewQuery(c)
m := dao.NewMutation(c)
cliCfg, err := utils.LoadClientConfigNormal(content, true)
if err != nil {
@@ -36,7 +38,7 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC
}, err
}
cliRecord, err := dao.NewQuery(c).GetClientByClientID(userInfo, reqClientID)
cliRecord, err := q.GetClientByClientID(userInfo, reqClientID)
if err != nil {
logger.Logger(c).WithError(err).Errorf("cannot get client, id: [%s]", reqClientID)
return &pb.UpdateFRPCResponse{
@@ -72,7 +74,7 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC
}
}
srv, err := dao.NewQuery(c).GetServerByServerID(userInfo, req.GetServerId())
srv, err := q.GetServerByServerID(userInfo, req.GetServerId())
if err != nil || srv == nil || len(srv.ServerIP) == 0 || len(srv.ConfigContent) == 0 {
logger.Logger(c).WithError(err).Errorf("cannot get server, server is not prepared, id: [%s]", req.GetServerId())
return &pb.UpdateFRPCResponse{
@@ -169,12 +171,12 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC
cli.Comment = req.GetComment()
}
if err := dao.NewQuery(c).UpdateClient(userInfo, cli); err != nil {
if err := m.UpdateClient(userInfo, cli); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot update client, id: [%s]", cli.ClientID)
return nil, err
}
if err := dao.NewQuery(c).RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: cli}); err != nil {
if err := m.RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: cli}); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot rebuild proxy config from client, id: [%s]", cli.ClientID)
return nil, err
}

View File

@@ -20,12 +20,14 @@ func DeleteProxyConfig(c *app.Context, req *pb.DeleteProxyConfigRequest) (*pb.De
serverID = req.GetServerId()
proxyName = req.GetName()
)
q := dao.NewQuery(c)
m := dao.NewMutation(c)
if len(clientID) == 0 || len(serverID) == 0 || len(proxyName) == 0 {
return nil, fmt.Errorf("request invalid")
}
cli, err := dao.NewQuery(c).GetClientByClientID(userInfo, clientID)
cli, err := q.GetClientByClientID(userInfo, clientID)
if err != nil {
logger.Logger(c).WithError(err).Errorf("cannot get client, id: [%s]", clientID)
return nil, err
@@ -49,7 +51,7 @@ func DeleteProxyConfig(c *app.Context, req *pb.DeleteProxyConfigRequest) (*pb.De
return nil, err
}
if err := dao.NewQuery(c).UpdateClient(userInfo, cli.ClientEntity); err != nil {
if err := m.UpdateClient(userInfo, cli.ClientEntity); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot update client, id: [%s]", clientID)
return nil, err
}
@@ -72,7 +74,7 @@ func DeleteProxyConfig(c *app.Context, req *pb.DeleteProxyConfigRequest) (*pb.De
return nil, err
}
if err := dao.NewQuery(c).DeleteProxyConfig(userInfo, clientID, proxyName); err != nil {
if err := m.DeleteProxyConfig(userInfo, clientID, proxyName); err != nil {
logger.Logger(c).WithError(err).Errorf("cannot delete proxy config, id: [%s]", clientID)
return nil, err
}

View File

@@ -45,7 +45,7 @@ func StartProxy(ctx *app.Context, req *pb.StartProxyRequest) (*pb.StartProxyResp
// 1. 更新proxy状态
proxyConfig.Stopped = false
err = dao.NewQuery(ctx).UpdateProxyConfig(userInfo, proxyConfig)
err = dao.NewMutation(ctx).UpdateProxyConfig(userInfo, proxyConfig)
if err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot update proxy config, client: [%s], server: [%s], proxy name: [%s]", clientID, serverID, proxyName)
return nil, err

View File

@@ -44,7 +44,7 @@ func StopProxy(ctx *app.Context, req *pb.StopProxyRequest) (*pb.StopProxyRespons
// 1. 更新proxy状态
proxyConfig.Stopped = true
err = dao.NewQuery(ctx).UpdateProxyConfig(userInfo, proxyConfig)
err = dao.NewMutation(ctx).UpdateProxyConfig(userInfo, proxyConfig)
if err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot update proxy config, client: [%s], server: [%s], proxy name: [%s]", clientID, serverID, proxyName)
return nil, err

View File

@@ -41,7 +41,7 @@ func CollectDailyStats(appInstance app.Application) error {
}
})
if err := dao.NewQuery(ctx).AdminMSaveTodyStats(tx, proxyDailyStats); err != nil {
if err := dao.NewMutation(ctx).AdminMSaveTodyStats(tx, proxyDailyStats); err != nil {
logger.Logger(context.Background()).WithError(err).Error("CollectDailyStats cannot save stats")
return err
}

View File

@@ -27,8 +27,10 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up
clientID = req.GetClientId()
serverID = req.GetServerId()
)
q := dao.NewQuery(c)
m := dao.NewMutation(c)
cli, err := dao.NewQuery(c).GetClientByClientID(userInfo, clientID)
cli, err := q.GetClientByClientID(userInfo, clientID)
if err != nil {
logger.Logger(c).WithError(err).Errorf("cannot get client, id: [%s]", clientID)
return nil, err
@@ -38,7 +40,7 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up
if clientEntity.ServerID != serverID {
logger.Logger(c).Errorf("client and server not match, find or create client, client: [%s], server: [%s]", clientID, serverID)
originClient, err := dao.NewQuery(c).GetClientByClientID(userInfo, clientEntity.OriginClientID)
originClient, err := q.GetClientByClientID(userInfo, clientEntity.OriginClientID)
if err != nil {
logger.Logger(c).WithError(err).Errorf("cannot get origin client, id: [%s]", clientEntity.OriginClientID)
return nil, err
@@ -51,7 +53,7 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up
}
}
_, err = dao.NewQuery(c).GetServerByServerID(userInfo, serverID)
_, err = q.GetServerByServerID(userInfo, serverID)
if err != nil {
logger.Logger(c).WithError(err).Errorf("cannot get server, id: [%s]", serverID)
return nil, err
@@ -84,13 +86,13 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up
return nil, err
}
oldProxyCfg, err := dao.NewQuery(c).GetProxyConfigByOriginClientIDAndName(userInfo, clientID, proxyCfg.Name)
oldProxyCfg, err := q.GetProxyConfigByOriginClientIDAndName(userInfo, clientID, proxyCfg.Name)
if err != nil {
logger.Logger(c).WithError(err).Errorf("cannot get proxy config, id: [%s]", clientID)
return nil, err
}
if dao.NewQuery(c).UpdateProxyConfig(userInfo, &models.ProxyConfig{
if m.UpdateProxyConfig(userInfo, &models.ProxyConfig{
Model: oldProxyCfg.Model,
ProxyConfigEntity: proxyCfg,
}) != nil {

View File

@@ -31,7 +31,7 @@ func InitServerHandler(c *app.Context, req *pb.InitServerRequest) (*pb.InitServe
globalServerID := app.GlobalClientID(userInfo.GetUserName(), "s", userServerID)
if err := dao.NewQuery(c).CreateServer(userInfo,
if err := dao.NewMutation(c).CreateServer(userInfo,
&models.ServerEntity{
ServerID: globalServerID,
TenantID: userInfo.GetTenantID(),

View File

@@ -25,7 +25,7 @@ func DeleteServerHandler(c *app.Context, req *pb.DeleteServerRequest) (*pb.Delet
}, nil
}
if err := dao.NewQuery(c).DeleteServer(userInfo, userServerID); err != nil {
if err := dao.NewMutation(c).DeleteServer(userInfo, userServerID); err != nil {
return nil, err
}

View File

@@ -25,7 +25,7 @@ func RemoveFrpsHandler(c *app.Context, req *pb.RemoveFRPSRequest) (*pb.RemoveFRP
return nil, err
}
if err = dao.NewQuery(c).DeleteServer(userInfo, serverID); err != nil {
if err = dao.NewMutation(c).DeleteServer(userInfo, serverID); err != nil {
logger.Logger(context.Background()).WithError(err).Errorf("cannot delete server, id: [%s]", serverID)
return nil, err
}

View File

@@ -15,7 +15,7 @@ func PushProxyInfo(ctx *app.Context, req *pb.PushProxyInfoReq) (*pb.PushProxyInf
return nil, err
}
if err = dao.NewQuery(ctx).AdminUpdateProxyStats(srv, req.GetProxyInfos()); err != nil {
if err = dao.NewMutation(ctx).AdminUpdateProxyStats(srv, req.GetProxyInfos()); err != nil {
return nil, err
}
return &pb.PushProxyInfoResp{

View File

@@ -27,7 +27,8 @@ func UpdateFrpsHander(c *app.Context, req *pb.UpdateFRPSRequest) (*pb.UpdateFRPS
return nil, fmt.Errorf("request invalid")
}
srv, err := dao.NewQuery(c).GetServerByServerID(userInfo, serverID)
q := dao.NewQuery(c)
srv, err := q.GetServerByServerID(userInfo, serverID)
if srv == nil || err != nil {
logger.Logger(context.Background()).WithError(err).Errorf("cannot get server, id: [%s]", serverID)
return nil, err
@@ -55,7 +56,7 @@ func UpdateFrpsHander(c *app.Context, req *pb.UpdateFRPSRequest) (*pb.UpdateFRPS
srv.FrpsUrls = req.GetFrpsUrls()
}
if err := dao.NewQuery(c).UpdateServer(userInfo, srv); err != nil {
if err := dao.NewMutation(c).UpdateServer(userInfo, srv); err != nil {
logger.Logger(context.Background()).WithError(err).Errorf("cannot update server, id: [%s]", serverID)
return nil, err
}

View File

@@ -47,7 +47,7 @@ func UpdateUserInfoHander(c *app.Context, req *pb.UpdateUserInfoRequest) (*pb.Up
newUserEntity.Token = newUserInfo.GetToken()
}
if err := dao.NewQuery(c).UpdateUser(userInfo, newUserEntity); err != nil {
if err := dao.NewMutation(c).UpdateUser(userInfo, newUserEntity); err != nil {
return &pb.UpdateUserInfoResponse{
Status: &pb.Status{Code: pb.RespCode_RESP_CODE_INVALID, Message: err.Error()},
}, err

View File

@@ -21,7 +21,7 @@ func CreateEndpoint(ctx *app.Context, req *pb.CreateEndpointRequest) (*pb.Create
}
entity := &models.EndpointEntity{Host: e.GetHost(), Port: e.GetPort(), ClientID: e.GetClientId(), Type: e.GetType(), Uri: e.GetUri()}
if err := dao.NewQuery(ctx).CreateEndpoint(userInfo, entity); err != nil {
if err := dao.NewMutation(ctx).CreateEndpoint(userInfo, entity); err != nil {
return nil, err
}
return &pb.CreateEndpointResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}, Endpoint: &pb.Endpoint{Id: 0, Host: entity.Host, Port: entity.Port, ClientId: entity.ClientID}}, nil

View File

@@ -18,7 +18,7 @@ func DeleteEndpoint(ctx *app.Context, req *pb.DeleteEndpointRequest) (*pb.Delete
if id == 0 {
return nil, errors.New("invalid id")
}
if err := dao.NewQuery(ctx).DeleteEndpoint(userInfo, id); err != nil {
if err := dao.NewMutation(ctx).DeleteEndpoint(userInfo, id); err != nil {
return nil, err
}
return &pb.DeleteEndpointResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil

View File

@@ -19,7 +19,10 @@ func UpdateEndpoint(ctx *app.Context, req *pb.UpdateEndpointRequest) (*pb.Update
return nil, errors.New("invalid endpoint params")
}
oldEndpoint, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, uint(e.GetId()))
q := dao.NewQuery(ctx)
m := dao.NewMutation(ctx)
oldEndpoint, err := q.GetEndpointByID(userInfo, uint(e.GetId()))
if err != nil {
return nil, err
}
@@ -37,7 +40,7 @@ func UpdateEndpoint(ctx *app.Context, req *pb.UpdateEndpointRequest) (*pb.Update
oldEndpoint.Type = e.GetType()
}
if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, uint(e.GetId()), oldEndpoint.EndpointEntity); err != nil {
if err := m.UpdateEndpoint(userInfo, uint(e.GetId()), oldEndpoint.EndpointEntity); err != nil {
return nil, err
}

View File

@@ -23,6 +23,8 @@ func CreateWireGuardLink(ctx *app.Context, req *pb.CreateWireGuardLinkRequest) (
}
// 校验两端属于同一 network
q := dao.NewQuery(ctx)
mut := dao.NewMutation(ctx)
from, err := q.GetWireGuardByID(userInfo, uint(l.GetFromWireguardId()))
if err != nil {
return nil, err
@@ -43,7 +45,7 @@ func CreateWireGuardLink(ctx *app.Context, req *pb.CreateWireGuardLinkRequest) (
reverse.NetworkID = from.NetworkID
reverse.ToEndpointID = 0
if err := q.CreateWireGuardLinks(userInfo, m, reverse); err != nil {
if err := mut.CreateWireGuardLinks(userInfo, m, reverse); err != nil {
return nil, err
}
return &pb.CreateWireGuardLinkResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}, WireguardLink: m.ToPB()}, nil
@@ -59,6 +61,8 @@ func UpdateWireGuardLink(ctx *app.Context, req *pb.UpdateWireGuardLinkRequest) (
return nil, errors.New("invalid link params")
}
q := dao.NewQuery(ctx)
mut := dao.NewMutation(ctx)
m, err := q.GetWireGuardLinkByID(userInfo, uint(l.GetId()))
if err != nil {
return nil, err
@@ -73,7 +77,7 @@ func UpdateWireGuardLink(ctx *app.Context, req *pb.UpdateWireGuardLinkRequest) (
m.ToEndpointID = uint(l.GetToEndpoint().GetId())
}
if err := q.UpdateWireGuardLink(userInfo, uint(l.GetId()), m); err != nil {
if err := mut.UpdateWireGuardLink(userInfo, uint(l.GetId()), m); err != nil {
return nil, err
}
return &pb.UpdateWireGuardLinkResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}, WireguardLink: m.ToPB()}, nil
@@ -89,21 +93,24 @@ func DeleteWireGuardLink(ctx *app.Context, req *pb.DeleteWireGuardLinkRequest) (
return nil, errors.New("invalid id")
}
link, err := dao.NewQuery(ctx).GetWireGuardLinkByID(userInfo, id)
q := dao.NewQuery(ctx)
m := dao.NewMutation(ctx)
link, err := q.GetWireGuardLinkByID(userInfo, id)
if err != nil {
return nil, err
}
rev, err := dao.NewQuery(ctx).GetWireGuardLinkByClientIDs(userInfo, link.ToWireGuardID, link.FromWireGuardID)
rev, err := q.GetWireGuardLinkByClientIDs(userInfo, link.ToWireGuardID, link.FromWireGuardID)
if err != nil {
return nil, err
}
if err := dao.NewQuery(ctx).DeleteWireGuardLink(userInfo, uint(link.ID)); err != nil {
if err := m.DeleteWireGuardLink(userInfo, uint(link.ID)); err != nil {
return nil, err
}
if err := dao.NewQuery(ctx).DeleteWireGuardLink(userInfo, uint(rev.ID)); err != nil {
if err := m.DeleteWireGuardLink(userInfo, uint(rev.ID)); err != nil {
return nil, err
}

View File

@@ -36,7 +36,7 @@ func CreateNetwork(ctx *app.Context, req *pb.CreateNetworkRequest) (*pb.CreateNe
ACL: models.JSON[*pb.AclConfig]{Data: req.GetNetwork().GetAcl()},
}
if err := dao.NewQuery(ctx).CreateNetwork(userInfo, entity); err != nil {
if err := dao.NewMutation(ctx).CreateNetwork(userInfo, entity); err != nil {
log.WithError(err).Errorf("create network error")
return nil, err
}

View File

@@ -18,7 +18,7 @@ func DeleteNetwork(ctx *app.Context, req *pb.DeleteNetworkRequest) (*pb.DeleteNe
if id == 0 {
return nil, errors.New("invalid id")
}
if err := dao.NewQuery(ctx).DeleteNetwork(userInfo, id); err != nil {
if err := dao.NewMutation(ctx).DeleteNetwork(userInfo, id); err != nil {
return nil, err
}
return &pb.DeleteNetworkResponse{

View File

@@ -20,7 +20,7 @@ func UpdateNetwork(ctx *app.Context, req *pb.UpdateNetworkRequest) (*pb.UpdateNe
return nil, errors.New("invalid network")
}
entity := &models.NetworkEntity{Name: n.GetName(), CIDR: n.GetCidr(), ACL: models.JSON[*pb.AclConfig]{Data: n.GetAcl()}}
if err := dao.NewQuery(ctx).UpdateNetwork(userInfo, uint(n.GetId()), entity); err != nil {
if err := dao.NewMutation(ctx).UpdateNetwork(userInfo, uint(n.GetId()), entity); err != nil {
return nil, err
}

View File

@@ -28,14 +28,16 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea
if cfg == nil || len(cfg.GetClientId()) == 0 || len(cfg.GetInterfaceName()) == 0 || len(cfg.GetLocalAddress()) == 0 {
return nil, errors.New("invalid wireguard config")
}
q := dao.NewQuery(ctx)
m := dao.NewMutation(ctx)
ips, err := dao.NewQuery(ctx).GetWireGuardLocalAddressesByNetworkID(userInfo, uint(cfg.GetNetworkId()))
ips, err := q.GetWireGuardLocalAddressesByNetworkID(userInfo, uint(cfg.GetNetworkId()))
if err != nil {
log.WithError(err).Errorf("get wireguard local addresses by network id failed")
return nil, err
}
network, err := dao.NewQuery(ctx).GetNetworkByID(userInfo, uint(cfg.GetNetworkId()))
network, err := q.GetNetworkByID(userInfo, uint(cfg.GetNetworkId()))
if err != nil {
log.WithError(err).Errorf("get network by id failed")
return nil, err
@@ -72,7 +74,7 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea
log.Debugf("create wireguard with config: %+v", wgModel)
if err := dao.NewQuery(ctx).CreateWireGuard(userInfo, wgModel); err != nil {
if err := m.CreateWireGuard(userInfo, wgModel); err != nil {
return nil, err
}
@@ -83,7 +85,7 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea
}
if ep.GetId() > 0 {
// 复用现有 endpoint要求归属同一 client
exist, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, uint(ep.GetId()))
exist, err := q.GetEndpointByID(userInfo, uint(ep.GetId()))
if err != nil {
return nil, err
}
@@ -92,7 +94,7 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea
}
exist.WireGuardID = wgModel.ID
if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil {
if err := m.UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil {
return nil, err
}
} else {
@@ -101,19 +103,19 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea
newEp.FromPB(ep)
newEp.ClientID = cfg.GetClientId()
newEp.WireGuardID = wgModel.ID
if err := dao.NewQuery(ctx).CreateEndpoint(userInfo, newEp.EndpointEntity); err != nil {
if err := m.CreateEndpoint(userInfo, newEp.EndpointEntity); err != nil {
return nil, err
}
}
}
go func() {
peers, err := dao.NewQuery(ctx).GetWireGuardsByNetworkID(userInfo, uint(cfg.GetNetworkId()))
peers, err := q.GetWireGuardsByNetworkID(userInfo, uint(cfg.GetNetworkId()))
if err != nil {
log.WithError(err).Errorf("get wireguards by network id failed")
return
}
links, err := dao.NewQuery(ctx).ListWireGuardLinksByNetwork(userInfo, uint(cfg.GetNetworkId()))
links, err := q.ListWireGuardLinksByNetwork(userInfo, uint(cfg.GetNetworkId()))
if err != nil {
log.WithError(err).Errorf("get wireguard links by network id failed")
return

View File

@@ -23,13 +23,16 @@ func DeleteWireGuard(ctx *app.Context, req *pb.DeleteWireGuardRequest) (*pb.Dele
return &pb.DeleteWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_INVALID, Message: "invalid id"}}, nil
}
wgToDelete, err := dao.NewQuery(ctx).GetWireGuardByID(userInfo, id)
q := dao.NewQuery(ctx)
m := dao.NewMutation(ctx)
wgToDelete, err := q.GetWireGuardByID(userInfo, id)
if err != nil {
log.WithError(err).Errorf("get wireguard by id failed")
return nil, err
}
if err := dao.NewQuery(ctx).DeleteWireGuard(userInfo, id); err != nil {
if err := m.DeleteWireGuard(userInfo, id); err != nil {
log.WithError(err).Errorf("delete wireguard failed")
return nil, err
}
@@ -51,7 +54,7 @@ func DeleteWireGuard(ctx *app.Context, req *pb.DeleteWireGuardRequest) (*pb.Dele
log.Errorf("cannot get response, client id: [%s]", wgToDelete.ClientID)
}
peers, err := dao.NewQuery(ctx).GetWireGuardsByNetworkID(userInfo, uint(wgToDelete.NetworkID))
peers, err := q.GetWireGuardsByNetworkID(userInfo, uint(wgToDelete.NetworkID))
if err != nil {
log.WithError(err).Errorf("get wireguards by network id failed")
return
@@ -62,7 +65,7 @@ func DeleteWireGuard(ctx *app.Context, req *pb.DeleteWireGuardRequest) (*pb.Dele
return
}
links, err := dao.NewQuery(ctx).ListWireGuardLinksByNetwork(userInfo, uint(wgToDelete.NetworkID))
links, err := q.ListWireGuardLinksByNetwork(userInfo, uint(wgToDelete.NetworkID))
if err != nil {
log.WithError(err).Errorf("get wireguard links by network id failed")
return

View File

@@ -20,19 +20,21 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda
if cfg == nil || cfg.GetId() == 0 || len(cfg.GetClientId()) == 0 || len(cfg.GetInterfaceName()) == 0 || len(cfg.GetPrivateKey()) == 0 || len(cfg.GetLocalAddress()) == 0 {
return nil, errors.New("invalid wireguard config")
}
q := dao.NewQuery(ctx)
m := dao.NewMutation(ctx)
model := &models.WireGuard{}
model.FromPB(cfg)
model.UserId = uint32(userInfo.GetUserID())
model.TenantId = uint32(userInfo.GetTenantID())
if err := dao.NewQuery(ctx).UpdateWireGuard(userInfo, uint(cfg.GetId()), model); err != nil {
if err := m.UpdateWireGuard(userInfo, uint(cfg.GetId()), model); err != nil {
return nil, err
}
// 端点赋值与解绑
// 1) 读取当前绑定的端点
currentList, err := dao.NewQuery(ctx).ListEndpointsWithFilters(userInfo, 1, 1000, "", uint(cfg.GetId()), "")
currentList, err := q.ListEndpointsWithFilters(userInfo, 1, 1000, "", uint(cfg.GetId()), "")
if err != nil {
return nil, err
}
@@ -45,7 +47,7 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda
continue
}
if ep.GetId() > 0 {
exist, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, uint(ep.GetId()))
exist, err := q.GetEndpointByID(userInfo, uint(ep.GetId()))
if err != nil {
return nil, err
}
@@ -55,13 +57,13 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda
}
exist.WireGuardID = uint(cfg.GetId())
if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil {
if err := m.UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil {
return nil, err
}
newSet[uint(exist.ID)] = struct{}{}
} else {
entity := &models.EndpointEntity{Host: ep.GetHost(), Port: ep.GetPort(), ClientID: cfg.GetClientId(), WireGuardID: uint(cfg.GetId())}
if err := dao.NewQuery(ctx).CreateEndpoint(userInfo, entity); err != nil {
if err := m.CreateEndpoint(userInfo, entity); err != nil {
return nil, err
}
// 无法获取新建 id这里不加入 newSet不影响后续解绑逻辑仅解绑 current - new
@@ -73,12 +75,12 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda
if _, ok := newSet[id]; ok {
continue
}
exist, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, id)
exist, err := q.GetEndpointByID(userInfo, id)
if err != nil {
return nil, err
}
entity := &models.EndpointEntity{Host: exist.Host, Port: exist.Port, ClientID: exist.ClientID, WireGuardID: 0}
if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, id, entity); err != nil {
if err := m.UpdateEndpoint(userInfo, id, entity); err != nil {
return nil, err
}
}

View File

@@ -19,13 +19,15 @@ func CreateWorker(ctx *app.Context, req *pb.CreateWorkerRequest) (*pb.CreateWork
clientId = req.GetClientId()
reqWorker = req.GetWorker()
)
q := dao.NewQuery(ctx)
m := dao.NewMutation(ctx)
if err := validateCreateWorker(req); err != nil {
logger.Logger(ctx).WithError(err).Errorf("invalid create worker request, origin is: [%s]", req.String())
return nil, err
}
cli, err := dao.NewQuery(ctx).GetClientByClientID(userInfo, clientId)
cli, err := q.GetClientByClientID(userInfo, clientId)
if err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot get client, id: [%s], workerName: [%s]", clientId, reqWorker.GetName())
return nil, err
@@ -38,7 +40,7 @@ func CreateWorker(ctx *app.Context, req *pb.CreateWorkerRequest) (*pb.CreateWork
workerToCreate.Clients = append(workerToCreate.Clients, *cli)
if err := dao.NewQuery(ctx).CreateWorker(userInfo, workerToCreate); err != nil {
if err := m.CreateWorker(userInfo, workerToCreate); err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot create worker, workerName: [%s]", workerToCreate.Name)
return nil, err
}

View File

@@ -19,13 +19,16 @@ func RemoveWorker(ctx *app.Context, req *pb.RemoveWorkerRequest) (*pb.RemoveWork
logger.Logger(ctx).Infof("start remove worker, id: [%s]", workerId)
workerToDelete, err := dao.NewQuery(ctx).GetWorkerByWorkerID(userInfo, workerId)
q := dao.NewQuery(ctx)
mut := dao.NewMutation(ctx)
workerToDelete, err := q.GetWorkerByWorkerID(userInfo, workerId)
if err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot get worker, id: [%s]", workerId)
return nil, err
}
if ingressesToDelete, err := dao.NewQuery(ctx).GetProxyConfigsByWorkerId(userInfo, workerId); err == nil {
if ingressesToDelete, err := q.GetProxyConfigsByWorkerId(userInfo, workerId); err == nil {
for _, ingressToDelete := range ingressesToDelete {
logger.Logger(ctx).Infof("start to remove worker ingress on server: [%s] client: [%s], name: [%s]", ingressToDelete.ServerID, ingressToDelete.ClientID, ingressToDelete.Name)
@@ -44,7 +47,7 @@ func RemoveWorker(ctx *app.Context, req *pb.RemoveWorkerRequest) (*pb.RemoveWork
}
}
if err := dao.NewQuery(ctx).DeleteWorker(userInfo, workerId); err != nil {
if err := mut.DeleteWorker(userInfo, workerId); err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot remove worker, id: [%s]", workerId)
return nil, err
}

View File

@@ -22,7 +22,10 @@ func UpdateWorker(ctx *app.Context, req *pb.UpdateWorkerRequest) (*pb.UpdateWork
oldClientIds []string
)
workerToUpdate, err := dao.NewQuery(ctx).GetWorkerByWorkerID(userInfo, wrokerReq.GetWorkerId())
q := dao.NewQuery(ctx)
m := dao.NewMutation(ctx)
workerToUpdate, err := q.GetWorkerByWorkerID(userInfo, wrokerReq.GetWorkerId())
if err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot get worker, id: [%s]", wrokerReq.GetWorkerId())
return nil, fmt.Errorf("cannot get worker, id: [%s]", wrokerReq.GetWorkerId())
@@ -30,7 +33,7 @@ func UpdateWorker(ctx *app.Context, req *pb.UpdateWorkerRequest) (*pb.UpdateWork
clis := []*models.Client{}
if len(clientIds) != 0 {
clis, err = dao.NewQuery(ctx).GetClientsByClientIDs(userInfo, clientIds)
clis, err = q.GetClientsByClientIDs(userInfo, clientIds)
if err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot get client, id: [%s]", utils.MarshalForJson(clientIds))
return nil, fmt.Errorf("cannot get client, id: [%s]", utils.MarshalForJson(clientIds))
@@ -64,7 +67,7 @@ func UpdateWorker(ctx *app.Context, req *pb.UpdateWorkerRequest) (*pb.UpdateWork
updatedFields = append(updatedFields, "config_template")
}
if err := dao.NewQuery(ctx).UpdateWorker(userInfo, workerToUpdate); err != nil {
if err := m.UpdateWorker(userInfo, workerToUpdate); err != nil {
logger.Logger(ctx).WithError(err).Errorf("cannot update worker, id: [%s]", wrokerReq.GetWorkerId())
return nil, fmt.Errorf("cannot update worker, id: [%s]", wrokerReq.GetWorkerId())
}

View File

@@ -151,7 +151,7 @@ func NewDBManager(ctx *app.Context, appInstance app.Application) app.DBManager {
}
func NewMasterTLSConfig(ctx *app.Context) *tls.Config {
return dao.NewQuery(ctx).InitCert(conf.GetCertTemplate(ctx.GetApp().GetConfig()))
return dao.NewMutation(ctx).InitCert(conf.GetCertTemplate(ctx.GetApp().GetConfig()))
}
func NewTLSMasterService(appInstance app.Application, masterTLSConfig *tls.Config) master.MasterService {
@@ -250,7 +250,7 @@ func NewDefaultServerConfig(ctx *app.Context) conf.Config {
logger.Logger(ctx).Infof("init default internal server")
dao.NewQuery(ctx).InitDefaultServer(appInstance.GetConfig().Master.APIHost)
dao.NewMutation(ctx).InitDefaultServer(appInstance.GetConfig().Master.APIHost)
defaultServer, err := dao.NewQuery(ctx).GetDefaultServer()
if err != nil {

View File

@@ -14,13 +14,30 @@ import (
"github.com/VaalaCat/frp-panel/utils/logger"
)
func (q *queryImpl) InitCert(template *x509.Certificate) *tls.Config {
type CertQuery interface {
CountCerts() (int64, error)
GetDefaultKeyPair() (keyPem []byte, certPem []byte, err error)
}
type CertMutation interface {
InitCert(template *x509.Certificate) *tls.Config
}
type certQuery struct{ *queryImpl }
type certMutation struct{ *mutationImpl }
func newCertQuery(base *queryImpl) CertQuery { return &certQuery{base} }
func newCertMutation(base *mutationImpl) CertMutation { return &certMutation{base} }
func (m *certMutation) InitCert(template *x509.Certificate) *tls.Config {
ctx := context.Background()
var (
certPem []byte
keyPem []byte
)
cnt, err := q.CountCerts()
query := NewQuery(m.ctx)
cnt, err := query.CountCerts()
if err != nil {
logger.Logger(ctx).Fatal(err)
}
@@ -29,7 +46,7 @@ func (q *queryImpl) InitCert(template *x509.Certificate) *tls.Config {
if err != nil {
logger.Logger(ctx).Fatal(err)
}
if err = q.ctx.GetApp().GetDBManager().GetDefaultDB().Create(&models.Cert{
if err = m.ctx.GetApp().GetDBManager().GetDefaultDB().Create(&models.Cert{
Name: "default",
CertFile: certPem,
CaFile: certPem,
@@ -38,7 +55,7 @@ func (q *queryImpl) InitCert(template *x509.Certificate) *tls.Config {
logger.Logger(ctx).Fatal(err)
}
} else {
keyPem, certPem, err = q.GetDefaultKeyPair()
keyPem, certPem, err = query.GetDefaultKeyPair()
if err != nil {
logger.Logger(ctx).Fatal(err)
}
@@ -79,7 +96,7 @@ func GenX509Info(template *x509.Certificate) (certPem []byte, keyPem []byte, err
return certBuf.Bytes(), keyBuf.Bytes(), nil
}
func (q *queryImpl) CountCerts() (int64, error) {
func (q *certQuery) CountCerts() (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Cert{}).Count(&count).Error
@@ -89,7 +106,7 @@ func (q *queryImpl) CountCerts() (int64, error) {
return count, nil
}
func (q *queryImpl) GetDefaultKeyPair() (keyPem []byte, certPem []byte, err error) {
func (q *certQuery) GetDefaultKeyPair() (keyPem []byte, certPem []byte, err error) {
resp := &models.Cert{}
err = q.ctx.GetApp().GetDBManager().GetDefaultDB().Model(&models.Cert{}).
Where(&models.Cert{Name: "default"}).First(resp).Error

View File

@@ -9,7 +9,38 @@ import (
"gorm.io/gorm"
)
func (q *queryImpl) ValidateClientSecret(clientID, clientSecret string) (*models.ClientEntity, error) {
type ClientQuery interface {
ValidateClientSecret(clientID, clientSecret string) (*models.ClientEntity, error)
AdminGetClientByClientID(clientID string) (*models.Client, error)
GetClientByClientID(userInfo models.UserInfo, clientID string) (*models.Client, error)
GetClientsByClientIDs(userInfo models.UserInfo, clientIDs []string) ([]*models.Client, error)
GetClientByFilter(userInfo models.UserInfo, client *models.ClientEntity, shadow *bool) (*models.ClientEntity, error)
GetClientByOriginClientID(originClientID string) (*models.ClientEntity, error)
ListClients(userInfo models.UserInfo, page, pageSize int) ([]*models.ClientEntity, error)
ListClientsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ClientEntity, error)
GetAllClients(userInfo models.UserInfo) ([]*models.ClientEntity, error)
CountClients(userInfo models.UserInfo) (int64, error)
CountClientsWithKeyword(userInfo models.UserInfo, keyword string) (int64, error)
CountConfiguredClients(userInfo models.UserInfo) (int64, error)
CountClientsInShadow(userInfo models.UserInfo, clientID string) (int64, error)
GetClientIDsInShadowByClientID(userInfo models.UserInfo, clientID string) ([]string, error)
AdminGetClientIDsInShadowByClientID(clientID string) ([]string, error)
}
type ClientMutation interface {
CreateClient(userInfo models.UserInfo, client *models.ClientEntity) error
DeleteClient(userInfo models.UserInfo, clientID string) error
UpdateClient(userInfo models.UserInfo, client *models.ClientEntity) error
AdminUpdateClientLastSeen(clientID string) error
}
type clientQuery struct{ *queryImpl }
type clientMutation struct{ *mutationImpl }
func newClientQuery(base *queryImpl) ClientQuery { return &clientQuery{base} }
func newClientMutation(base *mutationImpl) ClientMutation { return &clientMutation{base} }
func (q *clientQuery) ValidateClientSecret(clientID, clientSecret string) (*models.ClientEntity, error) {
if clientID == "" || clientSecret == "" {
return nil, fmt.Errorf("invalid client id or client secret")
}
@@ -29,7 +60,7 @@ func (q *queryImpl) ValidateClientSecret(clientID, clientSecret string) (*models
return c.ClientEntity, nil
}
func (q *queryImpl) AdminGetClientByClientID(clientID string) (*models.Client, error) {
func (q *clientQuery) AdminGetClientByClientID(clientID string) (*models.Client, error) {
if clientID == "" {
return nil, fmt.Errorf("invalid client id")
}
@@ -46,7 +77,7 @@ func (q *queryImpl) AdminGetClientByClientID(clientID string) (*models.Client, e
return c, nil
}
func (q *queryImpl) GetClientByClientID(userInfo models.UserInfo, clientID string) (*models.Client, error) {
func (q *clientQuery) GetClientByClientID(userInfo models.UserInfo, clientID string) (*models.Client, error) {
if clientID == "" {
return nil, fmt.Errorf("invalid client id")
}
@@ -65,7 +96,7 @@ func (q *queryImpl) GetClientByClientID(userInfo models.UserInfo, clientID strin
return c, nil
}
func (q *queryImpl) GetClientsByClientIDs(userInfo models.UserInfo, clientIDs []string) ([]*models.Client, error) {
func (q *clientQuery) GetClientsByClientIDs(userInfo models.UserInfo, clientIDs []string) ([]*models.Client, error) {
if len(clientIDs) == 0 {
return nil, fmt.Errorf("invalid client ids")
}
@@ -80,7 +111,7 @@ func (q *queryImpl) GetClientsByClientIDs(userInfo models.UserInfo, clientIDs []
return cs, nil
}
func (q *queryImpl) GetClientByFilter(userInfo models.UserInfo, client *models.ClientEntity, shadow *bool) (*models.ClientEntity, error) {
func (q *clientQuery) GetClientByFilter(userInfo models.UserInfo, client *models.ClientEntity, shadow *bool) (*models.ClientEntity, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
filter := &models.ClientEntity{}
if len(client.ClientID) != 0 {
@@ -113,7 +144,7 @@ func (q *queryImpl) GetClientByFilter(userInfo models.UserInfo, client *models.C
return c.ClientEntity, nil
}
func (q *queryImpl) GetClientByOriginClientID(originClientID string) (*models.ClientEntity, error) {
func (q *clientQuery) GetClientByOriginClientID(originClientID string) (*models.ClientEntity, error) {
if originClientID == "" {
return nil, fmt.Errorf("invalid origin client id")
}
@@ -130,21 +161,21 @@ func (q *queryImpl) GetClientByOriginClientID(originClientID string) (*models.Cl
return c.ClientEntity, nil
}
func (q *queryImpl) CreateClient(userInfo models.UserInfo, client *models.ClientEntity) error {
func (m *clientMutation) CreateClient(userInfo models.UserInfo, client *models.ClientEntity) error {
client.UserID = userInfo.GetUserID()
client.TenantID = userInfo.GetTenantID()
c := &models.Client{
ClientEntity: client,
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(c).Error
}
func (q *queryImpl) DeleteClient(userInfo models.UserInfo, clientID string) error {
func (m *clientMutation) DeleteClient(userInfo models.UserInfo, clientID string) error {
if clientID == "" {
return fmt.Errorf("invalid client id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(&models.Client{
ClientEntity: &models.ClientEntity{
ClientID: clientID,
@@ -160,11 +191,11 @@ func (q *queryImpl) DeleteClient(userInfo models.UserInfo, clientID string) erro
}).Delete(&models.Client{}).Error
}
func (q *queryImpl) UpdateClient(userInfo models.UserInfo, client *models.ClientEntity) error {
func (m *clientMutation) UpdateClient(userInfo models.UserInfo, client *models.ClientEntity) error {
c := &models.Client{
ClientEntity: client,
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Where(&models.Client{
ClientEntity: &models.ClientEntity{
UserID: userInfo.GetUserID(),
@@ -173,7 +204,7 @@ func (q *queryImpl) UpdateClient(userInfo models.UserInfo, client *models.Client
}).Save(c).Error
}
func (q *queryImpl) ListClients(userInfo models.UserInfo, page, pageSize int) ([]*models.ClientEntity, error) {
func (q *clientQuery) ListClients(userInfo models.UserInfo, page, pageSize int) ([]*models.ClientEntity, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -202,7 +233,7 @@ func (q *queryImpl) ListClients(userInfo models.UserInfo, page, pageSize int) ([
}), nil
}
func (q *queryImpl) ListClientsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ClientEntity, error) {
func (q *clientQuery) ListClientsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ClientEntity, error) {
// 只获取没shadow且config有东西
// 或isShadow的client
if page < 1 || pageSize < 1 || len(keyword) == 0 {
@@ -229,7 +260,7 @@ func (q *queryImpl) ListClientsWithKeyword(userInfo models.UserInfo, page, pageS
}), nil
}
func (q *queryImpl) GetAllClients(userInfo models.UserInfo) ([]*models.ClientEntity, error) {
func (q *clientQuery) GetAllClients(userInfo models.UserInfo) ([]*models.ClientEntity, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var clients []*models.Client
err := db.Where(&models.Client{
@@ -247,7 +278,7 @@ func (q *queryImpl) GetAllClients(userInfo models.UserInfo) ([]*models.ClientEnt
}), nil
}
func (q *queryImpl) CountClients(userInfo models.UserInfo) (int64, error) {
func (q *clientQuery) CountClients(userInfo models.UserInfo) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Client{}).Where(&models.Client{
@@ -263,7 +294,7 @@ func (q *queryImpl) CountClients(userInfo models.UserInfo) (int64, error) {
return count, nil
}
func (q *queryImpl) CountClientsWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
func (q *clientQuery) CountClientsWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Client{}).Where(&models.Client{
@@ -279,7 +310,7 @@ func (q *queryImpl) CountClientsWithKeyword(userInfo models.UserInfo, keyword st
return count, nil
}
func (q *queryImpl) CountConfiguredClients(userInfo models.UserInfo) (int64, error) {
func (q *clientQuery) CountConfiguredClients(userInfo models.UserInfo) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Client{}).
@@ -296,7 +327,7 @@ func (q *queryImpl) CountConfiguredClients(userInfo models.UserInfo) (int64, err
return count, nil
}
func (q *queryImpl) CountClientsInShadow(userInfo models.UserInfo, clientID string) (int64, error) {
func (q *clientQuery) CountClientsInShadow(userInfo models.UserInfo, clientID string) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Client{}).
@@ -313,7 +344,7 @@ func (q *queryImpl) CountClientsInShadow(userInfo models.UserInfo, clientID stri
return count, nil
}
func (q *queryImpl) GetClientIDsInShadowByClientID(userInfo models.UserInfo, clientID string) ([]string, error) {
func (q *clientQuery) GetClientIDsInShadowByClientID(userInfo models.UserInfo, clientID string) ([]string, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var clients []*models.Client
err := db.Where(&models.Client{
@@ -330,7 +361,7 @@ func (q *queryImpl) GetClientIDsInShadowByClientID(userInfo models.UserInfo, cli
}), nil
}
func (q *queryImpl) AdminGetClientIDsInShadowByClientID(clientID string) ([]string, error) {
func (q *clientQuery) AdminGetClientIDsInShadowByClientID(clientID string) ([]string, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var clients []*models.Client
err := db.Where(&models.Client{
@@ -345,8 +376,8 @@ func (q *queryImpl) AdminGetClientIDsInShadowByClientID(clientID string) ([]stri
}), nil
}
func (q *queryImpl) AdminUpdateClientLastSeen(clientID string) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *clientMutation) AdminUpdateClientLastSeen(clientID string) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Model(&models.Client{
ClientEntity: &models.ClientEntity{
ClientID: clientID,

View File

@@ -7,7 +7,27 @@ import (
"gorm.io/gorm"
)
func (q *queryImpl) CreateEndpoint(userInfo models.UserInfo, endpoint *models.EndpointEntity) error {
type EndpointQuery interface {
GetEndpointByID(userInfo models.UserInfo, id uint) (*models.Endpoint, error)
ListEndpoints(userInfo models.UserInfo, page, pageSize int) ([]*models.Endpoint, error)
CountEndpoints(userInfo models.UserInfo) (int64, error)
ListEndpointsWithFilters(userInfo models.UserInfo, page, pageSize int, clientID string, wireguardID uint, keyword string) ([]*models.Endpoint, error)
CountEndpointsWithFilters(userInfo models.UserInfo, clientID string, wireguardID uint, keyword string) (int64, error)
}
type EndpointMutation interface {
CreateEndpoint(userInfo models.UserInfo, endpoint *models.EndpointEntity) error
UpdateEndpoint(userInfo models.UserInfo, id uint, endpoint *models.EndpointEntity) error
DeleteEndpoint(userInfo models.UserInfo, id uint) error
}
type endpointQuery struct{ *queryImpl }
type endpointMutation struct{ *mutationImpl }
func newEndpointQuery(base *queryImpl) EndpointQuery { return &endpointQuery{base} }
func newEndpointMutation(base *mutationImpl) EndpointMutation { return &endpointMutation{base} }
func (m *endpointMutation) CreateEndpoint(userInfo models.UserInfo, endpoint *models.EndpointEntity) error {
if endpoint == nil {
return fmt.Errorf("invalid endpoint entity")
}
@@ -15,29 +35,29 @@ func (q *queryImpl) CreateEndpoint(userInfo models.UserInfo, endpoint *models.En
return fmt.Errorf("invalid endpoint host or port")
}
// scope via parent wireguard/client
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(&models.Endpoint{EndpointEntity: endpoint}).Error
}
func (q *queryImpl) UpdateEndpoint(userInfo models.UserInfo, id uint, endpoint *models.EndpointEntity) error {
func (m *endpointMutation) UpdateEndpoint(userInfo models.UserInfo, id uint, endpoint *models.EndpointEntity) error {
if id == 0 || endpoint == nil {
return fmt.Errorf("invalid endpoint id or entity")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Where(&models.Endpoint{
Model: gorm.Model{ID: id},
}).Save(&models.Endpoint{Model: gorm.Model{ID: id}, EndpointEntity: endpoint}).Error
}
func (q *queryImpl) DeleteEndpoint(userInfo models.UserInfo, id uint) error {
func (m *endpointMutation) DeleteEndpoint(userInfo models.UserInfo, id uint) error {
if id == 0 {
return fmt.Errorf("invalid endpoint id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(&models.Endpoint{Model: gorm.Model{ID: id}}).Delete(&models.Endpoint{}).Error
}
func (q *queryImpl) GetEndpointByID(userInfo models.UserInfo, id uint) (*models.Endpoint, error) {
func (q *endpointQuery) GetEndpointByID(userInfo models.UserInfo, id uint) (*models.Endpoint, error) {
if id == 0 {
return nil, fmt.Errorf("invalid endpoint id")
}
@@ -49,7 +69,7 @@ func (q *queryImpl) GetEndpointByID(userInfo models.UserInfo, id uint) (*models.
return &e, nil
}
func (q *queryImpl) ListEndpoints(userInfo models.UserInfo, page, pageSize int) ([]*models.Endpoint, error) {
func (q *endpointQuery) ListEndpoints(userInfo models.UserInfo, page, pageSize int) ([]*models.Endpoint, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -62,7 +82,7 @@ func (q *queryImpl) ListEndpoints(userInfo models.UserInfo, page, pageSize int)
return list, nil
}
func (q *queryImpl) CountEndpoints(userInfo models.UserInfo) (int64, error) {
func (q *endpointQuery) CountEndpoints(userInfo models.UserInfo) (int64, error) {
var count int64
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
if err := db.Model(&models.Endpoint{}).Count(&count).Error; err != nil {
@@ -72,7 +92,7 @@ func (q *queryImpl) CountEndpoints(userInfo models.UserInfo) (int64, error) {
}
// ListEndpointsWithFilters 根据 clientID / wireguardID / keyword 过滤端点
func (q *queryImpl) ListEndpointsWithFilters(userInfo models.UserInfo, page, pageSize int, clientID string, wireguardID uint, keyword string) ([]*models.Endpoint, error) {
func (q *endpointQuery) ListEndpointsWithFilters(userInfo models.UserInfo, page, pageSize int, clientID string, wireguardID uint, keyword string) ([]*models.Endpoint, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -80,7 +100,7 @@ func (q *queryImpl) ListEndpointsWithFilters(userInfo models.UserInfo, page, pag
// 若指定 clientID先校验归属
if len(clientID) > 0 {
if _, err := q.GetClientByClientID(userInfo, clientID); err != nil {
if _, err := newClientQuery(q.queryImpl).GetClientByClientID(userInfo, clientID); err != nil {
return nil, err
}
}
@@ -103,11 +123,11 @@ func (q *queryImpl) ListEndpointsWithFilters(userInfo models.UserInfo, page, pag
return list, nil
}
func (q *queryImpl) CountEndpointsWithFilters(userInfo models.UserInfo, clientID string, wireguardID uint, keyword string) (int64, error) {
func (q *endpointQuery) CountEndpointsWithFilters(userInfo models.UserInfo, clientID string, wireguardID uint, keyword string) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
if len(clientID) > 0 {
if _, err := q.GetClientByClientID(userInfo, clientID); err != nil {
if _, err := newClientQuery(q.queryImpl).GetClientByClientID(userInfo, clientID); err != nil {
return 0, err
}
}

View File

@@ -8,7 +8,29 @@ import (
"gorm.io/gorm"
)
func (q *queryImpl) CreateWireGuardLink(userInfo models.UserInfo, link *models.WireGuardLink) error {
type LinkQuery interface {
ListWireGuardLinksByNetwork(userInfo models.UserInfo, networkID uint) ([]*models.WireGuardLink, error)
GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*models.WireGuardLink, error)
GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromClientId, toClientId uint) (*models.WireGuardLink, error)
ListWireGuardLinksWithFilters(userInfo models.UserInfo, page, pageSize int, networkID uint, keyword string) ([]*models.WireGuardLink, error)
CountWireGuardLinksWithFilters(userInfo models.UserInfo, networkID uint, keyword string) (int64, error)
AdminListWireGuardLinksWithNetworkIDs(networkIDs []uint) ([]*models.WireGuardLink, error)
}
type LinkMutation interface {
CreateWireGuardLink(userInfo models.UserInfo, link *models.WireGuardLink) error
CreateWireGuardLinks(userInfo models.UserInfo, links ...*models.WireGuardLink) error
UpdateWireGuardLink(userInfo models.UserInfo, id uint, link *models.WireGuardLink) error
DeleteWireGuardLink(userInfo models.UserInfo, id uint) error
}
type linkQuery struct{ *queryImpl }
type linkMutation struct{ *mutationImpl }
func newLinkQuery(base *queryImpl) LinkQuery { return &linkQuery{base} }
func newLinkMutation(base *mutationImpl) LinkMutation { return &linkMutation{base} }
func (m *linkMutation) CreateWireGuardLink(userInfo models.UserInfo, link *models.WireGuardLink) error {
if link == nil {
return fmt.Errorf("invalid wg link")
}
@@ -17,11 +39,11 @@ func (q *queryImpl) CreateWireGuardLink(userInfo models.UserInfo, link *models.W
}
link.UserId = uint32(userInfo.GetUserID())
link.TenantId = uint32(userInfo.GetTenantID())
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(link).Error
}
func (q *queryImpl) CreateWireGuardLinks(userInfo models.UserInfo, links ...*models.WireGuardLink) error {
func (m *linkMutation) CreateWireGuardLinks(userInfo models.UserInfo, links ...*models.WireGuardLink) error {
if len(links) == 0 {
return fmt.Errorf("invalid wg links")
}
@@ -29,11 +51,11 @@ func (q *queryImpl) CreateWireGuardLinks(userInfo models.UserInfo, links ...*mod
link.UserId = uint32(userInfo.GetUserID())
link.TenantId = uint32(userInfo.GetTenantID())
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(links).Error
}
func (q *queryImpl) UpdateWireGuardLink(userInfo models.UserInfo, id uint, link *models.WireGuardLink) error {
func (m *linkMutation) UpdateWireGuardLink(userInfo models.UserInfo, id uint, link *models.WireGuardLink) error {
if id == 0 || link == nil {
return fmt.Errorf("invalid wg link id or entity")
}
@@ -43,18 +65,18 @@ func (q *queryImpl) UpdateWireGuardLink(userInfo models.UserInfo, id uint, link
}
link.UserId = uint32(userInfo.GetUserID())
link.TenantId = uint32(userInfo.GetTenantID())
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Where(&models.WireGuardLink{
Model: link.Model,
WireGuardLinkEntity: &models.WireGuardLinkEntity{UserId: link.UserId, TenantId: link.TenantId},
}).Save(link).Error
}
func (q *queryImpl) DeleteWireGuardLink(userInfo models.UserInfo, id uint) error {
func (m *linkMutation) DeleteWireGuardLink(userInfo models.UserInfo, id uint) error {
if id == 0 {
return fmt.Errorf("invalid wg link id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(&models.WireGuardLink{
Model: gorm.Model{ID: id},
WireGuardLinkEntity: &models.WireGuardLinkEntity{
@@ -63,7 +85,7 @@ func (q *queryImpl) DeleteWireGuardLink(userInfo models.UserInfo, id uint) error
}}).Delete(&models.WireGuardLink{}).Error
}
func (q *queryImpl) ListWireGuardLinksByNetwork(userInfo models.UserInfo, networkID uint) ([]*models.WireGuardLink, error) {
func (q *linkQuery) ListWireGuardLinksByNetwork(userInfo models.UserInfo, networkID uint) ([]*models.WireGuardLink, error) {
if networkID == 0 {
return nil, fmt.Errorf("invalid network id")
}
@@ -81,7 +103,7 @@ func (q *queryImpl) ListWireGuardLinksByNetwork(userInfo models.UserInfo, networ
}
// GetWireGuardLinkByID 根据 ID 查询 Link按租户隔离
func (q *queryImpl) GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*models.WireGuardLink, error) {
func (q *linkQuery) GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*models.WireGuardLink, error) {
if id == 0 {
return nil, fmt.Errorf("invalid wg link id")
}
@@ -99,7 +121,7 @@ func (q *queryImpl) GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*mo
return &m, nil
}
func (q *queryImpl) GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromClientId, toClientId uint) (*models.WireGuardLink, error) {
func (q *linkQuery) GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromClientId, toClientId uint) (*models.WireGuardLink, error) {
if fromClientId == 0 || toClientId == 0 {
return nil, fmt.Errorf("invalid from client id or to client id")
}
@@ -117,7 +139,7 @@ func (q *queryImpl) GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromCl
}
// ListWireGuardLinksWithFilters 分页查询 Link支持按 networkID 过滤与关键字(数字时匹配 from/to id
func (q *queryImpl) ListWireGuardLinksWithFilters(userInfo models.UserInfo, page, pageSize int, networkID uint, keyword string) ([]*models.WireGuardLink, error) {
func (q *linkQuery) ListWireGuardLinksWithFilters(userInfo models.UserInfo, page, pageSize int, networkID uint, keyword string) ([]*models.WireGuardLink, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -146,7 +168,7 @@ func (q *queryImpl) ListWireGuardLinksWithFilters(userInfo models.UserInfo, page
}
// CountWireGuardLinksWithFilters 统计分页条件下的总数
func (q *queryImpl) CountWireGuardLinksWithFilters(userInfo models.UserInfo, networkID uint, keyword string) (int64, error) {
func (q *linkQuery) CountWireGuardLinksWithFilters(userInfo models.UserInfo, networkID uint, keyword string) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
@@ -170,7 +192,7 @@ func (q *queryImpl) CountWireGuardLinksWithFilters(userInfo models.UserInfo, net
return count, nil
}
func (q *queryImpl) AdminListWireGuardLinksWithNetworkIDs(networkIDs []uint) ([]*models.WireGuardLink, error) {
func (q *linkQuery) AdminListWireGuardLinksWithNetworkIDs(networkIDs []uint) ([]*models.WireGuardLink, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var list []*models.WireGuardLink
if err := db.Where("network_id IN ?", networkIDs).Find(&list).Error; err != nil {

View File

@@ -7,7 +7,27 @@ import (
"gorm.io/gorm"
)
func (q *queryImpl) CreateNetwork(userInfo models.UserInfo, network *models.NetworkEntity) error {
type NetworkQuery interface {
GetNetworkByID(userInfo models.UserInfo, id uint) (*models.Network, error)
ListNetworks(userInfo models.UserInfo, page, pageSize int) ([]*models.Network, error)
ListNetworksWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Network, error)
CountNetworks(userInfo models.UserInfo) (int64, error)
CountNetworksWithKeyword(userInfo models.UserInfo, keyword string) (int64, error)
}
type NetworkMutation interface {
CreateNetwork(userInfo models.UserInfo, network *models.NetworkEntity) error
UpdateNetwork(userInfo models.UserInfo, id uint, network *models.NetworkEntity) error
DeleteNetwork(userInfo models.UserInfo, id uint) error
}
type networkQuery struct{ *queryImpl }
type networkMutation struct{ *mutationImpl }
func newNetworkQuery(base *queryImpl) NetworkQuery { return &networkQuery{base} }
func newNetworkMutation(base *mutationImpl) NetworkMutation { return &networkMutation{base} }
func (m *networkMutation) CreateNetwork(userInfo models.UserInfo, network *models.NetworkEntity) error {
if network == nil {
return fmt.Errorf("invalid network entity")
}
@@ -18,11 +38,11 @@ func (q *queryImpl) CreateNetwork(userInfo models.UserInfo, network *models.Netw
network.UserId = uint32(userInfo.GetUserID())
network.TenantId = uint32(userInfo.GetTenantID())
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(&models.Network{NetworkEntity: network}).Error
}
func (q *queryImpl) UpdateNetwork(userInfo models.UserInfo, id uint, network *models.NetworkEntity) error {
func (m *networkMutation) UpdateNetwork(userInfo models.UserInfo, id uint, network *models.NetworkEntity) error {
if id == 0 || network == nil {
return fmt.Errorf("invalid network id or entity")
}
@@ -30,7 +50,7 @@ func (q *queryImpl) UpdateNetwork(userInfo models.UserInfo, id uint, network *mo
network.UserId = uint32(userInfo.GetUserID())
network.TenantId = uint32(userInfo.GetTenantID())
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Where(&models.Network{
Model: gorm.Model{ID: id},
NetworkEntity: &models.NetworkEntity{
@@ -43,11 +63,11 @@ func (q *queryImpl) UpdateNetwork(userInfo models.UserInfo, id uint, network *mo
}).Error
}
func (q *queryImpl) DeleteNetwork(userInfo models.UserInfo, id uint) error {
func (m *networkMutation) DeleteNetwork(userInfo models.UserInfo, id uint) error {
if id == 0 {
return fmt.Errorf("invalid network id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(&models.Network{
Model: gorm.Model{ID: id},
NetworkEntity: &models.NetworkEntity{
@@ -57,7 +77,7 @@ func (q *queryImpl) DeleteNetwork(userInfo models.UserInfo, id uint) error {
}).Delete(&models.Network{}).Error
}
func (q *queryImpl) GetNetworkByID(userInfo models.UserInfo, id uint) (*models.Network, error) {
func (q *networkQuery) GetNetworkByID(userInfo models.UserInfo, id uint) (*models.Network, error) {
if id == 0 {
return nil, fmt.Errorf("invalid network id")
}
@@ -75,7 +95,7 @@ func (q *queryImpl) GetNetworkByID(userInfo models.UserInfo, id uint) (*models.N
return &n, nil
}
func (q *queryImpl) ListNetworks(userInfo models.UserInfo, page, pageSize int) ([]*models.Network, error) {
func (q *networkQuery) ListNetworks(userInfo models.UserInfo, page, pageSize int) ([]*models.Network, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -91,7 +111,7 @@ func (q *queryImpl) ListNetworks(userInfo models.UserInfo, page, pageSize int) (
return list, nil
}
func (q *queryImpl) ListNetworksWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Network, error) {
func (q *networkQuery) ListNetworksWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Network, error) {
if page < 1 || pageSize < 1 || len(keyword) == 0 {
return nil, fmt.Errorf("invalid page or page size or keyword")
}
@@ -107,7 +127,7 @@ func (q *queryImpl) ListNetworksWithKeyword(userInfo models.UserInfo, page, page
return list, nil
}
func (q *queryImpl) CountNetworks(userInfo models.UserInfo) (int64, error) {
func (q *networkQuery) CountNetworks(userInfo models.UserInfo) (int64, error) {
var count int64
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
if err := db.Model(&models.Network{}).Where(&models.Network{NetworkEntity: &models.NetworkEntity{
@@ -119,7 +139,7 @@ func (q *queryImpl) CountNetworks(userInfo models.UserInfo) (int64, error) {
return count, nil
}
func (q *queryImpl) CountNetworksWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
func (q *networkQuery) CountNetworksWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
var count int64
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
if err := db.Model(&models.Network{}).Where(&models.Network{NetworkEntity: &models.NetworkEntity{

View File

@@ -16,7 +16,44 @@ import (
"gorm.io/gorm/clause"
)
func (q *queryImpl) GetProxyStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyStatsEntity, error) {
type ProxyQuery interface {
GetProxyStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyStatsEntity, error)
GetProxyStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.ProxyStatsEntity, error)
AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStatsEntity, error)
AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEntity, error)
AdminGetProxyConfigByClientIDAndName(clientID string, name string) (*models.ProxyConfig, error)
GetProxyConfigsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyConfigEntity, error)
GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig *models.ProxyConfigEntity) (*models.ProxyConfig, error)
ListProxyConfigsWithFilters(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error)
AdminListProxyConfigsWithFilters(filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error)
ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity, keyword string) ([]*models.ProxyConfig, error)
ListProxyConfigsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ProxyConfig, error)
ListProxyConfigs(userInfo models.UserInfo, page, pageSize int) ([]*models.ProxyConfig, error)
GetProxyConfigByOriginClientIDAndName(userInfo models.UserInfo, clientID string, name string) (*models.ProxyConfig, error)
CountProxyConfigs(userInfo models.UserInfo) (int64, error)
CountProxyConfigsWithFilters(userInfo models.UserInfo, filters *models.ProxyConfigEntity) (int64, error)
CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, filters *models.ProxyConfigEntity, keyword string) (int64, error)
GetProxyConfigsByWorkerId(userInfo models.UserInfo, workerID string) ([]*models.ProxyConfig, error)
}
type ProxyMutation interface {
AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb.ProxyInfo) error
AdminCreateProxyConfig(proxyCfg *models.ProxyConfig) error
RebuildProxyConfigFromClient(userInfo models.UserInfo, client *models.Client) error
CreateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfigEntity) error
UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfig) error
DeleteProxyConfig(userInfo models.UserInfo, clientID, name string) error
DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models.UserInfo, clientID string) error
DeleteProxyConfigsByClientID(userInfo models.UserInfo, clientID string) error
}
type proxyQuery struct{ *queryImpl }
type proxyMutation struct{ *mutationImpl }
func newProxyQuery(base *queryImpl) ProxyQuery { return &proxyQuery{base} }
func newProxyMutation(base *mutationImpl) ProxyMutation { return &proxyMutation{base} }
func (q *proxyQuery) GetProxyStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyStatsEntity, error) {
if clientID == "" {
return nil, fmt.Errorf("invalid client id")
}
@@ -52,7 +89,7 @@ func (q *queryImpl) GetProxyStatsByClientID(userInfo models.UserInfo, clientID s
}), nil
}
func (q *queryImpl) GetProxyStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.ProxyStatsEntity, error) {
func (q *proxyQuery) GetProxyStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.ProxyStatsEntity, error) {
if serverID == "" {
return nil, fmt.Errorf("invalid server id")
}
@@ -77,12 +114,12 @@ func (q *queryImpl) GetProxyStatsByServerID(userInfo models.UserInfo, serverID s
}), nil
}
func (q *queryImpl) AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb.ProxyInfo) error {
func (m *proxyMutation) AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb.ProxyInfo) error {
if srv.ServerID == "" {
return fmt.Errorf("invalid server id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Transaction(func(tx *gorm.DB) error {
queryResults := make([]interface{}, 3)
@@ -201,7 +238,7 @@ func (q *queryImpl) AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb
})
}
func (q *queryImpl) AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStatsEntity, error) {
func (q *proxyQuery) AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStatsEntity, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
list := []*models.ProxyStats{}
err := db.
@@ -217,7 +254,7 @@ func (q *queryImpl) AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStats
}), nil
}
func (q *queryImpl) AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEntity, error) {
func (q *proxyQuery) AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEntity, error) {
db := tx
list := []*models.ProxyStats{}
err := db.Clauses(clause.Locking{Strength: "UPDATE"}).
@@ -230,15 +267,16 @@ func (q *queryImpl) AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEnti
}), nil
}
func (q *queryImpl) AdminCreateProxyConfig(proxyCfg *models.ProxyConfig) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *proxyMutation) AdminCreateProxyConfig(proxyCfg *models.ProxyConfig) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(proxyCfg).Error
}
// RebuildProxyConfigFromClient rebuild proxy from client
// skip stopped proxy
func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, client *models.Client) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *proxyMutation) RebuildProxyConfigFromClient(userInfo models.UserInfo, client *models.Client) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
query := NewQuery(m.ctx)
pxyCfgs, err := utils.LoadProxiesFromContent(client.ConfigContent)
if err != nil {
@@ -251,7 +289,7 @@ func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, clien
proxyCfg := &models.ProxyConfig{
ProxyConfigEntity: &models.ProxyConfigEntity{},
}
if oldProxyCfg, err := q.GetProxyConfigByOriginClientIDAndName(userInfo, client.ClientID, pxyCfg.GetBaseConfig().Name); err == nil {
if oldProxyCfg, err := query.GetProxyConfigByOriginClientIDAndName(userInfo, client.ClientID, pxyCfg.GetBaseConfig().Name); err == nil {
logger.Logger(context.Background()).WithError(err).Warnf("proxy config already exist, will be override, clientID: [%s], name: [%s]",
client.ClientID, pxyCfg.GetBaseConfig().Name)
proxyCfg.Model = oldProxyCfg.Model
@@ -268,7 +306,7 @@ func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, clien
proxyConfigEntities = append(proxyConfigEntities, proxyCfg)
}
if err := q.DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, client.ClientID); err != nil {
if err := m.DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, client.ClientID); err != nil {
return err
}
@@ -279,7 +317,7 @@ func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, clien
return db.Save(proxyConfigEntities).Error
}
func (q *queryImpl) AdminGetProxyConfigByClientIDAndName(clientID string, name string) (*models.ProxyConfig, error) {
func (q *proxyQuery) AdminGetProxyConfigByClientIDAndName(clientID string, name string) (*models.ProxyConfig, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
proxyCfg := &models.ProxyConfig{}
err := db.
@@ -294,7 +332,7 @@ func (q *queryImpl) AdminGetProxyConfigByClientIDAndName(clientID string, name s
return proxyCfg, nil
}
func (q *queryImpl) GetProxyConfigsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyConfigEntity, error) {
func (q *proxyQuery) GetProxyConfigsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyConfigEntity, error) {
if clientID == "" {
return nil, fmt.Errorf("invalid client id")
}
@@ -315,7 +353,7 @@ func (q *queryImpl) GetProxyConfigsByClientID(userInfo models.UserInfo, clientID
}), nil
}
func (q *queryImpl) GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig *models.ProxyConfigEntity) (*models.ProxyConfig, error) {
func (q *proxyQuery) GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig *models.ProxyConfigEntity) (*models.ProxyConfig, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
filter := &models.ProxyConfigEntity{}
@@ -348,7 +386,7 @@ func (q *queryImpl) GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig
return respProxyCfg, nil
}
func (q *queryImpl) ListProxyConfigsWithFilters(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) {
func (q *proxyQuery) ListProxyConfigsWithFilters(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -370,7 +408,7 @@ func (q *queryImpl) ListProxyConfigsWithFilters(userInfo models.UserInfo, page,
return proxyConfigs, nil
}
func (q *queryImpl) AdminListProxyConfigsWithFilters(filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) {
func (q *proxyQuery) AdminListProxyConfigsWithFilters(filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var proxyConfigs []*models.ProxyConfig
@@ -384,7 +422,7 @@ func (q *queryImpl) AdminListProxyConfigsWithFilters(filters *models.ProxyConfig
return proxyConfigs, nil
}
func (q *queryImpl) ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity, keyword string) ([]*models.ProxyConfig, error) {
func (q *proxyQuery) ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity, keyword string) ([]*models.ProxyConfig, error) {
if page < 1 || pageSize < 1 || len(keyword) == 0 {
return nil, fmt.Errorf("invalid page or page size or keyword")
}
@@ -406,26 +444,26 @@ func (q *queryImpl) ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserIn
return proxyConfigs, nil
}
func (q *queryImpl) ListProxyConfigsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ProxyConfig, error) {
func (q *proxyQuery) ListProxyConfigsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ProxyConfig, error) {
return q.ListProxyConfigsWithFiltersAndKeyword(userInfo, page, pageSize, &models.ProxyConfigEntity{}, keyword)
}
func (q *queryImpl) ListProxyConfigs(userInfo models.UserInfo, page, pageSize int) ([]*models.ProxyConfig, error) {
func (q *proxyQuery) ListProxyConfigs(userInfo models.UserInfo, page, pageSize int) ([]*models.ProxyConfig, error) {
return q.ListProxyConfigsWithFilters(userInfo, page, pageSize, &models.ProxyConfigEntity{})
}
func (q *queryImpl) CreateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfigEntity) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *proxyMutation) CreateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfigEntity) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
proxyCfg.UserID = userInfo.GetUserID()
proxyCfg.TenantID = userInfo.GetTenantID()
return db.Create(&models.ProxyConfig{ProxyConfigEntity: proxyCfg}).Error
}
func (q *queryImpl) UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfig) error {
func (m *proxyMutation) UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfig) error {
if proxyCfg.ID == 0 {
return fmt.Errorf("invalid proxy config id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
proxyCfg.UserID = userInfo.GetUserID()
proxyCfg.TenantID = userInfo.GetTenantID()
return db.Where(&models.ProxyConfig{
@@ -440,11 +478,11 @@ func (q *queryImpl) UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models
}).Save(proxyCfg).Error
}
func (q *queryImpl) DeleteProxyConfig(userInfo models.UserInfo, clientID, name string) error {
func (m *proxyMutation) DeleteProxyConfig(userInfo models.UserInfo, clientID, name string) error {
if clientID == "" || name == "" {
return fmt.Errorf("invalid client id or name")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().
Where(&models.ProxyConfig{ProxyConfigEntity: &models.ProxyConfigEntity{
UserID: userInfo.GetUserID(),
@@ -455,11 +493,11 @@ func (q *queryImpl) DeleteProxyConfig(userInfo models.UserInfo, clientID, name s
Delete(&models.ProxyConfig{}).Error
}
func (q *queryImpl) DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models.UserInfo, clientID string) error {
func (m *proxyMutation) DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models.UserInfo, clientID string) error {
if clientID == "" {
return fmt.Errorf("invalid client id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().
Where(
db.Where(&models.ProxyConfig{ProxyConfigEntity: &models.ProxyConfigEntity{
@@ -477,11 +515,11 @@ func (q *queryImpl) DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models
Delete(&models.ProxyConfig{}).Error
}
func (q *queryImpl) DeleteProxyConfigsByClientID(userInfo models.UserInfo, clientID string) error {
func (m *proxyMutation) DeleteProxyConfigsByClientID(userInfo models.UserInfo, clientID string) error {
if clientID == "" {
return fmt.Errorf("invalid client id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().
Where(&models.ProxyConfig{ProxyConfigEntity: &models.ProxyConfigEntity{
UserID: userInfo.GetUserID(),
@@ -491,7 +529,7 @@ func (q *queryImpl) DeleteProxyConfigsByClientID(userInfo models.UserInfo, clien
Delete(&models.ProxyConfig{}).Error
}
func (q *queryImpl) GetProxyConfigByOriginClientIDAndName(userInfo models.UserInfo, clientID string, name string) (*models.ProxyConfig, error) {
func (q *proxyQuery) GetProxyConfigByOriginClientIDAndName(userInfo models.UserInfo, clientID string, name string) (*models.ProxyConfig, error) {
if clientID == "" || name == "" {
return nil, fmt.Errorf("invalid client id or name")
}
@@ -511,11 +549,11 @@ func (q *queryImpl) GetProxyConfigByOriginClientIDAndName(userInfo models.UserIn
return item, nil
}
func (q *queryImpl) CountProxyConfigs(userInfo models.UserInfo) (int64, error) {
func (q *proxyQuery) CountProxyConfigs(userInfo models.UserInfo) (int64, error) {
return q.CountProxyConfigsWithFilters(userInfo, &models.ProxyConfigEntity{})
}
func (q *queryImpl) CountProxyConfigsWithFilters(userInfo models.UserInfo, filters *models.ProxyConfigEntity) (int64, error) {
func (q *proxyQuery) CountProxyConfigsWithFilters(userInfo models.UserInfo, filters *models.ProxyConfigEntity) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
filters.UserID = userInfo.GetUserID()
filters.TenantID = userInfo.GetTenantID()
@@ -530,7 +568,7 @@ func (q *queryImpl) CountProxyConfigsWithFilters(userInfo models.UserInfo, filte
return count, nil
}
func (q *queryImpl) CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, filters *models.ProxyConfigEntity, keyword string) (int64, error) {
func (q *proxyQuery) CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, filters *models.ProxyConfigEntity, keyword string) (int64, error) {
if len(keyword) == 0 {
return q.CountProxyConfigsWithFilters(userInfo, filters)
}
@@ -549,7 +587,7 @@ func (q *queryImpl) CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserI
return count, nil
}
func (q *queryImpl) GetProxyConfigsByWorkerId(userInfo models.UserInfo, workerID string) ([]*models.ProxyConfig, error) {
func (q *proxyQuery) GetProxyConfigsByWorkerId(userInfo models.UserInfo, workerID string) ([]*models.ProxyConfig, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
items := []*models.ProxyConfig{}

View File

@@ -2,14 +2,110 @@ package dao
import "github.com/VaalaCat/frp-panel/services/app"
type Query interface{}
type Query interface {
CertQuery
ClientQuery
EndpointQuery
LinkQuery
NetworkQuery
ProxyQuery
ServerQuery
StatsQuery
UserQuery
WireGuardQuery
WorkerQuery
}
type Mutation interface {
CertMutation
ClientMutation
EndpointMutation
LinkMutation
NetworkMutation
ProxyMutation
ServerMutation
StatsMutation
UserMutation
WireGuardMutation
WorkerMutation
UserGroupMutation
}
// queryImpl/mutationImpl 是具体表级实现的基础结构(持有 ctx
type queryImpl struct {
ctx *app.Context
}
func NewQuery(ctx *app.Context) *queryImpl {
return &queryImpl{
ctx: ctx,
type mutationImpl struct {
ctx *app.Context
}
// compositeQuery / compositeMutation 组合各子领域实现,对外暴露统一入口。
type compositeQuery struct {
CertQuery
ClientQuery
EndpointQuery
LinkQuery
NetworkQuery
ProxyQuery
ServerQuery
StatsQuery
UserQuery
WireGuardQuery
WorkerQuery
}
type compositeMutation struct {
CertMutation
ClientMutation
EndpointMutation
LinkMutation
NetworkMutation
ProxyMutation
ServerMutation
StatsMutation
UserMutation
WireGuardMutation
WorkerMutation
UserGroupMutation
}
func NewQuery(ctx *app.Context) Query {
base := &queryImpl{ctx: ctx}
return &compositeQuery{
CertQuery: newCertQuery(base),
ClientQuery: newClientQuery(base),
EndpointQuery: newEndpointQuery(base),
LinkQuery: newLinkQuery(base),
NetworkQuery: newNetworkQuery(base),
ProxyQuery: newProxyQuery(base),
ServerQuery: newServerQuery(base),
StatsQuery: newStatsQuery(base),
UserQuery: newUserQuery(base),
WireGuardQuery: newWireGuardQuery(base),
WorkerQuery: newWorkerQuery(base),
}
}
func NewMutation(ctx *app.Context) Mutation {
base := &mutationImpl{ctx: ctx}
return &compositeMutation{
CertMutation: newCertMutation(base),
ClientMutation: newClientMutation(base),
EndpointMutation: newEndpointMutation(base),
LinkMutation: newLinkMutation(base),
NetworkMutation: newNetworkMutation(base),
ProxyMutation: newProxyMutation(base),
ServerMutation: newServerMutation(base),
StatsMutation: newStatsMutation(base),
UserMutation: newUserMutation(base),
WireGuardMutation: newWireGuardMutation(base),
WorkerMutation: newWorkerMutation(base),
UserGroupMutation: newUserGroupMutation(base),
}
}
var (
_ Query = (*compositeQuery)(nil)
_ Mutation = (*compositeMutation)(nil)
)

View File

@@ -9,8 +9,34 @@ import (
"github.com/samber/lo"
)
func (q *queryImpl) InitDefaultServer(serverIP string) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
type ServerQuery interface {
GetDefaultServer() (*models.ServerEntity, error)
ValidateServerSecret(serverID string, secret string) (*models.ServerEntity, error)
AdminGetServerByServerID(serverID string) (*models.ServerEntity, error)
GetServerByServerID(userInfo models.UserInfo, serverID string) (*models.ServerEntity, error)
ListServers(userInfo models.UserInfo, page, pageSize int) ([]*models.ServerEntity, error)
ListServersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ServerEntity, error)
CountServers(userInfo models.UserInfo) (int64, error)
CountServersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error)
CountConfiguredServers(userInfo models.UserInfo) (int64, error)
}
type ServerMutation interface {
InitDefaultServer(serverIP string)
UpdateDefaultServer(c *models.Server) error
CreateServer(userInfo models.UserInfo, server *models.ServerEntity) error
DeleteServer(userInfo models.UserInfo, serverID string) error
UpdateServer(userInfo models.UserInfo, server *models.ServerEntity) error
}
type serverQuery struct{ *queryImpl }
type serverMutation struct{ *mutationImpl }
func newServerQuery(base *queryImpl) ServerQuery { return &serverQuery{base} }
func newServerMutation(base *mutationImpl) ServerMutation { return &serverMutation{base} }
func (m *serverMutation) InitDefaultServer(serverIP string) {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
db.Where(&models.Server{
ServerEntity: &models.ServerEntity{
ServerID: defs.DefaultServerID,
@@ -24,7 +50,7 @@ func (q *queryImpl) InitDefaultServer(serverIP string) {
}).FirstOrCreate(&models.Server{})
}
func (q *queryImpl) GetDefaultServer() (*models.ServerEntity, error) {
func (q *serverQuery) GetDefaultServer() (*models.ServerEntity, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
c := &models.Server{}
err := db.
@@ -38,8 +64,8 @@ func (q *queryImpl) GetDefaultServer() (*models.ServerEntity, error) {
return c.ServerEntity, nil
}
func (q *queryImpl) UpdateDefaultServer(c *models.Server) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *serverMutation) UpdateDefaultServer(c *models.Server) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
c.ServerID = defs.DefaultServerID
err := db.Where(&models.Server{
ServerEntity: &models.ServerEntity{
@@ -51,7 +77,7 @@ func (q *queryImpl) UpdateDefaultServer(c *models.Server) error {
return nil
}
func (q *queryImpl) ValidateServerSecret(serverID string, secret string) (*models.ServerEntity, error) {
func (q *serverQuery) ValidateServerSecret(serverID string, secret string) (*models.ServerEntity, error) {
if serverID == "" || secret == "" {
return nil, fmt.Errorf("invalid request")
}
@@ -71,7 +97,7 @@ func (q *queryImpl) ValidateServerSecret(serverID string, secret string) (*model
return c.ServerEntity, nil
}
func (q *queryImpl) AdminGetServerByServerID(serverID string) (*models.ServerEntity, error) {
func (q *serverQuery) AdminGetServerByServerID(serverID string) (*models.ServerEntity, error) {
if serverID == "" {
return nil, fmt.Errorf("invalid server id")
}
@@ -88,7 +114,7 @@ func (q *queryImpl) AdminGetServerByServerID(serverID string) (*models.ServerEnt
return c.ServerEntity, nil
}
func (q *queryImpl) GetServerByServerID(userInfo models.UserInfo, serverID string) (*models.ServerEntity, error) {
func (q *serverQuery) GetServerByServerID(userInfo models.UserInfo, serverID string) (*models.ServerEntity, error) {
if serverID == "" {
return nil, fmt.Errorf("invalid server id")
}
@@ -110,21 +136,21 @@ func (q *queryImpl) GetServerByServerID(userInfo models.UserInfo, serverID strin
return c.ServerEntity, nil
}
func (q *queryImpl) CreateServer(userInfo models.UserInfo, server *models.ServerEntity) error {
func (m *serverMutation) CreateServer(userInfo models.UserInfo, server *models.ServerEntity) error {
server.UserID = userInfo.GetUserID()
server.TenantID = userInfo.GetTenantID()
c := &models.Server{
ServerEntity: server,
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(c).Error
}
func (q *queryImpl) DeleteServer(userInfo models.UserInfo, serverID string) error {
func (m *serverMutation) DeleteServer(userInfo models.UserInfo, serverID string) error {
if serverID == "" {
return fmt.Errorf("invalid server id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(
&models.Server{
ServerEntity: &models.ServerEntity{
@@ -139,14 +165,14 @@ func (q *queryImpl) DeleteServer(userInfo models.UserInfo, serverID string) erro
}).Error
}
func (q *queryImpl) UpdateServer(userInfo models.UserInfo, server *models.ServerEntity) error {
func (m *serverMutation) UpdateServer(userInfo models.UserInfo, server *models.ServerEntity) error {
c := &models.Server{
ServerEntity: server,
}
if userInfo.GetUserID() == defs.DefaultAdminUserID && server.ServerID == defs.DefaultServerID {
return q.UpdateDefaultServer(c)
return m.UpdateDefaultServer(c)
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Where(
&models.Server{
ServerEntity: &models.ServerEntity{
@@ -157,7 +183,7 @@ func (q *queryImpl) UpdateServer(userInfo models.UserInfo, server *models.Server
).Save(c).Error
}
func (q *queryImpl) ListServers(userInfo models.UserInfo, page, pageSize int) ([]*models.ServerEntity, error) {
func (q *serverQuery) ListServers(userInfo models.UserInfo, page, pageSize int) ([]*models.ServerEntity, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -187,7 +213,7 @@ func (q *queryImpl) ListServers(userInfo models.UserInfo, page, pageSize int) ([
}), nil
}
func (q *queryImpl) ListServersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ServerEntity, error) {
func (q *serverQuery) ListServersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ServerEntity, error) {
if page < 1 || pageSize < 1 || len(keyword) == 0 {
return nil, fmt.Errorf("invalid page or page size or keyword")
}
@@ -214,7 +240,7 @@ func (q *queryImpl) ListServersWithKeyword(userInfo models.UserInfo, page, pageS
}), nil
}
func (q *queryImpl) CountServers(userInfo models.UserInfo) (int64, error) {
func (q *serverQuery) CountServers(userInfo models.UserInfo) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Server{}).Where(
@@ -231,7 +257,7 @@ func (q *queryImpl) CountServers(userInfo models.UserInfo) (int64, error) {
return count, nil
}
func (q *queryImpl) CountServersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
func (q *serverQuery) CountServersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Server{}).Where(
@@ -248,7 +274,7 @@ func (q *queryImpl) CountServersWithKeyword(userInfo models.UserInfo, keyword st
return count, nil
}
func (q *queryImpl) CountConfiguredServers(userInfo models.UserInfo) (int64, error) {
func (q *serverQuery) CountConfiguredServers(userInfo models.UserInfo) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Server{}).Where(

View File

@@ -11,12 +11,29 @@ const (
MSetBatchSize = 100
)
func (q *queryImpl) AdminSaveTodyStats(s *models.HistoryProxyStats) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
type StatsQuery interface {
GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID int) ([]*models.HistoryProxyStats, error)
GetHistoryStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.HistoryProxyStats, error)
GetHistoryStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.HistoryProxyStats, error)
}
type StatsMutation interface {
AdminSaveTodyStats(s *models.HistoryProxyStats) error
AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxyStats) error
}
type statsQuery struct{ *queryImpl }
type statsMutation struct{ *mutationImpl }
func newStatsQuery(base *queryImpl) StatsQuery { return &statsQuery{base} }
func newStatsMutation(base *mutationImpl) StatsMutation { return &statsMutation{base} }
func (m *statsMutation) AdminSaveTodyStats(s *models.HistoryProxyStats) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Save(s).Error
}
func (q *queryImpl) AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxyStats) error {
func (m *statsMutation) AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxyStats) error {
if len(s) == 0 {
return nil
}
@@ -31,7 +48,7 @@ func (q *queryImpl) AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxySta
return nil
}
func (q *queryImpl) GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID int) ([]*models.HistoryProxyStats, error) {
func (q *statsQuery) GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID int) ([]*models.HistoryProxyStats, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var stats []*models.HistoryProxyStats
err := db.Where(&models.HistoryProxyStats{
@@ -45,7 +62,7 @@ func (q *queryImpl) GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID i
return stats, nil
}
func (q *queryImpl) GetHistoryStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.HistoryProxyStats, error) {
func (q *statsQuery) GetHistoryStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.HistoryProxyStats, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var stats []*models.HistoryProxyStats
err := db.Where(&models.HistoryProxyStats{
@@ -59,7 +76,7 @@ func (q *queryImpl) GetHistoryStatsByClientID(userInfo models.UserInfo, clientID
return stats, nil
}
func (q *queryImpl) GetHistoryStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.HistoryProxyStats, error) {
func (q *statsQuery) GetHistoryStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.HistoryProxyStats, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var stats []*models.HistoryProxyStats
err := db.Where(&models.HistoryProxyStats{

View File

@@ -8,7 +8,28 @@ import (
"github.com/samber/lo"
)
func (q *queryImpl) AdminGetAllUsers() ([]*models.UserEntity, error) {
type UserQuery interface {
AdminGetAllUsers() ([]*models.UserEntity, error)
AdminCountUsers() (int64, error)
GetUserByUserID(userID int) (*models.UserEntity, error)
GetUserByUserName(userName string) (*models.UserEntity, error)
CheckUserPassword(userNameOrEmail, password string) (bool, models.UserInfo, error)
CheckUserNameAndEmail(userName, email string) error
}
type UserMutation interface {
UpdateUser(userInfo models.UserInfo, user *models.UserEntity) error
AdminUpdateUser(userInfo models.UserInfo, user *models.UserEntity) error
CreateUser(user *models.UserEntity) error
}
type userQuery struct{ *queryImpl }
type userMutation struct{ *mutationImpl }
func newUserQuery(base *queryImpl) UserQuery { return &userQuery{base} }
func newUserMutation(base *mutationImpl) UserMutation { return &userMutation{base} }
func (q *userQuery) AdminGetAllUsers() ([]*models.UserEntity, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
users := make([]*models.User, 0)
err := db.Find(&users).Error
@@ -21,7 +42,7 @@ func (q *queryImpl) AdminGetAllUsers() ([]*models.UserEntity, error) {
}), nil
}
func (q *queryImpl) AdminCountUsers() (int64, error) {
func (q *userQuery) AdminCountUsers() (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.User{}).Count(&count).Error
@@ -31,7 +52,7 @@ func (q *queryImpl) AdminCountUsers() (int64, error) {
return count, nil
}
func (q *queryImpl) GetUserByUserID(userID int) (*models.UserEntity, error) {
func (q *userQuery) GetUserByUserID(userID int) (*models.UserEntity, error) {
if userID == 0 {
return nil, fmt.Errorf("invalid user id")
}
@@ -48,8 +69,8 @@ func (q *queryImpl) GetUserByUserID(userID int) (*models.UserEntity, error) {
return u.UserEntity, nil
}
func (q *queryImpl) UpdateUser(userInfo models.UserInfo, user *models.UserEntity) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *userMutation) UpdateUser(userInfo models.UserInfo, user *models.UserEntity) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
user.UserID = userInfo.GetUserID()
return db.Model(&models.User{}).Where(
&models.User{
@@ -62,8 +83,8 @@ func (q *queryImpl) UpdateUser(userInfo models.UserInfo, user *models.UserEntity
}).Error
}
func (q *queryImpl) AdminUpdateUser(userInfo models.UserInfo, user *models.UserEntity) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *userMutation) AdminUpdateUser(userInfo models.UserInfo, user *models.UserEntity) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
user.UserID = userInfo.GetUserID()
return db.Model(&models.User{}).Where(
&models.User{
@@ -76,7 +97,7 @@ func (q *queryImpl) AdminUpdateUser(userInfo models.UserInfo, user *models.UserE
}).Error
}
func (q *queryImpl) GetUserByUserName(userName string) (*models.UserEntity, error) {
func (q *userQuery) GetUserByUserName(userName string) (*models.UserEntity, error) {
if userName == "" {
return nil, fmt.Errorf("invalid user name")
}
@@ -93,7 +114,7 @@ func (q *queryImpl) GetUserByUserName(userName string) (*models.UserEntity, erro
return u.UserEntity, nil
}
func (q *queryImpl) CheckUserPassword(userNameOrEmail, password string) (bool, models.UserInfo, error) {
func (q *userQuery) CheckUserPassword(userNameOrEmail, password string) (bool, models.UserInfo, error) {
var user models.User
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
@@ -111,7 +132,7 @@ func (q *queryImpl) CheckUserPassword(userNameOrEmail, password string) (bool, m
return utils.CheckPasswordHash(password, user.Password), user, nil
}
func (q *queryImpl) CheckUserNameAndEmail(userName, email string) error {
func (q *userQuery) CheckUserNameAndEmail(userName, email string) error {
var user models.User
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
@@ -129,10 +150,10 @@ func (q *queryImpl) CheckUserNameAndEmail(userName, email string) error {
return nil
}
func (q *queryImpl) CreateUser(user *models.UserEntity) error {
func (m *userMutation) CreateUser(user *models.UserEntity) error {
u := &models.User{
UserEntity: user,
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(u).Error
}

View File

@@ -7,7 +7,16 @@ import (
"github.com/VaalaCat/frp-panel/models"
)
func (q *queryImpl) CreateGroup(userInfo models.UserInfo, groupID, groupName, comment string) (*models.UserGroup, error) {
type UserGroupMutation interface {
CreateGroup(userInfo models.UserInfo, groupID, groupName, comment string) (*models.UserGroup, error)
DeleteGroup(userInfo models.UserInfo, groupID string) error
}
type userGroupMutation struct{ *mutationImpl }
func newUserGroupMutation(base *mutationImpl) UserGroupMutation { return &userGroupMutation{base} }
func (m *userGroupMutation) CreateGroup(userInfo models.UserInfo, groupID, groupName, comment string) (*models.UserGroup, error) {
if groupID == "" || groupName == "" {
return nil, fmt.Errorf("invalid group id or group name")
}
@@ -16,7 +25,7 @@ func (q *queryImpl) CreateGroup(userInfo models.UserInfo, groupID, groupName, co
return nil, fmt.Errorf("only admin can create group")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
g := &models.UserGroup{
TenantID: userInfo.GetTenantID(),
@@ -32,12 +41,12 @@ func (q *queryImpl) CreateGroup(userInfo models.UserInfo, groupID, groupName, co
return g, nil
}
func (q *queryImpl) DeleteGroup(userInfo models.UserInfo, groupID string) error {
func (m *userGroupMutation) DeleteGroup(userInfo models.UserInfo, groupID string) error {
if userInfo.GetRole() != defs.UserRole_Admin {
return fmt.Errorf("only admin can delete group")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(&models.UserGroup{
TenantID: userInfo.GetTenantID(),
GroupID: groupID,

View File

@@ -7,7 +7,30 @@ import (
"gorm.io/gorm"
)
func (q *queryImpl) CreateWireGuard(userInfo models.UserInfo, wg *models.WireGuard) error {
type WireGuardQuery interface {
GetWireGuardByID(userInfo models.UserInfo, id uint) (*models.WireGuard, error)
AdminGetWireGuardByClientIDAndInterfaceName(clientID, interfaceName string) (*models.WireGuard, error)
GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID uint) ([]*models.WireGuard, error)
GetWireGuardLocalAddressesByNetworkID(userInfo models.UserInfo, networkID uint) ([]string, error)
ListWireGuardsWithFilters(userInfo models.UserInfo, page, pageSize int, filter *models.WireGuardEntity, keyword string) ([]*models.WireGuard, error)
AdminListWireGuardsWithClientID(clientID string) ([]*models.WireGuard, error)
AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*models.WireGuard, error)
CountWireGuardsWithFilters(userInfo models.UserInfo, filter *models.WireGuardEntity, keyword string) (int64, error)
}
type WireGuardMutation interface {
CreateWireGuard(userInfo models.UserInfo, wg *models.WireGuard) error
UpdateWireGuard(userInfo models.UserInfo, id uint, wg *models.WireGuard) error
DeleteWireGuard(userInfo models.UserInfo, id uint) error
}
type wireGuardQuery struct{ *queryImpl }
type wireGuardMutation struct{ *mutationImpl }
func newWireGuardQuery(base *queryImpl) WireGuardQuery { return &wireGuardQuery{base} }
func newWireGuardMutation(base *mutationImpl) WireGuardMutation { return &wireGuardMutation{base} }
func (m *wireGuardMutation) CreateWireGuard(userInfo models.UserInfo, wg *models.WireGuard) error {
if wg == nil || wg.WireGuardEntity == nil {
return fmt.Errorf("invalid wireguard entity")
}
@@ -18,11 +41,11 @@ func (q *queryImpl) CreateWireGuard(userInfo models.UserInfo, wg *models.WireGua
wg.UserId = uint32(userInfo.GetUserID())
wg.TenantId = uint32(userInfo.GetTenantID())
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Create(wg).Error
}
func (q *queryImpl) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *models.WireGuard) error {
func (m *wireGuardMutation) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *models.WireGuard) error {
if id == 0 || wg == nil || wg.WireGuardEntity == nil {
return fmt.Errorf("invalid wireguard id or entity")
}
@@ -30,7 +53,7 @@ func (q *queryImpl) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *model
wg.UserId = uint32(userInfo.GetUserID())
wg.TenantId = uint32(userInfo.GetTenantID())
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
// clear endpoints and resave if provided
if wg.AdvertisedEndpoints != nil {
@@ -49,11 +72,11 @@ func (q *queryImpl) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *model
}}).Save(wg).Error
}
func (q *queryImpl) DeleteWireGuard(userInfo models.UserInfo, id uint) error {
func (m *wireGuardMutation) DeleteWireGuard(userInfo models.UserInfo, id uint) error {
if id == 0 {
return fmt.Errorf("invalid wireguard id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(&models.WireGuard{
Model: gorm.Model{ID: id},
WireGuardEntity: &models.WireGuardEntity{
@@ -63,7 +86,7 @@ func (q *queryImpl) DeleteWireGuard(userInfo models.UserInfo, id uint) error {
}).Delete(&models.WireGuard{}).Error
}
func (q *queryImpl) GetWireGuardByID(userInfo models.UserInfo, id uint) (*models.WireGuard, error) {
func (q *wireGuardQuery) GetWireGuardByID(userInfo models.UserInfo, id uint) (*models.WireGuard, error) {
if id == 0 {
return nil, fmt.Errorf("invalid wireguard id")
}
@@ -84,7 +107,7 @@ func (q *queryImpl) GetWireGuardByID(userInfo models.UserInfo, id uint) (*models
return &m, nil
}
func (q *queryImpl) AdminGetWireGuardByClientIDAndInterfaceName(clientID, interfaceName string) (*models.WireGuard, error) {
func (q *wireGuardQuery) AdminGetWireGuardByClientIDAndInterfaceName(clientID, interfaceName string) (*models.WireGuard, error) {
if clientID == "" || interfaceName == "" {
return nil, fmt.Errorf("invalid client id or interface name")
}
@@ -100,7 +123,7 @@ func (q *queryImpl) AdminGetWireGuardByClientIDAndInterfaceName(clientID, interf
return &m, nil
}
func (q *queryImpl) GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID uint) ([]*models.WireGuard, error) {
func (q *wireGuardQuery) GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID uint) ([]*models.WireGuard, error) {
if networkID == 0 {
return nil, fmt.Errorf("invalid network id")
}
@@ -119,7 +142,7 @@ func (q *queryImpl) GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID
return list, nil
}
func (q *queryImpl) GetWireGuardLocalAddressesByNetworkID(userInfo models.UserInfo, networkID uint) ([]string, error) {
func (q *wireGuardQuery) GetWireGuardLocalAddressesByNetworkID(userInfo models.UserInfo, networkID uint) ([]string, error) {
if networkID == 0 {
return nil, fmt.Errorf("invalid network id")
}
@@ -135,7 +158,7 @@ func (q *queryImpl) GetWireGuardLocalAddressesByNetworkID(userInfo models.UserIn
return list, nil
}
func (q *queryImpl) ListWireGuardsWithFilters(userInfo models.UserInfo, page, pageSize int, filter *models.WireGuardEntity, keyword string) ([]*models.WireGuard, error) {
func (q *wireGuardQuery) ListWireGuardsWithFilters(userInfo models.UserInfo, page, pageSize int, filter *models.WireGuardEntity, keyword string) ([]*models.WireGuard, error) {
if page < 1 || pageSize < 1 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -170,7 +193,7 @@ func (q *queryImpl) ListWireGuardsWithFilters(userInfo models.UserInfo, page, pa
return list, nil
}
func (q *queryImpl) AdminListWireGuardsWithClientID(clientID string) ([]*models.WireGuard, error) {
func (q *wireGuardQuery) AdminListWireGuardsWithClientID(clientID string) ([]*models.WireGuard, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var list []*models.WireGuard
if err := db.Where(&models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ClientID: clientID}}).
@@ -180,7 +203,7 @@ func (q *queryImpl) AdminListWireGuardsWithClientID(clientID string) ([]*models.
return list, nil
}
func (q *queryImpl) AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*models.WireGuard, error) {
func (q *wireGuardQuery) AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*models.WireGuard, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var list []*models.WireGuard
if err := db.Where("network_id IN ?", networkIDs).
@@ -190,7 +213,7 @@ func (q *queryImpl) AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*mod
return list, nil
}
func (q *queryImpl) CountWireGuardsWithFilters(userInfo models.UserInfo, filter *models.WireGuardEntity, keyword string) (int64, error) {
func (q *wireGuardQuery) CountWireGuardsWithFilters(userInfo models.UserInfo, filter *models.WireGuardEntity, keyword string) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
base := db.Model(&models.WireGuard{}).Where(&models.WireGuard{WireGuardEntity: &models.WireGuardEntity{

View File

@@ -6,8 +6,29 @@ import (
"github.com/VaalaCat/frp-panel/models"
)
func (q *queryImpl) CreateWorker(userInfo models.UserInfo, worker *models.Worker) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
type WorkerQuery interface {
GetWorkerByWorkerID(userInfo models.UserInfo, workerID string) (*models.Worker, error)
ListWorkers(userInfo models.UserInfo, page, pageSize int) ([]*models.Worker, error)
AdminListWorkersByClientID(clientID string) ([]*models.Worker, error)
ListWorkersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Worker, error)
CountWorkers(userInfo models.UserInfo) (int64, error)
CountWorkersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error)
}
type WorkerMutation interface {
CreateWorker(userInfo models.UserInfo, worker *models.Worker) error
DeleteWorker(userInfo models.UserInfo, workerID string) error
UpdateWorker(userInfo models.UserInfo, worker *models.Worker) error
}
type workerQuery struct{ *queryImpl }
type workerMutation struct{ *mutationImpl }
func newWorkerQuery(base *queryImpl) WorkerQuery { return &workerQuery{base} }
func newWorkerMutation(base *mutationImpl) WorkerMutation { return &workerMutation{base} }
func (m *workerMutation) CreateWorker(userInfo models.UserInfo, worker *models.Worker) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
worker.UserId = uint32(userInfo.GetUserID())
worker.TenantId = uint32(userInfo.GetTenantID())
@@ -17,8 +38,8 @@ func (q *queryImpl) CreateWorker(userInfo models.UserInfo, worker *models.Worker
return db.Create(worker).Error
}
func (q *queryImpl) DeleteWorker(userInfo models.UserInfo, workerID string) error {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
func (m *workerMutation) DeleteWorker(userInfo models.UserInfo, workerID string) error {
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
return db.Unscoped().Where(&models.Worker{
WorkerEntity: &models.WorkerEntity{
@@ -29,7 +50,7 @@ func (q *queryImpl) DeleteWorker(userInfo models.UserInfo, workerID string) erro
}).Delete(&models.Worker{}).Error
}
func (q *queryImpl) UpdateWorker(userInfo models.UserInfo, worker *models.Worker) error {
func (m *workerMutation) UpdateWorker(userInfo models.UserInfo, worker *models.Worker) error {
if worker.WorkerEntity == nil {
return fmt.Errorf("invalid worker entity")
}
@@ -37,7 +58,7 @@ func (q *queryImpl) UpdateWorker(userInfo models.UserInfo, worker *models.Worker
return fmt.Errorf("invalid worker id")
}
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
db := m.ctx.GetApp().GetDBManager().GetDefaultDB()
if err := db.Unscoped().Model(&models.Worker{
WorkerEntity: &models.WorkerEntity{
@@ -58,7 +79,7 @@ func (q *queryImpl) UpdateWorker(userInfo models.UserInfo, worker *models.Worker
}).Save(worker).Error
}
func (q *queryImpl) GetWorkerByWorkerID(userInfo models.UserInfo, workerID string) (*models.Worker, error) {
func (q *workerQuery) GetWorkerByWorkerID(userInfo models.UserInfo, workerID string) (*models.Worker, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
w := &models.Worker{}
err := db.Where(&models.Worker{
@@ -74,7 +95,7 @@ func (q *queryImpl) GetWorkerByWorkerID(userInfo models.UserInfo, workerID strin
return w, nil
}
func (q *queryImpl) ListWorkers(userInfo models.UserInfo, page, pageSize int) ([]*models.Worker, error) {
func (q *workerQuery) ListWorkers(userInfo models.UserInfo, page, pageSize int) ([]*models.Worker, error) {
if page < 1 || pageSize < 1 || pageSize > 100 {
return nil, fmt.Errorf("invalid page or page size")
}
@@ -96,9 +117,9 @@ func (q *queryImpl) ListWorkers(userInfo models.UserInfo, page, pageSize int) ([
return workers, nil
}
func (q *queryImpl) AdminListWorkersByClientID(clientID string) ([]*models.Worker, error) {
func (q *workerQuery) AdminListWorkersByClientID(clientID string) ([]*models.Worker, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
client, err := q.AdminGetClientByClientID(clientID)
client, err := newClientQuery(q.queryImpl).AdminGetClientByClientID(clientID)
if err != nil {
return nil, err
}
@@ -111,7 +132,7 @@ func (q *queryImpl) AdminListWorkersByClientID(clientID string) ([]*models.Worke
return client.Workers, nil
}
func (q *queryImpl) ListWorkersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Worker, error) {
func (q *workerQuery) ListWorkersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Worker, error) {
if page < 1 || pageSize < 1 || len(keyword) == 0 || pageSize > 100 {
return nil, fmt.Errorf("invalid page or page size or keyword")
}
@@ -134,7 +155,7 @@ func (q *queryImpl) ListWorkersWithKeyword(userInfo models.UserInfo, page, pageS
return workers, nil
}
func (q *queryImpl) CountWorkers(userInfo models.UserInfo) (int64, error) {
func (q *workerQuery) CountWorkers(userInfo models.UserInfo) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Worker{}).Where(&models.Worker{
@@ -149,7 +170,7 @@ func (q *queryImpl) CountWorkers(userInfo models.UserInfo) (int64, error) {
return count, nil
}
func (q *queryImpl) CountWorkersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
func (q *workerQuery) CountWorkersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) {
db := q.ctx.GetApp().GetDBManager().GetDefaultDB()
var count int64
err := db.Model(&models.Worker{}).Where("name like ?", "%"+keyword+"%").

View File

@@ -208,7 +208,7 @@ func (s *server) ServerSend(sender pb.Master_ServerSendServer) error {
}
if cliType == defs.CliTypeClient {
if err := dao.NewQuery(ctx).AdminUpdateClientLastSeen(req.GetClientId()); err != nil {
if err := dao.NewMutation(ctx).AdminUpdateClientLastSeen(req.GetClientId()); err != nil {
logger.Logger(ctx).Errorf("cannot update client last seen, %s id: [%s]", req.GetEvent().String(), req.GetClientId())
}
}