diff --git a/biz/master/auth/login.go b/biz/master/auth/login.go index 738926e..cc7227e 100644 --- a/biz/master/auth/login.go +++ b/biz/master/auth/login.go @@ -41,7 +41,7 @@ func LoginHandler(ctx *app.Context, req *pb.LoginRequest) (*pb.LoginResponse, er userEntity.Role = defs.UserRole_Admin - dao.NewQuery(ctx).AdminUpdateUser(&models.UserEntity{ + dao.NewMutation(ctx).AdminUpdateUser(&models.UserEntity{ UserID: user.GetUserID(), }, userEntity.UserEntity) } diff --git a/biz/master/auth/register.go b/biz/master/auth/register.go index fc44ea9..bc0c8f8 100644 --- a/biz/master/auth/register.go +++ b/biz/master/auth/register.go @@ -56,7 +56,7 @@ func RegisterHandler(c *app.Context, req *pb.RegisterRequest) (*pb.RegisterRespo newUser.Role = defs.UserRole_Admin } - err = dao.NewQuery(c).CreateUser(newUser) + err = dao.NewMutation(c).CreateUser(newUser) if err != nil { return &pb.RegisterResponse{ Status: &pb.Status{Code: pb.RespCode_RESP_CODE_INVALID, Message: err.Error()}, diff --git a/biz/master/client/create_client.go b/biz/master/client/create_client.go index 57c22d4..8a3ba71 100644 --- a/biz/master/client/create_client.go +++ b/biz/master/client/create_client.go @@ -31,7 +31,7 @@ func InitClientHandler(c *app.Context, req *pb.InitClientRequest) (*pb.InitClien logger.Logger(c).Infof("start to init client, request:[%s], transformed global client id:[%s]", req.String(), globalClientID) - if err := dao.NewQuery(c).CreateClient(userInfo, + if err := dao.NewMutation(c).CreateClient(userInfo, &models.ClientEntity{ ClientID: globalClientID, TenantID: userInfo.GetTenantID(), diff --git a/biz/master/client/delete_client.go b/biz/master/client/delete_client.go index 96fad8c..ea837f7 100644 --- a/biz/master/client/delete_client.go +++ b/biz/master/client/delete_client.go @@ -29,11 +29,11 @@ func DeleteClientHandler(ctx *app.Context, req *pb.DeleteClientRequest) (*pb.Del }, nil } - if err := dao.NewQuery(ctx).DeleteClient(userInfo, clientID); err != nil { + if err := dao.NewMutation(ctx).DeleteClient(userInfo, clientID); err != nil { return nil, err } - if err := dao.NewQuery(ctx).DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, clientID); err != nil { + if err := dao.NewMutation(ctx).DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, clientID); err != nil { return nil, err } diff --git a/biz/master/client/delete_tunnel.go b/biz/master/client/delete_tunnel.go index defad77..96abb6b 100644 --- a/biz/master/client/delete_tunnel.go +++ b/biz/master/client/delete_tunnel.go @@ -31,7 +31,7 @@ func RemoveFrpcHandler(c *app.Context, req *pb.RemoveFRPCRequest) (*pb.RemoveFRP return nil, err } - err = dao.NewQuery(c).DeleteClient(userInfo, clientID) + err = dao.NewMutation(c).DeleteClient(userInfo, clientID) if err != nil { logger.Logger(context.Background()).WithError(err).Errorf("cannot delete client, id: [%s]", clientID) return nil, err diff --git a/biz/master/client/helper.go b/biz/master/client/helper.go index 934a12a..f637975 100644 --- a/biz/master/client/helper.go +++ b/biz/master/client/helper.go @@ -42,6 +42,7 @@ func ValidateClientRequest(ctx *app.Context, req ValidateableClientRequest) (*mo func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.ClientEntity) (*models.ClientEntity, error) { userInfo := common.GetUserInfo(c) + m := dao.NewMutation(c) var clientID = clientEntity.ClientID var childClient *models.ClientEntity @@ -53,7 +54,7 @@ func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.Cl return nil, err } - if err := dao.NewQuery(c).RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: childClient}); err != nil { + if err := m.RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: childClient}); err != nil { logger.Logger(c).WithError(err).Errorf("cannot rebuild proxy config from client, id: [%s]", childClient.ClientID) return nil, err } @@ -62,12 +63,12 @@ func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.Cl clientEntity.IsShadow = true clientEntity.ConfigContent = nil clientEntity.ServerID = "" - if err := dao.NewQuery(c).UpdateClient(userInfo, clientEntity); err != nil { + if err := m.UpdateClient(userInfo, clientEntity); err != nil { logger.Logger(c).WithError(err).Errorf("cannot update client, id: [%s]", clientID) return nil, err } - if err := dao.NewQuery(c).DeleteProxyConfigsByClientID(userInfo, clientID); err != nil { + if err := m.DeleteProxyConfigsByClientID(userInfo, clientID); err != nil { logger.Logger(c).WithError(err).Errorf("cannot delete proxy configs, id: [%s]", clientID) return nil, err } @@ -79,13 +80,15 @@ func MakeClientShadowed(c *app.Context, serverID string, clientEntity *models.Cl // bool 表示是否新建 func ChildClientForServer(c *app.Context, serverID string, clientEntity *models.ClientEntity) (*models.ClientEntity, bool, error) { userInfo := common.GetUserInfo(c) + q := dao.NewQuery(c) + m := dao.NewMutation(c) originClientID := clientEntity.ClientID if len(clientEntity.OriginClientID) != 0 { originClientID = clientEntity.OriginClientID } - existClient, err := dao.NewQuery(c).GetClientByFilter(userInfo, &models.ClientEntity{ + existClient, err := q.GetClientByFilter(userInfo, &models.ClientEntity{ ServerID: serverID, OriginClientID: originClientID, }, lo.ToPtr(false)) @@ -93,7 +96,7 @@ func ChildClientForServer(c *app.Context, serverID string, clientEntity *models. return existClient, false, nil } - shadowCount, err := dao.NewQuery(c).CountClientsInShadow(userInfo, originClientID) + shadowCount, err := q.CountClientsInShadow(userInfo, originClientID) if err != nil { logger.Logger(c).WithError(err).Errorf("cannot count shadow clients, id: [%s]", originClientID) return nil, false, err @@ -109,7 +112,7 @@ func ChildClientForServer(c *app.Context, serverID string, clientEntity *models. copiedClient.OriginClientID = originClientID copiedClient.IsShadow = false copiedClient.Stopped = false - if err := dao.NewQuery(c).CreateClient(userInfo, copiedClient); err != nil { + if err := m.CreateClient(userInfo, copiedClient); err != nil { logger.Logger(c).WithError(err).Errorf("cannot create child client, id: [%s]", copiedClient.ClientID) return nil, false, err } diff --git a/biz/master/client/rpc_pull_config.go b/biz/master/client/rpc_pull_config.go index f473a49..285aaed 100644 --- a/biz/master/client/rpc_pull_config.go +++ b/biz/master/client/rpc_pull_config.go @@ -21,7 +21,7 @@ func RPCPullConfig(ctx *app.Context, req *pb.PullClientConfigReq) (*pb.PullClien return nil, err } - if err := dao.NewQuery(ctx).AdminUpdateClientLastSeen(cli.ClientID); err != nil { + if err := dao.NewMutation(ctx).AdminUpdateClientLastSeen(cli.ClientID); err != nil { logger.Logger(ctx).WithError(err).Errorf("update client last_seen_at time error, req:[%s] clientId:[%s]", req.String(), cli.ClientID) } diff --git a/biz/master/client/start_client.go b/biz/master/client/start_client.go index 4ae70f0..eaebe44 100644 --- a/biz/master/client/start_client.go +++ b/biz/master/client/start_client.go @@ -38,7 +38,7 @@ func StartFRPCHandler(ctx *app.Context, req *pb.StartFRPCRequest) (*pb.StartFRPC client.Stopped = false - if err = dao.NewQuery(ctx).UpdateClient(userInfo, client); err != nil { + if err = dao.NewMutation(ctx).UpdateClient(userInfo, client); err != nil { return nil, err } diff --git a/biz/master/client/stop_client.go b/biz/master/client/stop_client.go index b2aa596..acbe163 100644 --- a/biz/master/client/stop_client.go +++ b/biz/master/client/stop_client.go @@ -38,7 +38,7 @@ func StopFRPCHandler(ctx *app.Context, req *pb.StopFRPCRequest) (*pb.StopFRPCRes client.Stopped = true - if err = dao.NewQuery(ctx).UpdateClient(userInfo, client); err != nil { + if err = dao.NewMutation(ctx).UpdateClient(userInfo, client); err != nil { return nil, err } diff --git a/biz/master/client/sync_tunnel.go b/biz/master/client/sync_tunnel.go index 67b7e1f..76927c8 100644 --- a/biz/master/client/sync_tunnel.go +++ b/biz/master/client/sync_tunnel.go @@ -32,7 +32,7 @@ func SyncTunnel(ctx *app.Context, userInfo models.UserInfo) error { return } - if err := dao.NewQuery(ctx).UpdateClient(userInfo, cli); err != nil { + if err := dao.NewMutation(ctx).UpdateClient(userInfo, cli); err != nil { logger.Logger(context.Background()).WithError(err).Errorf("cannot update client, id: [%s]", cli.ClientID) return } diff --git a/biz/master/client/update_tunnel.go b/biz/master/client/update_tunnel.go index 55ec95b..1056777 100644 --- a/biz/master/client/update_tunnel.go +++ b/biz/master/client/update_tunnel.go @@ -27,6 +27,8 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC reqClientID = req.GetClientId() // may be shadow or child userInfo = common.GetUserInfo(c) ) + q := dao.NewQuery(c) + m := dao.NewMutation(c) cliCfg, err := utils.LoadClientConfigNormal(content, true) if err != nil { @@ -36,7 +38,7 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC }, err } - cliRecord, err := dao.NewQuery(c).GetClientByClientID(userInfo, reqClientID) + cliRecord, err := q.GetClientByClientID(userInfo, reqClientID) if err != nil { logger.Logger(c).WithError(err).Errorf("cannot get client, id: [%s]", reqClientID) return &pb.UpdateFRPCResponse{ @@ -72,7 +74,7 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC } } - srv, err := dao.NewQuery(c).GetServerByServerID(userInfo, req.GetServerId()) + srv, err := q.GetServerByServerID(userInfo, req.GetServerId()) if err != nil || srv == nil || len(srv.ServerIP) == 0 || len(srv.ConfigContent) == 0 { logger.Logger(c).WithError(err).Errorf("cannot get server, server is not prepared, id: [%s]", req.GetServerId()) return &pb.UpdateFRPCResponse{ @@ -169,12 +171,12 @@ func UpdateFrpcHander(c *app.Context, req *pb.UpdateFRPCRequest) (*pb.UpdateFRPC cli.Comment = req.GetComment() } - if err := dao.NewQuery(c).UpdateClient(userInfo, cli); err != nil { + if err := m.UpdateClient(userInfo, cli); err != nil { logger.Logger(c).WithError(err).Errorf("cannot update client, id: [%s]", cli.ClientID) return nil, err } - if err := dao.NewQuery(c).RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: cli}); err != nil { + if err := m.RebuildProxyConfigFromClient(userInfo, &models.Client{ClientEntity: cli}); err != nil { logger.Logger(c).WithError(err).Errorf("cannot rebuild proxy config from client, id: [%s]", cli.ClientID) return nil, err } diff --git a/biz/master/proxy/delete_proxy_config.go b/biz/master/proxy/delete_proxy_config.go index b91b026..bac2921 100644 --- a/biz/master/proxy/delete_proxy_config.go +++ b/biz/master/proxy/delete_proxy_config.go @@ -20,12 +20,14 @@ func DeleteProxyConfig(c *app.Context, req *pb.DeleteProxyConfigRequest) (*pb.De serverID = req.GetServerId() proxyName = req.GetName() ) + q := dao.NewQuery(c) + m := dao.NewMutation(c) if len(clientID) == 0 || len(serverID) == 0 || len(proxyName) == 0 { return nil, fmt.Errorf("request invalid") } - cli, err := dao.NewQuery(c).GetClientByClientID(userInfo, clientID) + cli, err := q.GetClientByClientID(userInfo, clientID) if err != nil { logger.Logger(c).WithError(err).Errorf("cannot get client, id: [%s]", clientID) return nil, err @@ -49,7 +51,7 @@ func DeleteProxyConfig(c *app.Context, req *pb.DeleteProxyConfigRequest) (*pb.De return nil, err } - if err := dao.NewQuery(c).UpdateClient(userInfo, cli.ClientEntity); err != nil { + if err := m.UpdateClient(userInfo, cli.ClientEntity); err != nil { logger.Logger(c).WithError(err).Errorf("cannot update client, id: [%s]", clientID) return nil, err } @@ -72,7 +74,7 @@ func DeleteProxyConfig(c *app.Context, req *pb.DeleteProxyConfigRequest) (*pb.De return nil, err } - if err := dao.NewQuery(c).DeleteProxyConfig(userInfo, clientID, proxyName); err != nil { + if err := m.DeleteProxyConfig(userInfo, clientID, proxyName); err != nil { logger.Logger(c).WithError(err).Errorf("cannot delete proxy config, id: [%s]", clientID) return nil, err } diff --git a/biz/master/proxy/start_proxy.go b/biz/master/proxy/start_proxy.go index 27afc23..29db2a2 100644 --- a/biz/master/proxy/start_proxy.go +++ b/biz/master/proxy/start_proxy.go @@ -45,7 +45,7 @@ func StartProxy(ctx *app.Context, req *pb.StartProxyRequest) (*pb.StartProxyResp // 1. 更新proxy状态 proxyConfig.Stopped = false - err = dao.NewQuery(ctx).UpdateProxyConfig(userInfo, proxyConfig) + err = dao.NewMutation(ctx).UpdateProxyConfig(userInfo, proxyConfig) if err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot update proxy config, client: [%s], server: [%s], proxy name: [%s]", clientID, serverID, proxyName) return nil, err diff --git a/biz/master/proxy/stop_proxy.go b/biz/master/proxy/stop_proxy.go index f2f1289..7b147aa 100644 --- a/biz/master/proxy/stop_proxy.go +++ b/biz/master/proxy/stop_proxy.go @@ -44,7 +44,7 @@ func StopProxy(ctx *app.Context, req *pb.StopProxyRequest) (*pb.StopProxyRespons // 1. 更新proxy状态 proxyConfig.Stopped = true - err = dao.NewQuery(ctx).UpdateProxyConfig(userInfo, proxyConfig) + err = dao.NewMutation(ctx).UpdateProxyConfig(userInfo, proxyConfig) if err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot update proxy config, client: [%s], server: [%s], proxy name: [%s]", clientID, serverID, proxyName) return nil, err diff --git a/biz/master/proxy/task_collect_daily_stats.go b/biz/master/proxy/task_collect_daily_stats.go index d3a281f..662dabd 100644 --- a/biz/master/proxy/task_collect_daily_stats.go +++ b/biz/master/proxy/task_collect_daily_stats.go @@ -41,7 +41,7 @@ func CollectDailyStats(appInstance app.Application) error { } }) - if err := dao.NewQuery(ctx).AdminMSaveTodyStats(tx, proxyDailyStats); err != nil { + if err := dao.NewMutation(ctx).AdminMSaveTodyStats(tx, proxyDailyStats); err != nil { logger.Logger(context.Background()).WithError(err).Error("CollectDailyStats cannot save stats") return err } diff --git a/biz/master/proxy/update_proxy_config.go b/biz/master/proxy/update_proxy_config.go index 0636e01..f7672ef 100644 --- a/biz/master/proxy/update_proxy_config.go +++ b/biz/master/proxy/update_proxy_config.go @@ -27,8 +27,10 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up clientID = req.GetClientId() serverID = req.GetServerId() ) + q := dao.NewQuery(c) + m := dao.NewMutation(c) - cli, err := dao.NewQuery(c).GetClientByClientID(userInfo, clientID) + cli, err := q.GetClientByClientID(userInfo, clientID) if err != nil { logger.Logger(c).WithError(err).Errorf("cannot get client, id: [%s]", clientID) return nil, err @@ -38,7 +40,7 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up if clientEntity.ServerID != serverID { logger.Logger(c).Errorf("client and server not match, find or create client, client: [%s], server: [%s]", clientID, serverID) - originClient, err := dao.NewQuery(c).GetClientByClientID(userInfo, clientEntity.OriginClientID) + originClient, err := q.GetClientByClientID(userInfo, clientEntity.OriginClientID) if err != nil { logger.Logger(c).WithError(err).Errorf("cannot get origin client, id: [%s]", clientEntity.OriginClientID) return nil, err @@ -51,7 +53,7 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up } } - _, err = dao.NewQuery(c).GetServerByServerID(userInfo, serverID) + _, err = q.GetServerByServerID(userInfo, serverID) if err != nil { logger.Logger(c).WithError(err).Errorf("cannot get server, id: [%s]", serverID) return nil, err @@ -84,13 +86,13 @@ func UpdateProxyConfig(c *app.Context, req *pb.UpdateProxyConfigRequest) (*pb.Up return nil, err } - oldProxyCfg, err := dao.NewQuery(c).GetProxyConfigByOriginClientIDAndName(userInfo, clientID, proxyCfg.Name) + oldProxyCfg, err := q.GetProxyConfigByOriginClientIDAndName(userInfo, clientID, proxyCfg.Name) if err != nil { logger.Logger(c).WithError(err).Errorf("cannot get proxy config, id: [%s]", clientID) return nil, err } - if dao.NewQuery(c).UpdateProxyConfig(userInfo, &models.ProxyConfig{ + if m.UpdateProxyConfig(userInfo, &models.ProxyConfig{ Model: oldProxyCfg.Model, ProxyConfigEntity: proxyCfg, }) != nil { diff --git a/biz/master/server/create_server.go b/biz/master/server/create_server.go index f85ebbc..9f2f6de 100644 --- a/biz/master/server/create_server.go +++ b/biz/master/server/create_server.go @@ -31,7 +31,7 @@ func InitServerHandler(c *app.Context, req *pb.InitServerRequest) (*pb.InitServe globalServerID := app.GlobalClientID(userInfo.GetUserName(), "s", userServerID) - if err := dao.NewQuery(c).CreateServer(userInfo, + if err := dao.NewMutation(c).CreateServer(userInfo, &models.ServerEntity{ ServerID: globalServerID, TenantID: userInfo.GetTenantID(), diff --git a/biz/master/server/delete_server.go b/biz/master/server/delete_server.go index 30b8a5c..22e08a7 100644 --- a/biz/master/server/delete_server.go +++ b/biz/master/server/delete_server.go @@ -25,7 +25,7 @@ func DeleteServerHandler(c *app.Context, req *pb.DeleteServerRequest) (*pb.Delet }, nil } - if err := dao.NewQuery(c).DeleteServer(userInfo, userServerID); err != nil { + if err := dao.NewMutation(c).DeleteServer(userInfo, userServerID); err != nil { return nil, err } diff --git a/biz/master/server/delete_tunnel.go b/biz/master/server/delete_tunnel.go index 1f8955e..54cc2d3 100644 --- a/biz/master/server/delete_tunnel.go +++ b/biz/master/server/delete_tunnel.go @@ -25,7 +25,7 @@ func RemoveFrpsHandler(c *app.Context, req *pb.RemoveFRPSRequest) (*pb.RemoveFRP return nil, err } - if err = dao.NewQuery(c).DeleteServer(userInfo, serverID); err != nil { + if err = dao.NewMutation(c).DeleteServer(userInfo, serverID); err != nil { logger.Logger(context.Background()).WithError(err).Errorf("cannot delete server, id: [%s]", serverID) return nil, err } diff --git a/biz/master/server/rpc_push_proxy_info.go b/biz/master/server/rpc_push_proxy_info.go index cf3a4c5..dbf7601 100644 --- a/biz/master/server/rpc_push_proxy_info.go +++ b/biz/master/server/rpc_push_proxy_info.go @@ -15,7 +15,7 @@ func PushProxyInfo(ctx *app.Context, req *pb.PushProxyInfoReq) (*pb.PushProxyInf return nil, err } - if err = dao.NewQuery(ctx).AdminUpdateProxyStats(srv, req.GetProxyInfos()); err != nil { + if err = dao.NewMutation(ctx).AdminUpdateProxyStats(srv, req.GetProxyInfos()); err != nil { return nil, err } return &pb.PushProxyInfoResp{ diff --git a/biz/master/server/update_tunnel.go b/biz/master/server/update_tunnel.go index 4ed1799..0cd8c55 100644 --- a/biz/master/server/update_tunnel.go +++ b/biz/master/server/update_tunnel.go @@ -27,7 +27,8 @@ func UpdateFrpsHander(c *app.Context, req *pb.UpdateFRPSRequest) (*pb.UpdateFRPS return nil, fmt.Errorf("request invalid") } - srv, err := dao.NewQuery(c).GetServerByServerID(userInfo, serverID) + q := dao.NewQuery(c) + srv, err := q.GetServerByServerID(userInfo, serverID) if srv == nil || err != nil { logger.Logger(context.Background()).WithError(err).Errorf("cannot get server, id: [%s]", serverID) return nil, err @@ -55,7 +56,7 @@ func UpdateFrpsHander(c *app.Context, req *pb.UpdateFRPSRequest) (*pb.UpdateFRPS srv.FrpsUrls = req.GetFrpsUrls() } - if err := dao.NewQuery(c).UpdateServer(userInfo, srv); err != nil { + if err := dao.NewMutation(c).UpdateServer(userInfo, srv); err != nil { logger.Logger(context.Background()).WithError(err).Errorf("cannot update server, id: [%s]", serverID) return nil, err } diff --git a/biz/master/user/update_user_info.go b/biz/master/user/update_user_info.go index e81d064..0180773 100644 --- a/biz/master/user/update_user_info.go +++ b/biz/master/user/update_user_info.go @@ -47,7 +47,7 @@ func UpdateUserInfoHander(c *app.Context, req *pb.UpdateUserInfoRequest) (*pb.Up newUserEntity.Token = newUserInfo.GetToken() } - if err := dao.NewQuery(c).UpdateUser(userInfo, newUserEntity); err != nil { + if err := dao.NewMutation(c).UpdateUser(userInfo, newUserEntity); err != nil { return &pb.UpdateUserInfoResponse{ Status: &pb.Status{Code: pb.RespCode_RESP_CODE_INVALID, Message: err.Error()}, }, err diff --git a/biz/master/wg/endpoint_create.go b/biz/master/wg/endpoint_create.go index 0c9962f..2944fd2 100644 --- a/biz/master/wg/endpoint_create.go +++ b/biz/master/wg/endpoint_create.go @@ -21,7 +21,7 @@ func CreateEndpoint(ctx *app.Context, req *pb.CreateEndpointRequest) (*pb.Create } entity := &models.EndpointEntity{Host: e.GetHost(), Port: e.GetPort(), ClientID: e.GetClientId(), Type: e.GetType(), Uri: e.GetUri()} - if err := dao.NewQuery(ctx).CreateEndpoint(userInfo, entity); err != nil { + if err := dao.NewMutation(ctx).CreateEndpoint(userInfo, entity); err != nil { return nil, err } return &pb.CreateEndpointResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}, Endpoint: &pb.Endpoint{Id: 0, Host: entity.Host, Port: entity.Port, ClientId: entity.ClientID}}, nil diff --git a/biz/master/wg/endpoint_delete.go b/biz/master/wg/endpoint_delete.go index 61a15f1..cc5f194 100644 --- a/biz/master/wg/endpoint_delete.go +++ b/biz/master/wg/endpoint_delete.go @@ -18,7 +18,7 @@ func DeleteEndpoint(ctx *app.Context, req *pb.DeleteEndpointRequest) (*pb.Delete if id == 0 { return nil, errors.New("invalid id") } - if err := dao.NewQuery(ctx).DeleteEndpoint(userInfo, id); err != nil { + if err := dao.NewMutation(ctx).DeleteEndpoint(userInfo, id); err != nil { return nil, err } return &pb.DeleteEndpointResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}}, nil diff --git a/biz/master/wg/endpoint_update.go b/biz/master/wg/endpoint_update.go index 15e3030..d830db3 100644 --- a/biz/master/wg/endpoint_update.go +++ b/biz/master/wg/endpoint_update.go @@ -19,7 +19,10 @@ func UpdateEndpoint(ctx *app.Context, req *pb.UpdateEndpointRequest) (*pb.Update return nil, errors.New("invalid endpoint params") } - oldEndpoint, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, uint(e.GetId())) + q := dao.NewQuery(ctx) + m := dao.NewMutation(ctx) + + oldEndpoint, err := q.GetEndpointByID(userInfo, uint(e.GetId())) if err != nil { return nil, err } @@ -37,7 +40,7 @@ func UpdateEndpoint(ctx *app.Context, req *pb.UpdateEndpointRequest) (*pb.Update oldEndpoint.Type = e.GetType() } - if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, uint(e.GetId()), oldEndpoint.EndpointEntity); err != nil { + if err := m.UpdateEndpoint(userInfo, uint(e.GetId()), oldEndpoint.EndpointEntity); err != nil { return nil, err } diff --git a/biz/master/wg/link_query.go b/biz/master/wg/link_query.go index 6280c1f..5de8676 100644 --- a/biz/master/wg/link_query.go +++ b/biz/master/wg/link_query.go @@ -23,6 +23,8 @@ func CreateWireGuardLink(ctx *app.Context, req *pb.CreateWireGuardLinkRequest) ( } // 校验两端属于同一 network q := dao.NewQuery(ctx) + mut := dao.NewMutation(ctx) + from, err := q.GetWireGuardByID(userInfo, uint(l.GetFromWireguardId())) if err != nil { return nil, err @@ -43,7 +45,7 @@ func CreateWireGuardLink(ctx *app.Context, req *pb.CreateWireGuardLinkRequest) ( reverse.NetworkID = from.NetworkID reverse.ToEndpointID = 0 - if err := q.CreateWireGuardLinks(userInfo, m, reverse); err != nil { + if err := mut.CreateWireGuardLinks(userInfo, m, reverse); err != nil { return nil, err } return &pb.CreateWireGuardLinkResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}, WireguardLink: m.ToPB()}, nil @@ -59,6 +61,8 @@ func UpdateWireGuardLink(ctx *app.Context, req *pb.UpdateWireGuardLinkRequest) ( return nil, errors.New("invalid link params") } q := dao.NewQuery(ctx) + mut := dao.NewMutation(ctx) + m, err := q.GetWireGuardLinkByID(userInfo, uint(l.GetId())) if err != nil { return nil, err @@ -73,7 +77,7 @@ func UpdateWireGuardLink(ctx *app.Context, req *pb.UpdateWireGuardLinkRequest) ( m.ToEndpointID = uint(l.GetToEndpoint().GetId()) } - if err := q.UpdateWireGuardLink(userInfo, uint(l.GetId()), m); err != nil { + if err := mut.UpdateWireGuardLink(userInfo, uint(l.GetId()), m); err != nil { return nil, err } return &pb.UpdateWireGuardLinkResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_SUCCESS, Message: "success"}, WireguardLink: m.ToPB()}, nil @@ -89,21 +93,24 @@ func DeleteWireGuardLink(ctx *app.Context, req *pb.DeleteWireGuardLinkRequest) ( return nil, errors.New("invalid id") } - link, err := dao.NewQuery(ctx).GetWireGuardLinkByID(userInfo, id) + q := dao.NewQuery(ctx) + m := dao.NewMutation(ctx) + + link, err := q.GetWireGuardLinkByID(userInfo, id) if err != nil { return nil, err } - rev, err := dao.NewQuery(ctx).GetWireGuardLinkByClientIDs(userInfo, link.ToWireGuardID, link.FromWireGuardID) + rev, err := q.GetWireGuardLinkByClientIDs(userInfo, link.ToWireGuardID, link.FromWireGuardID) if err != nil { return nil, err } - if err := dao.NewQuery(ctx).DeleteWireGuardLink(userInfo, uint(link.ID)); err != nil { + if err := m.DeleteWireGuardLink(userInfo, uint(link.ID)); err != nil { return nil, err } - if err := dao.NewQuery(ctx).DeleteWireGuardLink(userInfo, uint(rev.ID)); err != nil { + if err := m.DeleteWireGuardLink(userInfo, uint(rev.ID)); err != nil { return nil, err } diff --git a/biz/master/wg/network_create.go b/biz/master/wg/network_create.go index ffea49f..8ceb2b4 100644 --- a/biz/master/wg/network_create.go +++ b/biz/master/wg/network_create.go @@ -36,7 +36,7 @@ func CreateNetwork(ctx *app.Context, req *pb.CreateNetworkRequest) (*pb.CreateNe ACL: models.JSON[*pb.AclConfig]{Data: req.GetNetwork().GetAcl()}, } - if err := dao.NewQuery(ctx).CreateNetwork(userInfo, entity); err != nil { + if err := dao.NewMutation(ctx).CreateNetwork(userInfo, entity); err != nil { log.WithError(err).Errorf("create network error") return nil, err } diff --git a/biz/master/wg/network_delete.go b/biz/master/wg/network_delete.go index 998635d..be429f9 100644 --- a/biz/master/wg/network_delete.go +++ b/biz/master/wg/network_delete.go @@ -18,7 +18,7 @@ func DeleteNetwork(ctx *app.Context, req *pb.DeleteNetworkRequest) (*pb.DeleteNe if id == 0 { return nil, errors.New("invalid id") } - if err := dao.NewQuery(ctx).DeleteNetwork(userInfo, id); err != nil { + if err := dao.NewMutation(ctx).DeleteNetwork(userInfo, id); err != nil { return nil, err } return &pb.DeleteNetworkResponse{ diff --git a/biz/master/wg/network_update.go b/biz/master/wg/network_update.go index ff286d8..bba1bf6 100644 --- a/biz/master/wg/network_update.go +++ b/biz/master/wg/network_update.go @@ -20,7 +20,7 @@ func UpdateNetwork(ctx *app.Context, req *pb.UpdateNetworkRequest) (*pb.UpdateNe return nil, errors.New("invalid network") } entity := &models.NetworkEntity{Name: n.GetName(), CIDR: n.GetCidr(), ACL: models.JSON[*pb.AclConfig]{Data: n.GetAcl()}} - if err := dao.NewQuery(ctx).UpdateNetwork(userInfo, uint(n.GetId()), entity); err != nil { + if err := dao.NewMutation(ctx).UpdateNetwork(userInfo, uint(n.GetId()), entity); err != nil { return nil, err } diff --git a/biz/master/wg/wireguard_create.go b/biz/master/wg/wireguard_create.go index fcda357..ba665a0 100644 --- a/biz/master/wg/wireguard_create.go +++ b/biz/master/wg/wireguard_create.go @@ -28,14 +28,16 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea if cfg == nil || len(cfg.GetClientId()) == 0 || len(cfg.GetInterfaceName()) == 0 || len(cfg.GetLocalAddress()) == 0 { return nil, errors.New("invalid wireguard config") } + q := dao.NewQuery(ctx) + m := dao.NewMutation(ctx) - ips, err := dao.NewQuery(ctx).GetWireGuardLocalAddressesByNetworkID(userInfo, uint(cfg.GetNetworkId())) + ips, err := q.GetWireGuardLocalAddressesByNetworkID(userInfo, uint(cfg.GetNetworkId())) if err != nil { log.WithError(err).Errorf("get wireguard local addresses by network id failed") return nil, err } - network, err := dao.NewQuery(ctx).GetNetworkByID(userInfo, uint(cfg.GetNetworkId())) + network, err := q.GetNetworkByID(userInfo, uint(cfg.GetNetworkId())) if err != nil { log.WithError(err).Errorf("get network by id failed") return nil, err @@ -72,7 +74,7 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea log.Debugf("create wireguard with config: %+v", wgModel) - if err := dao.NewQuery(ctx).CreateWireGuard(userInfo, wgModel); err != nil { + if err := m.CreateWireGuard(userInfo, wgModel); err != nil { return nil, err } @@ -83,7 +85,7 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea } if ep.GetId() > 0 { // 复用现有 endpoint,要求归属同一 client - exist, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, uint(ep.GetId())) + exist, err := q.GetEndpointByID(userInfo, uint(ep.GetId())) if err != nil { return nil, err } @@ -92,7 +94,7 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea } exist.WireGuardID = wgModel.ID - if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil { + if err := m.UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil { return nil, err } } else { @@ -101,19 +103,19 @@ func CreateWireGuard(ctx *app.Context, req *pb.CreateWireGuardRequest) (*pb.Crea newEp.FromPB(ep) newEp.ClientID = cfg.GetClientId() newEp.WireGuardID = wgModel.ID - if err := dao.NewQuery(ctx).CreateEndpoint(userInfo, newEp.EndpointEntity); err != nil { + if err := m.CreateEndpoint(userInfo, newEp.EndpointEntity); err != nil { return nil, err } } } go func() { - peers, err := dao.NewQuery(ctx).GetWireGuardsByNetworkID(userInfo, uint(cfg.GetNetworkId())) + peers, err := q.GetWireGuardsByNetworkID(userInfo, uint(cfg.GetNetworkId())) if err != nil { log.WithError(err).Errorf("get wireguards by network id failed") return } - links, err := dao.NewQuery(ctx).ListWireGuardLinksByNetwork(userInfo, uint(cfg.GetNetworkId())) + links, err := q.ListWireGuardLinksByNetwork(userInfo, uint(cfg.GetNetworkId())) if err != nil { log.WithError(err).Errorf("get wireguard links by network id failed") return diff --git a/biz/master/wg/wireguard_delete.go b/biz/master/wg/wireguard_delete.go index 846a2d9..e801187 100644 --- a/biz/master/wg/wireguard_delete.go +++ b/biz/master/wg/wireguard_delete.go @@ -23,13 +23,16 @@ func DeleteWireGuard(ctx *app.Context, req *pb.DeleteWireGuardRequest) (*pb.Dele return &pb.DeleteWireGuardResponse{Status: &pb.Status{Code: pb.RespCode_RESP_CODE_INVALID, Message: "invalid id"}}, nil } - wgToDelete, err := dao.NewQuery(ctx).GetWireGuardByID(userInfo, id) + q := dao.NewQuery(ctx) + m := dao.NewMutation(ctx) + + wgToDelete, err := q.GetWireGuardByID(userInfo, id) if err != nil { log.WithError(err).Errorf("get wireguard by id failed") return nil, err } - if err := dao.NewQuery(ctx).DeleteWireGuard(userInfo, id); err != nil { + if err := m.DeleteWireGuard(userInfo, id); err != nil { log.WithError(err).Errorf("delete wireguard failed") return nil, err } @@ -51,7 +54,7 @@ func DeleteWireGuard(ctx *app.Context, req *pb.DeleteWireGuardRequest) (*pb.Dele log.Errorf("cannot get response, client id: [%s]", wgToDelete.ClientID) } - peers, err := dao.NewQuery(ctx).GetWireGuardsByNetworkID(userInfo, uint(wgToDelete.NetworkID)) + peers, err := q.GetWireGuardsByNetworkID(userInfo, uint(wgToDelete.NetworkID)) if err != nil { log.WithError(err).Errorf("get wireguards by network id failed") return @@ -62,7 +65,7 @@ func DeleteWireGuard(ctx *app.Context, req *pb.DeleteWireGuardRequest) (*pb.Dele return } - links, err := dao.NewQuery(ctx).ListWireGuardLinksByNetwork(userInfo, uint(wgToDelete.NetworkID)) + links, err := q.ListWireGuardLinksByNetwork(userInfo, uint(wgToDelete.NetworkID)) if err != nil { log.WithError(err).Errorf("get wireguard links by network id failed") return diff --git a/biz/master/wg/wireguard_update.go b/biz/master/wg/wireguard_update.go index 22d6220..71c782f 100644 --- a/biz/master/wg/wireguard_update.go +++ b/biz/master/wg/wireguard_update.go @@ -20,19 +20,21 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda if cfg == nil || cfg.GetId() == 0 || len(cfg.GetClientId()) == 0 || len(cfg.GetInterfaceName()) == 0 || len(cfg.GetPrivateKey()) == 0 || len(cfg.GetLocalAddress()) == 0 { return nil, errors.New("invalid wireguard config") } + q := dao.NewQuery(ctx) + m := dao.NewMutation(ctx) model := &models.WireGuard{} model.FromPB(cfg) model.UserId = uint32(userInfo.GetUserID()) model.TenantId = uint32(userInfo.GetTenantID()) - if err := dao.NewQuery(ctx).UpdateWireGuard(userInfo, uint(cfg.GetId()), model); err != nil { + if err := m.UpdateWireGuard(userInfo, uint(cfg.GetId()), model); err != nil { return nil, err } // 端点赋值与解绑 // 1) 读取当前绑定的端点 - currentList, err := dao.NewQuery(ctx).ListEndpointsWithFilters(userInfo, 1, 1000, "", uint(cfg.GetId()), "") + currentList, err := q.ListEndpointsWithFilters(userInfo, 1, 1000, "", uint(cfg.GetId()), "") if err != nil { return nil, err } @@ -45,7 +47,7 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda continue } if ep.GetId() > 0 { - exist, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, uint(ep.GetId())) + exist, err := q.GetEndpointByID(userInfo, uint(ep.GetId())) if err != nil { return nil, err } @@ -55,13 +57,13 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda } exist.WireGuardID = uint(cfg.GetId()) - if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil { + if err := m.UpdateEndpoint(userInfo, uint(exist.ID), exist.EndpointEntity); err != nil { return nil, err } newSet[uint(exist.ID)] = struct{}{} } else { entity := &models.EndpointEntity{Host: ep.GetHost(), Port: ep.GetPort(), ClientID: cfg.GetClientId(), WireGuardID: uint(cfg.GetId())} - if err := dao.NewQuery(ctx).CreateEndpoint(userInfo, entity); err != nil { + if err := m.CreateEndpoint(userInfo, entity); err != nil { return nil, err } // 无法获取新建 id,这里不加入 newSet;不影响后续解绑逻辑(仅解绑 current - new) @@ -73,12 +75,12 @@ func UpdateWireGuard(ctx *app.Context, req *pb.UpdateWireGuardRequest) (*pb.Upda if _, ok := newSet[id]; ok { continue } - exist, err := dao.NewQuery(ctx).GetEndpointByID(userInfo, id) + exist, err := q.GetEndpointByID(userInfo, id) if err != nil { return nil, err } entity := &models.EndpointEntity{Host: exist.Host, Port: exist.Port, ClientID: exist.ClientID, WireGuardID: 0} - if err := dao.NewQuery(ctx).UpdateEndpoint(userInfo, id, entity); err != nil { + if err := m.UpdateEndpoint(userInfo, id, entity); err != nil { return nil, err } } diff --git a/biz/master/worker/create_worker.go b/biz/master/worker/create_worker.go index a58e3b9..442eea2 100644 --- a/biz/master/worker/create_worker.go +++ b/biz/master/worker/create_worker.go @@ -19,13 +19,15 @@ func CreateWorker(ctx *app.Context, req *pb.CreateWorkerRequest) (*pb.CreateWork clientId = req.GetClientId() reqWorker = req.GetWorker() ) + q := dao.NewQuery(ctx) + m := dao.NewMutation(ctx) if err := validateCreateWorker(req); err != nil { logger.Logger(ctx).WithError(err).Errorf("invalid create worker request, origin is: [%s]", req.String()) return nil, err } - cli, err := dao.NewQuery(ctx).GetClientByClientID(userInfo, clientId) + cli, err := q.GetClientByClientID(userInfo, clientId) if err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot get client, id: [%s], workerName: [%s]", clientId, reqWorker.GetName()) return nil, err @@ -38,7 +40,7 @@ func CreateWorker(ctx *app.Context, req *pb.CreateWorkerRequest) (*pb.CreateWork workerToCreate.Clients = append(workerToCreate.Clients, *cli) - if err := dao.NewQuery(ctx).CreateWorker(userInfo, workerToCreate); err != nil { + if err := m.CreateWorker(userInfo, workerToCreate); err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot create worker, workerName: [%s]", workerToCreate.Name) return nil, err } diff --git a/biz/master/worker/remove_worker.go b/biz/master/worker/remove_worker.go index 36c1a9e..b68c3d3 100644 --- a/biz/master/worker/remove_worker.go +++ b/biz/master/worker/remove_worker.go @@ -19,13 +19,16 @@ func RemoveWorker(ctx *app.Context, req *pb.RemoveWorkerRequest) (*pb.RemoveWork logger.Logger(ctx).Infof("start remove worker, id: [%s]", workerId) - workerToDelete, err := dao.NewQuery(ctx).GetWorkerByWorkerID(userInfo, workerId) + q := dao.NewQuery(ctx) + mut := dao.NewMutation(ctx) + + workerToDelete, err := q.GetWorkerByWorkerID(userInfo, workerId) if err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot get worker, id: [%s]", workerId) return nil, err } - if ingressesToDelete, err := dao.NewQuery(ctx).GetProxyConfigsByWorkerId(userInfo, workerId); err == nil { + if ingressesToDelete, err := q.GetProxyConfigsByWorkerId(userInfo, workerId); err == nil { for _, ingressToDelete := range ingressesToDelete { logger.Logger(ctx).Infof("start to remove worker ingress on server: [%s] client: [%s], name: [%s]", ingressToDelete.ServerID, ingressToDelete.ClientID, ingressToDelete.Name) @@ -44,7 +47,7 @@ func RemoveWorker(ctx *app.Context, req *pb.RemoveWorkerRequest) (*pb.RemoveWork } } - if err := dao.NewQuery(ctx).DeleteWorker(userInfo, workerId); err != nil { + if err := mut.DeleteWorker(userInfo, workerId); err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot remove worker, id: [%s]", workerId) return nil, err } diff --git a/biz/master/worker/update_worker.go b/biz/master/worker/update_worker.go index 29077d3..29e7b8c 100644 --- a/biz/master/worker/update_worker.go +++ b/biz/master/worker/update_worker.go @@ -22,7 +22,10 @@ func UpdateWorker(ctx *app.Context, req *pb.UpdateWorkerRequest) (*pb.UpdateWork oldClientIds []string ) - workerToUpdate, err := dao.NewQuery(ctx).GetWorkerByWorkerID(userInfo, wrokerReq.GetWorkerId()) + q := dao.NewQuery(ctx) + m := dao.NewMutation(ctx) + + workerToUpdate, err := q.GetWorkerByWorkerID(userInfo, wrokerReq.GetWorkerId()) if err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot get worker, id: [%s]", wrokerReq.GetWorkerId()) return nil, fmt.Errorf("cannot get worker, id: [%s]", wrokerReq.GetWorkerId()) @@ -30,7 +33,7 @@ func UpdateWorker(ctx *app.Context, req *pb.UpdateWorkerRequest) (*pb.UpdateWork clis := []*models.Client{} if len(clientIds) != 0 { - clis, err = dao.NewQuery(ctx).GetClientsByClientIDs(userInfo, clientIds) + clis, err = q.GetClientsByClientIDs(userInfo, clientIds) if err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot get client, id: [%s]", utils.MarshalForJson(clientIds)) return nil, fmt.Errorf("cannot get client, id: [%s]", utils.MarshalForJson(clientIds)) @@ -64,7 +67,7 @@ func UpdateWorker(ctx *app.Context, req *pb.UpdateWorkerRequest) (*pb.UpdateWork updatedFields = append(updatedFields, "config_template") } - if err := dao.NewQuery(ctx).UpdateWorker(userInfo, workerToUpdate); err != nil { + if err := m.UpdateWorker(userInfo, workerToUpdate); err != nil { logger.Logger(ctx).WithError(err).Errorf("cannot update worker, id: [%s]", wrokerReq.GetWorkerId()) return nil, fmt.Errorf("cannot update worker, id: [%s]", wrokerReq.GetWorkerId()) } diff --git a/cmd/frpp/shared/providers.go b/cmd/frpp/shared/providers.go index 62c622a..fbde9a3 100644 --- a/cmd/frpp/shared/providers.go +++ b/cmd/frpp/shared/providers.go @@ -151,7 +151,7 @@ func NewDBManager(ctx *app.Context, appInstance app.Application) app.DBManager { } func NewMasterTLSConfig(ctx *app.Context) *tls.Config { - return dao.NewQuery(ctx).InitCert(conf.GetCertTemplate(ctx.GetApp().GetConfig())) + return dao.NewMutation(ctx).InitCert(conf.GetCertTemplate(ctx.GetApp().GetConfig())) } func NewTLSMasterService(appInstance app.Application, masterTLSConfig *tls.Config) master.MasterService { @@ -250,7 +250,7 @@ func NewDefaultServerConfig(ctx *app.Context) conf.Config { logger.Logger(ctx).Infof("init default internal server") - dao.NewQuery(ctx).InitDefaultServer(appInstance.GetConfig().Master.APIHost) + dao.NewMutation(ctx).InitDefaultServer(appInstance.GetConfig().Master.APIHost) defaultServer, err := dao.NewQuery(ctx).GetDefaultServer() if err != nil { diff --git a/services/dao/cert.go b/services/dao/cert.go index 106415a..361010f 100644 --- a/services/dao/cert.go +++ b/services/dao/cert.go @@ -14,13 +14,30 @@ import ( "github.com/VaalaCat/frp-panel/utils/logger" ) -func (q *queryImpl) InitCert(template *x509.Certificate) *tls.Config { +type CertQuery interface { + CountCerts() (int64, error) + GetDefaultKeyPair() (keyPem []byte, certPem []byte, err error) +} + +type CertMutation interface { + InitCert(template *x509.Certificate) *tls.Config +} + +type certQuery struct{ *queryImpl } +type certMutation struct{ *mutationImpl } + +func newCertQuery(base *queryImpl) CertQuery { return &certQuery{base} } +func newCertMutation(base *mutationImpl) CertMutation { return &certMutation{base} } + +func (m *certMutation) InitCert(template *x509.Certificate) *tls.Config { ctx := context.Background() var ( certPem []byte keyPem []byte ) - cnt, err := q.CountCerts() + query := NewQuery(m.ctx) + + cnt, err := query.CountCerts() if err != nil { logger.Logger(ctx).Fatal(err) } @@ -29,7 +46,7 @@ func (q *queryImpl) InitCert(template *x509.Certificate) *tls.Config { if err != nil { logger.Logger(ctx).Fatal(err) } - if err = q.ctx.GetApp().GetDBManager().GetDefaultDB().Create(&models.Cert{ + if err = m.ctx.GetApp().GetDBManager().GetDefaultDB().Create(&models.Cert{ Name: "default", CertFile: certPem, CaFile: certPem, @@ -38,7 +55,7 @@ func (q *queryImpl) InitCert(template *x509.Certificate) *tls.Config { logger.Logger(ctx).Fatal(err) } } else { - keyPem, certPem, err = q.GetDefaultKeyPair() + keyPem, certPem, err = query.GetDefaultKeyPair() if err != nil { logger.Logger(ctx).Fatal(err) } @@ -79,7 +96,7 @@ func GenX509Info(template *x509.Certificate) (certPem []byte, keyPem []byte, err return certBuf.Bytes(), keyBuf.Bytes(), nil } -func (q *queryImpl) CountCerts() (int64, error) { +func (q *certQuery) CountCerts() (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Cert{}).Count(&count).Error @@ -89,7 +106,7 @@ func (q *queryImpl) CountCerts() (int64, error) { return count, nil } -func (q *queryImpl) GetDefaultKeyPair() (keyPem []byte, certPem []byte, err error) { +func (q *certQuery) GetDefaultKeyPair() (keyPem []byte, certPem []byte, err error) { resp := &models.Cert{} err = q.ctx.GetApp().GetDBManager().GetDefaultDB().Model(&models.Cert{}). Where(&models.Cert{Name: "default"}).First(resp).Error diff --git a/services/dao/client.go b/services/dao/client.go index 087fb02..767c1d4 100644 --- a/services/dao/client.go +++ b/services/dao/client.go @@ -9,7 +9,38 @@ import ( "gorm.io/gorm" ) -func (q *queryImpl) ValidateClientSecret(clientID, clientSecret string) (*models.ClientEntity, error) { +type ClientQuery interface { + ValidateClientSecret(clientID, clientSecret string) (*models.ClientEntity, error) + AdminGetClientByClientID(clientID string) (*models.Client, error) + GetClientByClientID(userInfo models.UserInfo, clientID string) (*models.Client, error) + GetClientsByClientIDs(userInfo models.UserInfo, clientIDs []string) ([]*models.Client, error) + GetClientByFilter(userInfo models.UserInfo, client *models.ClientEntity, shadow *bool) (*models.ClientEntity, error) + GetClientByOriginClientID(originClientID string) (*models.ClientEntity, error) + ListClients(userInfo models.UserInfo, page, pageSize int) ([]*models.ClientEntity, error) + ListClientsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ClientEntity, error) + GetAllClients(userInfo models.UserInfo) ([]*models.ClientEntity, error) + CountClients(userInfo models.UserInfo) (int64, error) + CountClientsWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) + CountConfiguredClients(userInfo models.UserInfo) (int64, error) + CountClientsInShadow(userInfo models.UserInfo, clientID string) (int64, error) + GetClientIDsInShadowByClientID(userInfo models.UserInfo, clientID string) ([]string, error) + AdminGetClientIDsInShadowByClientID(clientID string) ([]string, error) +} + +type ClientMutation interface { + CreateClient(userInfo models.UserInfo, client *models.ClientEntity) error + DeleteClient(userInfo models.UserInfo, clientID string) error + UpdateClient(userInfo models.UserInfo, client *models.ClientEntity) error + AdminUpdateClientLastSeen(clientID string) error +} + +type clientQuery struct{ *queryImpl } +type clientMutation struct{ *mutationImpl } + +func newClientQuery(base *queryImpl) ClientQuery { return &clientQuery{base} } +func newClientMutation(base *mutationImpl) ClientMutation { return &clientMutation{base} } + +func (q *clientQuery) ValidateClientSecret(clientID, clientSecret string) (*models.ClientEntity, error) { if clientID == "" || clientSecret == "" { return nil, fmt.Errorf("invalid client id or client secret") } @@ -29,7 +60,7 @@ func (q *queryImpl) ValidateClientSecret(clientID, clientSecret string) (*models return c.ClientEntity, nil } -func (q *queryImpl) AdminGetClientByClientID(clientID string) (*models.Client, error) { +func (q *clientQuery) AdminGetClientByClientID(clientID string) (*models.Client, error) { if clientID == "" { return nil, fmt.Errorf("invalid client id") } @@ -46,7 +77,7 @@ func (q *queryImpl) AdminGetClientByClientID(clientID string) (*models.Client, e return c, nil } -func (q *queryImpl) GetClientByClientID(userInfo models.UserInfo, clientID string) (*models.Client, error) { +func (q *clientQuery) GetClientByClientID(userInfo models.UserInfo, clientID string) (*models.Client, error) { if clientID == "" { return nil, fmt.Errorf("invalid client id") } @@ -65,7 +96,7 @@ func (q *queryImpl) GetClientByClientID(userInfo models.UserInfo, clientID strin return c, nil } -func (q *queryImpl) GetClientsByClientIDs(userInfo models.UserInfo, clientIDs []string) ([]*models.Client, error) { +func (q *clientQuery) GetClientsByClientIDs(userInfo models.UserInfo, clientIDs []string) ([]*models.Client, error) { if len(clientIDs) == 0 { return nil, fmt.Errorf("invalid client ids") } @@ -80,7 +111,7 @@ func (q *queryImpl) GetClientsByClientIDs(userInfo models.UserInfo, clientIDs [] return cs, nil } -func (q *queryImpl) GetClientByFilter(userInfo models.UserInfo, client *models.ClientEntity, shadow *bool) (*models.ClientEntity, error) { +func (q *clientQuery) GetClientByFilter(userInfo models.UserInfo, client *models.ClientEntity, shadow *bool) (*models.ClientEntity, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() filter := &models.ClientEntity{} if len(client.ClientID) != 0 { @@ -113,7 +144,7 @@ func (q *queryImpl) GetClientByFilter(userInfo models.UserInfo, client *models.C return c.ClientEntity, nil } -func (q *queryImpl) GetClientByOriginClientID(originClientID string) (*models.ClientEntity, error) { +func (q *clientQuery) GetClientByOriginClientID(originClientID string) (*models.ClientEntity, error) { if originClientID == "" { return nil, fmt.Errorf("invalid origin client id") } @@ -130,21 +161,21 @@ func (q *queryImpl) GetClientByOriginClientID(originClientID string) (*models.Cl return c.ClientEntity, nil } -func (q *queryImpl) CreateClient(userInfo models.UserInfo, client *models.ClientEntity) error { +func (m *clientMutation) CreateClient(userInfo models.UserInfo, client *models.ClientEntity) error { client.UserID = userInfo.GetUserID() client.TenantID = userInfo.GetTenantID() c := &models.Client{ ClientEntity: client, } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(c).Error } -func (q *queryImpl) DeleteClient(userInfo models.UserInfo, clientID string) error { +func (m *clientMutation) DeleteClient(userInfo models.UserInfo, clientID string) error { if clientID == "" { return fmt.Errorf("invalid client id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where(&models.Client{ ClientEntity: &models.ClientEntity{ ClientID: clientID, @@ -160,11 +191,11 @@ func (q *queryImpl) DeleteClient(userInfo models.UserInfo, clientID string) erro }).Delete(&models.Client{}).Error } -func (q *queryImpl) UpdateClient(userInfo models.UserInfo, client *models.ClientEntity) error { +func (m *clientMutation) UpdateClient(userInfo models.UserInfo, client *models.ClientEntity) error { c := &models.Client{ ClientEntity: client, } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Where(&models.Client{ ClientEntity: &models.ClientEntity{ UserID: userInfo.GetUserID(), @@ -173,7 +204,7 @@ func (q *queryImpl) UpdateClient(userInfo models.UserInfo, client *models.Client }).Save(c).Error } -func (q *queryImpl) ListClients(userInfo models.UserInfo, page, pageSize int) ([]*models.ClientEntity, error) { +func (q *clientQuery) ListClients(userInfo models.UserInfo, page, pageSize int) ([]*models.ClientEntity, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -202,7 +233,7 @@ func (q *queryImpl) ListClients(userInfo models.UserInfo, page, pageSize int) ([ }), nil } -func (q *queryImpl) ListClientsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ClientEntity, error) { +func (q *clientQuery) ListClientsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ClientEntity, error) { // 只获取没shadow且config有东西 // 或isShadow的client if page < 1 || pageSize < 1 || len(keyword) == 0 { @@ -229,7 +260,7 @@ func (q *queryImpl) ListClientsWithKeyword(userInfo models.UserInfo, page, pageS }), nil } -func (q *queryImpl) GetAllClients(userInfo models.UserInfo) ([]*models.ClientEntity, error) { +func (q *clientQuery) GetAllClients(userInfo models.UserInfo) ([]*models.ClientEntity, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var clients []*models.Client err := db.Where(&models.Client{ @@ -247,7 +278,7 @@ func (q *queryImpl) GetAllClients(userInfo models.UserInfo) ([]*models.ClientEnt }), nil } -func (q *queryImpl) CountClients(userInfo models.UserInfo) (int64, error) { +func (q *clientQuery) CountClients(userInfo models.UserInfo) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Client{}).Where(&models.Client{ @@ -263,7 +294,7 @@ func (q *queryImpl) CountClients(userInfo models.UserInfo) (int64, error) { return count, nil } -func (q *queryImpl) CountClientsWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { +func (q *clientQuery) CountClientsWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Client{}).Where(&models.Client{ @@ -279,7 +310,7 @@ func (q *queryImpl) CountClientsWithKeyword(userInfo models.UserInfo, keyword st return count, nil } -func (q *queryImpl) CountConfiguredClients(userInfo models.UserInfo) (int64, error) { +func (q *clientQuery) CountConfiguredClients(userInfo models.UserInfo) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Client{}). @@ -296,7 +327,7 @@ func (q *queryImpl) CountConfiguredClients(userInfo models.UserInfo) (int64, err return count, nil } -func (q *queryImpl) CountClientsInShadow(userInfo models.UserInfo, clientID string) (int64, error) { +func (q *clientQuery) CountClientsInShadow(userInfo models.UserInfo, clientID string) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Client{}). @@ -313,7 +344,7 @@ func (q *queryImpl) CountClientsInShadow(userInfo models.UserInfo, clientID stri return count, nil } -func (q *queryImpl) GetClientIDsInShadowByClientID(userInfo models.UserInfo, clientID string) ([]string, error) { +func (q *clientQuery) GetClientIDsInShadowByClientID(userInfo models.UserInfo, clientID string) ([]string, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var clients []*models.Client err := db.Where(&models.Client{ @@ -330,7 +361,7 @@ func (q *queryImpl) GetClientIDsInShadowByClientID(userInfo models.UserInfo, cli }), nil } -func (q *queryImpl) AdminGetClientIDsInShadowByClientID(clientID string) ([]string, error) { +func (q *clientQuery) AdminGetClientIDsInShadowByClientID(clientID string) ([]string, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var clients []*models.Client err := db.Where(&models.Client{ @@ -345,8 +376,8 @@ func (q *queryImpl) AdminGetClientIDsInShadowByClientID(clientID string) ([]stri }), nil } -func (q *queryImpl) AdminUpdateClientLastSeen(clientID string) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *clientMutation) AdminUpdateClientLastSeen(clientID string) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Model(&models.Client{ ClientEntity: &models.ClientEntity{ ClientID: clientID, diff --git a/services/dao/endpoint.go b/services/dao/endpoint.go index 612b3d7..b2e09db 100644 --- a/services/dao/endpoint.go +++ b/services/dao/endpoint.go @@ -7,7 +7,27 @@ import ( "gorm.io/gorm" ) -func (q *queryImpl) CreateEndpoint(userInfo models.UserInfo, endpoint *models.EndpointEntity) error { +type EndpointQuery interface { + GetEndpointByID(userInfo models.UserInfo, id uint) (*models.Endpoint, error) + ListEndpoints(userInfo models.UserInfo, page, pageSize int) ([]*models.Endpoint, error) + CountEndpoints(userInfo models.UserInfo) (int64, error) + ListEndpointsWithFilters(userInfo models.UserInfo, page, pageSize int, clientID string, wireguardID uint, keyword string) ([]*models.Endpoint, error) + CountEndpointsWithFilters(userInfo models.UserInfo, clientID string, wireguardID uint, keyword string) (int64, error) +} + +type EndpointMutation interface { + CreateEndpoint(userInfo models.UserInfo, endpoint *models.EndpointEntity) error + UpdateEndpoint(userInfo models.UserInfo, id uint, endpoint *models.EndpointEntity) error + DeleteEndpoint(userInfo models.UserInfo, id uint) error +} + +type endpointQuery struct{ *queryImpl } +type endpointMutation struct{ *mutationImpl } + +func newEndpointQuery(base *queryImpl) EndpointQuery { return &endpointQuery{base} } +func newEndpointMutation(base *mutationImpl) EndpointMutation { return &endpointMutation{base} } + +func (m *endpointMutation) CreateEndpoint(userInfo models.UserInfo, endpoint *models.EndpointEntity) error { if endpoint == nil { return fmt.Errorf("invalid endpoint entity") } @@ -15,29 +35,29 @@ func (q *queryImpl) CreateEndpoint(userInfo models.UserInfo, endpoint *models.En return fmt.Errorf("invalid endpoint host or port") } // scope via parent wireguard/client - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(&models.Endpoint{EndpointEntity: endpoint}).Error } -func (q *queryImpl) UpdateEndpoint(userInfo models.UserInfo, id uint, endpoint *models.EndpointEntity) error { +func (m *endpointMutation) UpdateEndpoint(userInfo models.UserInfo, id uint, endpoint *models.EndpointEntity) error { if id == 0 || endpoint == nil { return fmt.Errorf("invalid endpoint id or entity") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Where(&models.Endpoint{ Model: gorm.Model{ID: id}, }).Save(&models.Endpoint{Model: gorm.Model{ID: id}, EndpointEntity: endpoint}).Error } -func (q *queryImpl) DeleteEndpoint(userInfo models.UserInfo, id uint) error { +func (m *endpointMutation) DeleteEndpoint(userInfo models.UserInfo, id uint) error { if id == 0 { return fmt.Errorf("invalid endpoint id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where(&models.Endpoint{Model: gorm.Model{ID: id}}).Delete(&models.Endpoint{}).Error } -func (q *queryImpl) GetEndpointByID(userInfo models.UserInfo, id uint) (*models.Endpoint, error) { +func (q *endpointQuery) GetEndpointByID(userInfo models.UserInfo, id uint) (*models.Endpoint, error) { if id == 0 { return nil, fmt.Errorf("invalid endpoint id") } @@ -49,7 +69,7 @@ func (q *queryImpl) GetEndpointByID(userInfo models.UserInfo, id uint) (*models. return &e, nil } -func (q *queryImpl) ListEndpoints(userInfo models.UserInfo, page, pageSize int) ([]*models.Endpoint, error) { +func (q *endpointQuery) ListEndpoints(userInfo models.UserInfo, page, pageSize int) ([]*models.Endpoint, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -62,7 +82,7 @@ func (q *queryImpl) ListEndpoints(userInfo models.UserInfo, page, pageSize int) return list, nil } -func (q *queryImpl) CountEndpoints(userInfo models.UserInfo) (int64, error) { +func (q *endpointQuery) CountEndpoints(userInfo models.UserInfo) (int64, error) { var count int64 db := q.ctx.GetApp().GetDBManager().GetDefaultDB() if err := db.Model(&models.Endpoint{}).Count(&count).Error; err != nil { @@ -72,7 +92,7 @@ func (q *queryImpl) CountEndpoints(userInfo models.UserInfo) (int64, error) { } // ListEndpointsWithFilters 根据 clientID / wireguardID / keyword 过滤端点 -func (q *queryImpl) ListEndpointsWithFilters(userInfo models.UserInfo, page, pageSize int, clientID string, wireguardID uint, keyword string) ([]*models.Endpoint, error) { +func (q *endpointQuery) ListEndpointsWithFilters(userInfo models.UserInfo, page, pageSize int, clientID string, wireguardID uint, keyword string) ([]*models.Endpoint, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -80,7 +100,7 @@ func (q *queryImpl) ListEndpointsWithFilters(userInfo models.UserInfo, page, pag // 若指定 clientID,先校验归属 if len(clientID) > 0 { - if _, err := q.GetClientByClientID(userInfo, clientID); err != nil { + if _, err := newClientQuery(q.queryImpl).GetClientByClientID(userInfo, clientID); err != nil { return nil, err } } @@ -103,11 +123,11 @@ func (q *queryImpl) ListEndpointsWithFilters(userInfo models.UserInfo, page, pag return list, nil } -func (q *queryImpl) CountEndpointsWithFilters(userInfo models.UserInfo, clientID string, wireguardID uint, keyword string) (int64, error) { +func (q *endpointQuery) CountEndpointsWithFilters(userInfo models.UserInfo, clientID string, wireguardID uint, keyword string) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() if len(clientID) > 0 { - if _, err := q.GetClientByClientID(userInfo, clientID); err != nil { + if _, err := newClientQuery(q.queryImpl).GetClientByClientID(userInfo, clientID); err != nil { return 0, err } } diff --git a/services/dao/link.go b/services/dao/link.go index e3d377a..00a6efe 100644 --- a/services/dao/link.go +++ b/services/dao/link.go @@ -8,7 +8,29 @@ import ( "gorm.io/gorm" ) -func (q *queryImpl) CreateWireGuardLink(userInfo models.UserInfo, link *models.WireGuardLink) error { +type LinkQuery interface { + ListWireGuardLinksByNetwork(userInfo models.UserInfo, networkID uint) ([]*models.WireGuardLink, error) + GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*models.WireGuardLink, error) + GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromClientId, toClientId uint) (*models.WireGuardLink, error) + ListWireGuardLinksWithFilters(userInfo models.UserInfo, page, pageSize int, networkID uint, keyword string) ([]*models.WireGuardLink, error) + CountWireGuardLinksWithFilters(userInfo models.UserInfo, networkID uint, keyword string) (int64, error) + AdminListWireGuardLinksWithNetworkIDs(networkIDs []uint) ([]*models.WireGuardLink, error) +} + +type LinkMutation interface { + CreateWireGuardLink(userInfo models.UserInfo, link *models.WireGuardLink) error + CreateWireGuardLinks(userInfo models.UserInfo, links ...*models.WireGuardLink) error + UpdateWireGuardLink(userInfo models.UserInfo, id uint, link *models.WireGuardLink) error + DeleteWireGuardLink(userInfo models.UserInfo, id uint) error +} + +type linkQuery struct{ *queryImpl } +type linkMutation struct{ *mutationImpl } + +func newLinkQuery(base *queryImpl) LinkQuery { return &linkQuery{base} } +func newLinkMutation(base *mutationImpl) LinkMutation { return &linkMutation{base} } + +func (m *linkMutation) CreateWireGuardLink(userInfo models.UserInfo, link *models.WireGuardLink) error { if link == nil { return fmt.Errorf("invalid wg link") } @@ -17,11 +39,11 @@ func (q *queryImpl) CreateWireGuardLink(userInfo models.UserInfo, link *models.W } link.UserId = uint32(userInfo.GetUserID()) link.TenantId = uint32(userInfo.GetTenantID()) - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(link).Error } -func (q *queryImpl) CreateWireGuardLinks(userInfo models.UserInfo, links ...*models.WireGuardLink) error { +func (m *linkMutation) CreateWireGuardLinks(userInfo models.UserInfo, links ...*models.WireGuardLink) error { if len(links) == 0 { return fmt.Errorf("invalid wg links") } @@ -29,11 +51,11 @@ func (q *queryImpl) CreateWireGuardLinks(userInfo models.UserInfo, links ...*mod link.UserId = uint32(userInfo.GetUserID()) link.TenantId = uint32(userInfo.GetTenantID()) } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(links).Error } -func (q *queryImpl) UpdateWireGuardLink(userInfo models.UserInfo, id uint, link *models.WireGuardLink) error { +func (m *linkMutation) UpdateWireGuardLink(userInfo models.UserInfo, id uint, link *models.WireGuardLink) error { if id == 0 || link == nil { return fmt.Errorf("invalid wg link id or entity") } @@ -43,18 +65,18 @@ func (q *queryImpl) UpdateWireGuardLink(userInfo models.UserInfo, id uint, link } link.UserId = uint32(userInfo.GetUserID()) link.TenantId = uint32(userInfo.GetTenantID()) - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Where(&models.WireGuardLink{ Model: link.Model, WireGuardLinkEntity: &models.WireGuardLinkEntity{UserId: link.UserId, TenantId: link.TenantId}, }).Save(link).Error } -func (q *queryImpl) DeleteWireGuardLink(userInfo models.UserInfo, id uint) error { +func (m *linkMutation) DeleteWireGuardLink(userInfo models.UserInfo, id uint) error { if id == 0 { return fmt.Errorf("invalid wg link id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where(&models.WireGuardLink{ Model: gorm.Model{ID: id}, WireGuardLinkEntity: &models.WireGuardLinkEntity{ @@ -63,7 +85,7 @@ func (q *queryImpl) DeleteWireGuardLink(userInfo models.UserInfo, id uint) error }}).Delete(&models.WireGuardLink{}).Error } -func (q *queryImpl) ListWireGuardLinksByNetwork(userInfo models.UserInfo, networkID uint) ([]*models.WireGuardLink, error) { +func (q *linkQuery) ListWireGuardLinksByNetwork(userInfo models.UserInfo, networkID uint) ([]*models.WireGuardLink, error) { if networkID == 0 { return nil, fmt.Errorf("invalid network id") } @@ -81,7 +103,7 @@ func (q *queryImpl) ListWireGuardLinksByNetwork(userInfo models.UserInfo, networ } // GetWireGuardLinkByID 根据 ID 查询 Link(按租户隔离) -func (q *queryImpl) GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*models.WireGuardLink, error) { +func (q *linkQuery) GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*models.WireGuardLink, error) { if id == 0 { return nil, fmt.Errorf("invalid wg link id") } @@ -99,7 +121,7 @@ func (q *queryImpl) GetWireGuardLinkByID(userInfo models.UserInfo, id uint) (*mo return &m, nil } -func (q *queryImpl) GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromClientId, toClientId uint) (*models.WireGuardLink, error) { +func (q *linkQuery) GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromClientId, toClientId uint) (*models.WireGuardLink, error) { if fromClientId == 0 || toClientId == 0 { return nil, fmt.Errorf("invalid from client id or to client id") } @@ -117,7 +139,7 @@ func (q *queryImpl) GetWireGuardLinkByClientIDs(userInfo models.UserInfo, fromCl } // ListWireGuardLinksWithFilters 分页查询 Link,支持按 networkID 过滤与关键字(数字时匹配 from/to id) -func (q *queryImpl) ListWireGuardLinksWithFilters(userInfo models.UserInfo, page, pageSize int, networkID uint, keyword string) ([]*models.WireGuardLink, error) { +func (q *linkQuery) ListWireGuardLinksWithFilters(userInfo models.UserInfo, page, pageSize int, networkID uint, keyword string) ([]*models.WireGuardLink, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -146,7 +168,7 @@ func (q *queryImpl) ListWireGuardLinksWithFilters(userInfo models.UserInfo, page } // CountWireGuardLinksWithFilters 统计分页条件下的总数 -func (q *queryImpl) CountWireGuardLinksWithFilters(userInfo models.UserInfo, networkID uint, keyword string) (int64, error) { +func (q *linkQuery) CountWireGuardLinksWithFilters(userInfo models.UserInfo, networkID uint, keyword string) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 @@ -170,7 +192,7 @@ func (q *queryImpl) CountWireGuardLinksWithFilters(userInfo models.UserInfo, net return count, nil } -func (q *queryImpl) AdminListWireGuardLinksWithNetworkIDs(networkIDs []uint) ([]*models.WireGuardLink, error) { +func (q *linkQuery) AdminListWireGuardLinksWithNetworkIDs(networkIDs []uint) ([]*models.WireGuardLink, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var list []*models.WireGuardLink if err := db.Where("network_id IN ?", networkIDs).Find(&list).Error; err != nil { diff --git a/services/dao/network.go b/services/dao/network.go index 1ac0287..51396d7 100644 --- a/services/dao/network.go +++ b/services/dao/network.go @@ -7,7 +7,27 @@ import ( "gorm.io/gorm" ) -func (q *queryImpl) CreateNetwork(userInfo models.UserInfo, network *models.NetworkEntity) error { +type NetworkQuery interface { + GetNetworkByID(userInfo models.UserInfo, id uint) (*models.Network, error) + ListNetworks(userInfo models.UserInfo, page, pageSize int) ([]*models.Network, error) + ListNetworksWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Network, error) + CountNetworks(userInfo models.UserInfo) (int64, error) + CountNetworksWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) +} + +type NetworkMutation interface { + CreateNetwork(userInfo models.UserInfo, network *models.NetworkEntity) error + UpdateNetwork(userInfo models.UserInfo, id uint, network *models.NetworkEntity) error + DeleteNetwork(userInfo models.UserInfo, id uint) error +} + +type networkQuery struct{ *queryImpl } +type networkMutation struct{ *mutationImpl } + +func newNetworkQuery(base *queryImpl) NetworkQuery { return &networkQuery{base} } +func newNetworkMutation(base *mutationImpl) NetworkMutation { return &networkMutation{base} } + +func (m *networkMutation) CreateNetwork(userInfo models.UserInfo, network *models.NetworkEntity) error { if network == nil { return fmt.Errorf("invalid network entity") } @@ -18,11 +38,11 @@ func (q *queryImpl) CreateNetwork(userInfo models.UserInfo, network *models.Netw network.UserId = uint32(userInfo.GetUserID()) network.TenantId = uint32(userInfo.GetTenantID()) - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(&models.Network{NetworkEntity: network}).Error } -func (q *queryImpl) UpdateNetwork(userInfo models.UserInfo, id uint, network *models.NetworkEntity) error { +func (m *networkMutation) UpdateNetwork(userInfo models.UserInfo, id uint, network *models.NetworkEntity) error { if id == 0 || network == nil { return fmt.Errorf("invalid network id or entity") } @@ -30,7 +50,7 @@ func (q *queryImpl) UpdateNetwork(userInfo models.UserInfo, id uint, network *mo network.UserId = uint32(userInfo.GetUserID()) network.TenantId = uint32(userInfo.GetTenantID()) - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Where(&models.Network{ Model: gorm.Model{ID: id}, NetworkEntity: &models.NetworkEntity{ @@ -43,11 +63,11 @@ func (q *queryImpl) UpdateNetwork(userInfo models.UserInfo, id uint, network *mo }).Error } -func (q *queryImpl) DeleteNetwork(userInfo models.UserInfo, id uint) error { +func (m *networkMutation) DeleteNetwork(userInfo models.UserInfo, id uint) error { if id == 0 { return fmt.Errorf("invalid network id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where(&models.Network{ Model: gorm.Model{ID: id}, NetworkEntity: &models.NetworkEntity{ @@ -57,7 +77,7 @@ func (q *queryImpl) DeleteNetwork(userInfo models.UserInfo, id uint) error { }).Delete(&models.Network{}).Error } -func (q *queryImpl) GetNetworkByID(userInfo models.UserInfo, id uint) (*models.Network, error) { +func (q *networkQuery) GetNetworkByID(userInfo models.UserInfo, id uint) (*models.Network, error) { if id == 0 { return nil, fmt.Errorf("invalid network id") } @@ -75,7 +95,7 @@ func (q *queryImpl) GetNetworkByID(userInfo models.UserInfo, id uint) (*models.N return &n, nil } -func (q *queryImpl) ListNetworks(userInfo models.UserInfo, page, pageSize int) ([]*models.Network, error) { +func (q *networkQuery) ListNetworks(userInfo models.UserInfo, page, pageSize int) ([]*models.Network, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -91,7 +111,7 @@ func (q *queryImpl) ListNetworks(userInfo models.UserInfo, page, pageSize int) ( return list, nil } -func (q *queryImpl) ListNetworksWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Network, error) { +func (q *networkQuery) ListNetworksWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Network, error) { if page < 1 || pageSize < 1 || len(keyword) == 0 { return nil, fmt.Errorf("invalid page or page size or keyword") } @@ -107,7 +127,7 @@ func (q *queryImpl) ListNetworksWithKeyword(userInfo models.UserInfo, page, page return list, nil } -func (q *queryImpl) CountNetworks(userInfo models.UserInfo) (int64, error) { +func (q *networkQuery) CountNetworks(userInfo models.UserInfo) (int64, error) { var count int64 db := q.ctx.GetApp().GetDBManager().GetDefaultDB() if err := db.Model(&models.Network{}).Where(&models.Network{NetworkEntity: &models.NetworkEntity{ @@ -119,7 +139,7 @@ func (q *queryImpl) CountNetworks(userInfo models.UserInfo) (int64, error) { return count, nil } -func (q *queryImpl) CountNetworksWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { +func (q *networkQuery) CountNetworksWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { var count int64 db := q.ctx.GetApp().GetDBManager().GetDefaultDB() if err := db.Model(&models.Network{}).Where(&models.Network{NetworkEntity: &models.NetworkEntity{ diff --git a/services/dao/proxy.go b/services/dao/proxy.go index c787218..75e3c8e 100644 --- a/services/dao/proxy.go +++ b/services/dao/proxy.go @@ -16,7 +16,44 @@ import ( "gorm.io/gorm/clause" ) -func (q *queryImpl) GetProxyStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyStatsEntity, error) { +type ProxyQuery interface { + GetProxyStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyStatsEntity, error) + GetProxyStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.ProxyStatsEntity, error) + AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStatsEntity, error) + AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEntity, error) + AdminGetProxyConfigByClientIDAndName(clientID string, name string) (*models.ProxyConfig, error) + GetProxyConfigsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyConfigEntity, error) + GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig *models.ProxyConfigEntity) (*models.ProxyConfig, error) + ListProxyConfigsWithFilters(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) + AdminListProxyConfigsWithFilters(filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) + ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity, keyword string) ([]*models.ProxyConfig, error) + ListProxyConfigsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ProxyConfig, error) + ListProxyConfigs(userInfo models.UserInfo, page, pageSize int) ([]*models.ProxyConfig, error) + GetProxyConfigByOriginClientIDAndName(userInfo models.UserInfo, clientID string, name string) (*models.ProxyConfig, error) + CountProxyConfigs(userInfo models.UserInfo) (int64, error) + CountProxyConfigsWithFilters(userInfo models.UserInfo, filters *models.ProxyConfigEntity) (int64, error) + CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, filters *models.ProxyConfigEntity, keyword string) (int64, error) + GetProxyConfigsByWorkerId(userInfo models.UserInfo, workerID string) ([]*models.ProxyConfig, error) +} + +type ProxyMutation interface { + AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb.ProxyInfo) error + AdminCreateProxyConfig(proxyCfg *models.ProxyConfig) error + RebuildProxyConfigFromClient(userInfo models.UserInfo, client *models.Client) error + CreateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfigEntity) error + UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfig) error + DeleteProxyConfig(userInfo models.UserInfo, clientID, name string) error + DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models.UserInfo, clientID string) error + DeleteProxyConfigsByClientID(userInfo models.UserInfo, clientID string) error +} + +type proxyQuery struct{ *queryImpl } +type proxyMutation struct{ *mutationImpl } + +func newProxyQuery(base *queryImpl) ProxyQuery { return &proxyQuery{base} } +func newProxyMutation(base *mutationImpl) ProxyMutation { return &proxyMutation{base} } + +func (q *proxyQuery) GetProxyStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyStatsEntity, error) { if clientID == "" { return nil, fmt.Errorf("invalid client id") } @@ -52,7 +89,7 @@ func (q *queryImpl) GetProxyStatsByClientID(userInfo models.UserInfo, clientID s }), nil } -func (q *queryImpl) GetProxyStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.ProxyStatsEntity, error) { +func (q *proxyQuery) GetProxyStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.ProxyStatsEntity, error) { if serverID == "" { return nil, fmt.Errorf("invalid server id") } @@ -77,12 +114,12 @@ func (q *queryImpl) GetProxyStatsByServerID(userInfo models.UserInfo, serverID s }), nil } -func (q *queryImpl) AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb.ProxyInfo) error { +func (m *proxyMutation) AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb.ProxyInfo) error { if srv.ServerID == "" { return fmt.Errorf("invalid server id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Transaction(func(tx *gorm.DB) error { queryResults := make([]interface{}, 3) @@ -201,7 +238,7 @@ func (q *queryImpl) AdminUpdateProxyStats(srv *models.ServerEntity, inputs []*pb }) } -func (q *queryImpl) AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStatsEntity, error) { +func (q *proxyQuery) AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStatsEntity, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() list := []*models.ProxyStats{} err := db. @@ -217,7 +254,7 @@ func (q *queryImpl) AdminGetTenantProxyStats(tenantID int) ([]*models.ProxyStats }), nil } -func (q *queryImpl) AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEntity, error) { +func (q *proxyQuery) AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEntity, error) { db := tx list := []*models.ProxyStats{} err := db.Clauses(clause.Locking{Strength: "UPDATE"}). @@ -230,15 +267,16 @@ func (q *queryImpl) AdminGetAllProxyStats(tx *gorm.DB) ([]*models.ProxyStatsEnti }), nil } -func (q *queryImpl) AdminCreateProxyConfig(proxyCfg *models.ProxyConfig) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *proxyMutation) AdminCreateProxyConfig(proxyCfg *models.ProxyConfig) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(proxyCfg).Error } // RebuildProxyConfigFromClient rebuild proxy from client // skip stopped proxy -func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, client *models.Client) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *proxyMutation) RebuildProxyConfigFromClient(userInfo models.UserInfo, client *models.Client) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() + query := NewQuery(m.ctx) pxyCfgs, err := utils.LoadProxiesFromContent(client.ConfigContent) if err != nil { @@ -251,7 +289,7 @@ func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, clien proxyCfg := &models.ProxyConfig{ ProxyConfigEntity: &models.ProxyConfigEntity{}, } - if oldProxyCfg, err := q.GetProxyConfigByOriginClientIDAndName(userInfo, client.ClientID, pxyCfg.GetBaseConfig().Name); err == nil { + if oldProxyCfg, err := query.GetProxyConfigByOriginClientIDAndName(userInfo, client.ClientID, pxyCfg.GetBaseConfig().Name); err == nil { logger.Logger(context.Background()).WithError(err).Warnf("proxy config already exist, will be override, clientID: [%s], name: [%s]", client.ClientID, pxyCfg.GetBaseConfig().Name) proxyCfg.Model = oldProxyCfg.Model @@ -268,7 +306,7 @@ func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, clien proxyConfigEntities = append(proxyConfigEntities, proxyCfg) } - if err := q.DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, client.ClientID); err != nil { + if err := m.DeleteProxyConfigsByClientIDOrOriginClientID(userInfo, client.ClientID); err != nil { return err } @@ -279,7 +317,7 @@ func (q *queryImpl) RebuildProxyConfigFromClient(userInfo models.UserInfo, clien return db.Save(proxyConfigEntities).Error } -func (q *queryImpl) AdminGetProxyConfigByClientIDAndName(clientID string, name string) (*models.ProxyConfig, error) { +func (q *proxyQuery) AdminGetProxyConfigByClientIDAndName(clientID string, name string) (*models.ProxyConfig, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() proxyCfg := &models.ProxyConfig{} err := db. @@ -294,7 +332,7 @@ func (q *queryImpl) AdminGetProxyConfigByClientIDAndName(clientID string, name s return proxyCfg, nil } -func (q *queryImpl) GetProxyConfigsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyConfigEntity, error) { +func (q *proxyQuery) GetProxyConfigsByClientID(userInfo models.UserInfo, clientID string) ([]*models.ProxyConfigEntity, error) { if clientID == "" { return nil, fmt.Errorf("invalid client id") } @@ -315,7 +353,7 @@ func (q *queryImpl) GetProxyConfigsByClientID(userInfo models.UserInfo, clientID }), nil } -func (q *queryImpl) GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig *models.ProxyConfigEntity) (*models.ProxyConfig, error) { +func (q *proxyQuery) GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig *models.ProxyConfigEntity) (*models.ProxyConfig, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() filter := &models.ProxyConfigEntity{} @@ -348,7 +386,7 @@ func (q *queryImpl) GetProxyConfigByFilter(userInfo models.UserInfo, proxyConfig return respProxyCfg, nil } -func (q *queryImpl) ListProxyConfigsWithFilters(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) { +func (q *proxyQuery) ListProxyConfigsWithFilters(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -370,7 +408,7 @@ func (q *queryImpl) ListProxyConfigsWithFilters(userInfo models.UserInfo, page, return proxyConfigs, nil } -func (q *queryImpl) AdminListProxyConfigsWithFilters(filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) { +func (q *proxyQuery) AdminListProxyConfigsWithFilters(filters *models.ProxyConfigEntity) ([]*models.ProxyConfig, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var proxyConfigs []*models.ProxyConfig @@ -384,7 +422,7 @@ func (q *queryImpl) AdminListProxyConfigsWithFilters(filters *models.ProxyConfig return proxyConfigs, nil } -func (q *queryImpl) ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity, keyword string) ([]*models.ProxyConfig, error) { +func (q *proxyQuery) ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, page, pageSize int, filters *models.ProxyConfigEntity, keyword string) ([]*models.ProxyConfig, error) { if page < 1 || pageSize < 1 || len(keyword) == 0 { return nil, fmt.Errorf("invalid page or page size or keyword") } @@ -406,26 +444,26 @@ func (q *queryImpl) ListProxyConfigsWithFiltersAndKeyword(userInfo models.UserIn return proxyConfigs, nil } -func (q *queryImpl) ListProxyConfigsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ProxyConfig, error) { +func (q *proxyQuery) ListProxyConfigsWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ProxyConfig, error) { return q.ListProxyConfigsWithFiltersAndKeyword(userInfo, page, pageSize, &models.ProxyConfigEntity{}, keyword) } -func (q *queryImpl) ListProxyConfigs(userInfo models.UserInfo, page, pageSize int) ([]*models.ProxyConfig, error) { +func (q *proxyQuery) ListProxyConfigs(userInfo models.UserInfo, page, pageSize int) ([]*models.ProxyConfig, error) { return q.ListProxyConfigsWithFilters(userInfo, page, pageSize, &models.ProxyConfigEntity{}) } -func (q *queryImpl) CreateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfigEntity) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *proxyMutation) CreateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfigEntity) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() proxyCfg.UserID = userInfo.GetUserID() proxyCfg.TenantID = userInfo.GetTenantID() return db.Create(&models.ProxyConfig{ProxyConfigEntity: proxyCfg}).Error } -func (q *queryImpl) UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfig) error { +func (m *proxyMutation) UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models.ProxyConfig) error { if proxyCfg.ID == 0 { return fmt.Errorf("invalid proxy config id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() proxyCfg.UserID = userInfo.GetUserID() proxyCfg.TenantID = userInfo.GetTenantID() return db.Where(&models.ProxyConfig{ @@ -440,11 +478,11 @@ func (q *queryImpl) UpdateProxyConfig(userInfo models.UserInfo, proxyCfg *models }).Save(proxyCfg).Error } -func (q *queryImpl) DeleteProxyConfig(userInfo models.UserInfo, clientID, name string) error { +func (m *proxyMutation) DeleteProxyConfig(userInfo models.UserInfo, clientID, name string) error { if clientID == "" || name == "" { return fmt.Errorf("invalid client id or name") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped(). Where(&models.ProxyConfig{ProxyConfigEntity: &models.ProxyConfigEntity{ UserID: userInfo.GetUserID(), @@ -455,11 +493,11 @@ func (q *queryImpl) DeleteProxyConfig(userInfo models.UserInfo, clientID, name s Delete(&models.ProxyConfig{}).Error } -func (q *queryImpl) DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models.UserInfo, clientID string) error { +func (m *proxyMutation) DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models.UserInfo, clientID string) error { if clientID == "" { return fmt.Errorf("invalid client id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped(). Where( db.Where(&models.ProxyConfig{ProxyConfigEntity: &models.ProxyConfigEntity{ @@ -477,11 +515,11 @@ func (q *queryImpl) DeleteProxyConfigsByClientIDOrOriginClientID(userInfo models Delete(&models.ProxyConfig{}).Error } -func (q *queryImpl) DeleteProxyConfigsByClientID(userInfo models.UserInfo, clientID string) error { +func (m *proxyMutation) DeleteProxyConfigsByClientID(userInfo models.UserInfo, clientID string) error { if clientID == "" { return fmt.Errorf("invalid client id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped(). Where(&models.ProxyConfig{ProxyConfigEntity: &models.ProxyConfigEntity{ UserID: userInfo.GetUserID(), @@ -491,7 +529,7 @@ func (q *queryImpl) DeleteProxyConfigsByClientID(userInfo models.UserInfo, clien Delete(&models.ProxyConfig{}).Error } -func (q *queryImpl) GetProxyConfigByOriginClientIDAndName(userInfo models.UserInfo, clientID string, name string) (*models.ProxyConfig, error) { +func (q *proxyQuery) GetProxyConfigByOriginClientIDAndName(userInfo models.UserInfo, clientID string, name string) (*models.ProxyConfig, error) { if clientID == "" || name == "" { return nil, fmt.Errorf("invalid client id or name") } @@ -511,11 +549,11 @@ func (q *queryImpl) GetProxyConfigByOriginClientIDAndName(userInfo models.UserIn return item, nil } -func (q *queryImpl) CountProxyConfigs(userInfo models.UserInfo) (int64, error) { +func (q *proxyQuery) CountProxyConfigs(userInfo models.UserInfo) (int64, error) { return q.CountProxyConfigsWithFilters(userInfo, &models.ProxyConfigEntity{}) } -func (q *queryImpl) CountProxyConfigsWithFilters(userInfo models.UserInfo, filters *models.ProxyConfigEntity) (int64, error) { +func (q *proxyQuery) CountProxyConfigsWithFilters(userInfo models.UserInfo, filters *models.ProxyConfigEntity) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() filters.UserID = userInfo.GetUserID() filters.TenantID = userInfo.GetTenantID() @@ -530,7 +568,7 @@ func (q *queryImpl) CountProxyConfigsWithFilters(userInfo models.UserInfo, filte return count, nil } -func (q *queryImpl) CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, filters *models.ProxyConfigEntity, keyword string) (int64, error) { +func (q *proxyQuery) CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserInfo, filters *models.ProxyConfigEntity, keyword string) (int64, error) { if len(keyword) == 0 { return q.CountProxyConfigsWithFilters(userInfo, filters) } @@ -549,7 +587,7 @@ func (q *queryImpl) CountProxyConfigsWithFiltersAndKeyword(userInfo models.UserI return count, nil } -func (q *queryImpl) GetProxyConfigsByWorkerId(userInfo models.UserInfo, workerID string) ([]*models.ProxyConfig, error) { +func (q *proxyQuery) GetProxyConfigsByWorkerId(userInfo models.UserInfo, workerID string) ([]*models.ProxyConfig, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() items := []*models.ProxyConfig{} diff --git a/services/dao/query.go b/services/dao/query.go index 98878c1..de8e533 100644 --- a/services/dao/query.go +++ b/services/dao/query.go @@ -2,14 +2,110 @@ package dao import "github.com/VaalaCat/frp-panel/services/app" -type Query interface{} +type Query interface { + CertQuery + ClientQuery + EndpointQuery + LinkQuery + NetworkQuery + ProxyQuery + ServerQuery + StatsQuery + UserQuery + WireGuardQuery + WorkerQuery +} +type Mutation interface { + CertMutation + ClientMutation + EndpointMutation + LinkMutation + NetworkMutation + ProxyMutation + ServerMutation + StatsMutation + UserMutation + WireGuardMutation + WorkerMutation + UserGroupMutation +} + +// queryImpl/mutationImpl 是具体表级实现的基础结构(持有 ctx)。 type queryImpl struct { ctx *app.Context } -func NewQuery(ctx *app.Context) *queryImpl { - return &queryImpl{ - ctx: ctx, +type mutationImpl struct { + ctx *app.Context +} + +// compositeQuery / compositeMutation 组合各子领域实现,对外暴露统一入口。 +type compositeQuery struct { + CertQuery + ClientQuery + EndpointQuery + LinkQuery + NetworkQuery + ProxyQuery + ServerQuery + StatsQuery + UserQuery + WireGuardQuery + WorkerQuery +} + +type compositeMutation struct { + CertMutation + ClientMutation + EndpointMutation + LinkMutation + NetworkMutation + ProxyMutation + ServerMutation + StatsMutation + UserMutation + WireGuardMutation + WorkerMutation + UserGroupMutation +} + +func NewQuery(ctx *app.Context) Query { + base := &queryImpl{ctx: ctx} + return &compositeQuery{ + CertQuery: newCertQuery(base), + ClientQuery: newClientQuery(base), + EndpointQuery: newEndpointQuery(base), + LinkQuery: newLinkQuery(base), + NetworkQuery: newNetworkQuery(base), + ProxyQuery: newProxyQuery(base), + ServerQuery: newServerQuery(base), + StatsQuery: newStatsQuery(base), + UserQuery: newUserQuery(base), + WireGuardQuery: newWireGuardQuery(base), + WorkerQuery: newWorkerQuery(base), } } + +func NewMutation(ctx *app.Context) Mutation { + base := &mutationImpl{ctx: ctx} + return &compositeMutation{ + CertMutation: newCertMutation(base), + ClientMutation: newClientMutation(base), + EndpointMutation: newEndpointMutation(base), + LinkMutation: newLinkMutation(base), + NetworkMutation: newNetworkMutation(base), + ProxyMutation: newProxyMutation(base), + ServerMutation: newServerMutation(base), + StatsMutation: newStatsMutation(base), + UserMutation: newUserMutation(base), + WireGuardMutation: newWireGuardMutation(base), + WorkerMutation: newWorkerMutation(base), + UserGroupMutation: newUserGroupMutation(base), + } +} + +var ( + _ Query = (*compositeQuery)(nil) + _ Mutation = (*compositeMutation)(nil) +) diff --git a/services/dao/server.go b/services/dao/server.go index 09838fa..107ea72 100644 --- a/services/dao/server.go +++ b/services/dao/server.go @@ -9,8 +9,34 @@ import ( "github.com/samber/lo" ) -func (q *queryImpl) InitDefaultServer(serverIP string) { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +type ServerQuery interface { + GetDefaultServer() (*models.ServerEntity, error) + ValidateServerSecret(serverID string, secret string) (*models.ServerEntity, error) + AdminGetServerByServerID(serverID string) (*models.ServerEntity, error) + GetServerByServerID(userInfo models.UserInfo, serverID string) (*models.ServerEntity, error) + ListServers(userInfo models.UserInfo, page, pageSize int) ([]*models.ServerEntity, error) + ListServersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ServerEntity, error) + CountServers(userInfo models.UserInfo) (int64, error) + CountServersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) + CountConfiguredServers(userInfo models.UserInfo) (int64, error) +} + +type ServerMutation interface { + InitDefaultServer(serverIP string) + UpdateDefaultServer(c *models.Server) error + CreateServer(userInfo models.UserInfo, server *models.ServerEntity) error + DeleteServer(userInfo models.UserInfo, serverID string) error + UpdateServer(userInfo models.UserInfo, server *models.ServerEntity) error +} + +type serverQuery struct{ *queryImpl } +type serverMutation struct{ *mutationImpl } + +func newServerQuery(base *queryImpl) ServerQuery { return &serverQuery{base} } +func newServerMutation(base *mutationImpl) ServerMutation { return &serverMutation{base} } + +func (m *serverMutation) InitDefaultServer(serverIP string) { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() db.Where(&models.Server{ ServerEntity: &models.ServerEntity{ ServerID: defs.DefaultServerID, @@ -24,7 +50,7 @@ func (q *queryImpl) InitDefaultServer(serverIP string) { }).FirstOrCreate(&models.Server{}) } -func (q *queryImpl) GetDefaultServer() (*models.ServerEntity, error) { +func (q *serverQuery) GetDefaultServer() (*models.ServerEntity, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() c := &models.Server{} err := db. @@ -38,8 +64,8 @@ func (q *queryImpl) GetDefaultServer() (*models.ServerEntity, error) { return c.ServerEntity, nil } -func (q *queryImpl) UpdateDefaultServer(c *models.Server) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *serverMutation) UpdateDefaultServer(c *models.Server) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() c.ServerID = defs.DefaultServerID err := db.Where(&models.Server{ ServerEntity: &models.ServerEntity{ @@ -51,7 +77,7 @@ func (q *queryImpl) UpdateDefaultServer(c *models.Server) error { return nil } -func (q *queryImpl) ValidateServerSecret(serverID string, secret string) (*models.ServerEntity, error) { +func (q *serverQuery) ValidateServerSecret(serverID string, secret string) (*models.ServerEntity, error) { if serverID == "" || secret == "" { return nil, fmt.Errorf("invalid request") } @@ -71,7 +97,7 @@ func (q *queryImpl) ValidateServerSecret(serverID string, secret string) (*model return c.ServerEntity, nil } -func (q *queryImpl) AdminGetServerByServerID(serverID string) (*models.ServerEntity, error) { +func (q *serverQuery) AdminGetServerByServerID(serverID string) (*models.ServerEntity, error) { if serverID == "" { return nil, fmt.Errorf("invalid server id") } @@ -88,7 +114,7 @@ func (q *queryImpl) AdminGetServerByServerID(serverID string) (*models.ServerEnt return c.ServerEntity, nil } -func (q *queryImpl) GetServerByServerID(userInfo models.UserInfo, serverID string) (*models.ServerEntity, error) { +func (q *serverQuery) GetServerByServerID(userInfo models.UserInfo, serverID string) (*models.ServerEntity, error) { if serverID == "" { return nil, fmt.Errorf("invalid server id") } @@ -110,21 +136,21 @@ func (q *queryImpl) GetServerByServerID(userInfo models.UserInfo, serverID strin return c.ServerEntity, nil } -func (q *queryImpl) CreateServer(userInfo models.UserInfo, server *models.ServerEntity) error { +func (m *serverMutation) CreateServer(userInfo models.UserInfo, server *models.ServerEntity) error { server.UserID = userInfo.GetUserID() server.TenantID = userInfo.GetTenantID() c := &models.Server{ ServerEntity: server, } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(c).Error } -func (q *queryImpl) DeleteServer(userInfo models.UserInfo, serverID string) error { +func (m *serverMutation) DeleteServer(userInfo models.UserInfo, serverID string) error { if serverID == "" { return fmt.Errorf("invalid server id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where( &models.Server{ ServerEntity: &models.ServerEntity{ @@ -139,14 +165,14 @@ func (q *queryImpl) DeleteServer(userInfo models.UserInfo, serverID string) erro }).Error } -func (q *queryImpl) UpdateServer(userInfo models.UserInfo, server *models.ServerEntity) error { +func (m *serverMutation) UpdateServer(userInfo models.UserInfo, server *models.ServerEntity) error { c := &models.Server{ ServerEntity: server, } if userInfo.GetUserID() == defs.DefaultAdminUserID && server.ServerID == defs.DefaultServerID { - return q.UpdateDefaultServer(c) + return m.UpdateDefaultServer(c) } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Where( &models.Server{ ServerEntity: &models.ServerEntity{ @@ -157,7 +183,7 @@ func (q *queryImpl) UpdateServer(userInfo models.UserInfo, server *models.Server ).Save(c).Error } -func (q *queryImpl) ListServers(userInfo models.UserInfo, page, pageSize int) ([]*models.ServerEntity, error) { +func (q *serverQuery) ListServers(userInfo models.UserInfo, page, pageSize int) ([]*models.ServerEntity, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -187,7 +213,7 @@ func (q *queryImpl) ListServers(userInfo models.UserInfo, page, pageSize int) ([ }), nil } -func (q *queryImpl) ListServersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ServerEntity, error) { +func (q *serverQuery) ListServersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.ServerEntity, error) { if page < 1 || pageSize < 1 || len(keyword) == 0 { return nil, fmt.Errorf("invalid page or page size or keyword") } @@ -214,7 +240,7 @@ func (q *queryImpl) ListServersWithKeyword(userInfo models.UserInfo, page, pageS }), nil } -func (q *queryImpl) CountServers(userInfo models.UserInfo) (int64, error) { +func (q *serverQuery) CountServers(userInfo models.UserInfo) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Server{}).Where( @@ -231,7 +257,7 @@ func (q *queryImpl) CountServers(userInfo models.UserInfo) (int64, error) { return count, nil } -func (q *queryImpl) CountServersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { +func (q *serverQuery) CountServersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Server{}).Where( @@ -248,7 +274,7 @@ func (q *queryImpl) CountServersWithKeyword(userInfo models.UserInfo, keyword st return count, nil } -func (q *queryImpl) CountConfiguredServers(userInfo models.UserInfo) (int64, error) { +func (q *serverQuery) CountConfiguredServers(userInfo models.UserInfo) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Server{}).Where( diff --git a/services/dao/stats.go b/services/dao/stats.go index aa11e9e..3f84a72 100644 --- a/services/dao/stats.go +++ b/services/dao/stats.go @@ -11,12 +11,29 @@ const ( MSetBatchSize = 100 ) -func (q *queryImpl) AdminSaveTodyStats(s *models.HistoryProxyStats) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +type StatsQuery interface { + GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID int) ([]*models.HistoryProxyStats, error) + GetHistoryStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.HistoryProxyStats, error) + GetHistoryStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.HistoryProxyStats, error) +} + +type StatsMutation interface { + AdminSaveTodyStats(s *models.HistoryProxyStats) error + AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxyStats) error +} + +type statsQuery struct{ *queryImpl } +type statsMutation struct{ *mutationImpl } + +func newStatsQuery(base *queryImpl) StatsQuery { return &statsQuery{base} } +func newStatsMutation(base *mutationImpl) StatsMutation { return &statsMutation{base} } + +func (m *statsMutation) AdminSaveTodyStats(s *models.HistoryProxyStats) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Save(s).Error } -func (q *queryImpl) AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxyStats) error { +func (m *statsMutation) AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxyStats) error { if len(s) == 0 { return nil } @@ -31,7 +48,7 @@ func (q *queryImpl) AdminMSaveTodyStats(tx *gorm.DB, s []*models.HistoryProxySta return nil } -func (q *queryImpl) GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID int) ([]*models.HistoryProxyStats, error) { +func (q *statsQuery) GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID int) ([]*models.HistoryProxyStats, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var stats []*models.HistoryProxyStats err := db.Where(&models.HistoryProxyStats{ @@ -45,7 +62,7 @@ func (q *queryImpl) GetHistoryStatsByProxyID(userInfo models.UserInfo, proxyID i return stats, nil } -func (q *queryImpl) GetHistoryStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.HistoryProxyStats, error) { +func (q *statsQuery) GetHistoryStatsByClientID(userInfo models.UserInfo, clientID string) ([]*models.HistoryProxyStats, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var stats []*models.HistoryProxyStats err := db.Where(&models.HistoryProxyStats{ @@ -59,7 +76,7 @@ func (q *queryImpl) GetHistoryStatsByClientID(userInfo models.UserInfo, clientID return stats, nil } -func (q *queryImpl) GetHistoryStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.HistoryProxyStats, error) { +func (q *statsQuery) GetHistoryStatsByServerID(userInfo models.UserInfo, serverID string) ([]*models.HistoryProxyStats, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var stats []*models.HistoryProxyStats err := db.Where(&models.HistoryProxyStats{ diff --git a/services/dao/user.go b/services/dao/user.go index ca80905..cf49cb0 100644 --- a/services/dao/user.go +++ b/services/dao/user.go @@ -8,7 +8,28 @@ import ( "github.com/samber/lo" ) -func (q *queryImpl) AdminGetAllUsers() ([]*models.UserEntity, error) { +type UserQuery interface { + AdminGetAllUsers() ([]*models.UserEntity, error) + AdminCountUsers() (int64, error) + GetUserByUserID(userID int) (*models.UserEntity, error) + GetUserByUserName(userName string) (*models.UserEntity, error) + CheckUserPassword(userNameOrEmail, password string) (bool, models.UserInfo, error) + CheckUserNameAndEmail(userName, email string) error +} + +type UserMutation interface { + UpdateUser(userInfo models.UserInfo, user *models.UserEntity) error + AdminUpdateUser(userInfo models.UserInfo, user *models.UserEntity) error + CreateUser(user *models.UserEntity) error +} + +type userQuery struct{ *queryImpl } +type userMutation struct{ *mutationImpl } + +func newUserQuery(base *queryImpl) UserQuery { return &userQuery{base} } +func newUserMutation(base *mutationImpl) UserMutation { return &userMutation{base} } + +func (q *userQuery) AdminGetAllUsers() ([]*models.UserEntity, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() users := make([]*models.User, 0) err := db.Find(&users).Error @@ -21,7 +42,7 @@ func (q *queryImpl) AdminGetAllUsers() ([]*models.UserEntity, error) { }), nil } -func (q *queryImpl) AdminCountUsers() (int64, error) { +func (q *userQuery) AdminCountUsers() (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.User{}).Count(&count).Error @@ -31,7 +52,7 @@ func (q *queryImpl) AdminCountUsers() (int64, error) { return count, nil } -func (q *queryImpl) GetUserByUserID(userID int) (*models.UserEntity, error) { +func (q *userQuery) GetUserByUserID(userID int) (*models.UserEntity, error) { if userID == 0 { return nil, fmt.Errorf("invalid user id") } @@ -48,8 +69,8 @@ func (q *queryImpl) GetUserByUserID(userID int) (*models.UserEntity, error) { return u.UserEntity, nil } -func (q *queryImpl) UpdateUser(userInfo models.UserInfo, user *models.UserEntity) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *userMutation) UpdateUser(userInfo models.UserInfo, user *models.UserEntity) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() user.UserID = userInfo.GetUserID() return db.Model(&models.User{}).Where( &models.User{ @@ -62,8 +83,8 @@ func (q *queryImpl) UpdateUser(userInfo models.UserInfo, user *models.UserEntity }).Error } -func (q *queryImpl) AdminUpdateUser(userInfo models.UserInfo, user *models.UserEntity) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *userMutation) AdminUpdateUser(userInfo models.UserInfo, user *models.UserEntity) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() user.UserID = userInfo.GetUserID() return db.Model(&models.User{}).Where( &models.User{ @@ -76,7 +97,7 @@ func (q *queryImpl) AdminUpdateUser(userInfo models.UserInfo, user *models.UserE }).Error } -func (q *queryImpl) GetUserByUserName(userName string) (*models.UserEntity, error) { +func (q *userQuery) GetUserByUserName(userName string) (*models.UserEntity, error) { if userName == "" { return nil, fmt.Errorf("invalid user name") } @@ -93,7 +114,7 @@ func (q *queryImpl) GetUserByUserName(userName string) (*models.UserEntity, erro return u.UserEntity, nil } -func (q *queryImpl) CheckUserPassword(userNameOrEmail, password string) (bool, models.UserInfo, error) { +func (q *userQuery) CheckUserPassword(userNameOrEmail, password string) (bool, models.UserInfo, error) { var user models.User db := q.ctx.GetApp().GetDBManager().GetDefaultDB() @@ -111,7 +132,7 @@ func (q *queryImpl) CheckUserPassword(userNameOrEmail, password string) (bool, m return utils.CheckPasswordHash(password, user.Password), user, nil } -func (q *queryImpl) CheckUserNameAndEmail(userName, email string) error { +func (q *userQuery) CheckUserNameAndEmail(userName, email string) error { var user models.User db := q.ctx.GetApp().GetDBManager().GetDefaultDB() @@ -129,10 +150,10 @@ func (q *queryImpl) CheckUserNameAndEmail(userName, email string) error { return nil } -func (q *queryImpl) CreateUser(user *models.UserEntity) error { +func (m *userMutation) CreateUser(user *models.UserEntity) error { u := &models.User{ UserEntity: user, } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(u).Error } diff --git a/services/dao/user_group.go b/services/dao/user_group.go index a588194..d94ce9a 100644 --- a/services/dao/user_group.go +++ b/services/dao/user_group.go @@ -7,7 +7,16 @@ import ( "github.com/VaalaCat/frp-panel/models" ) -func (q *queryImpl) CreateGroup(userInfo models.UserInfo, groupID, groupName, comment string) (*models.UserGroup, error) { +type UserGroupMutation interface { + CreateGroup(userInfo models.UserInfo, groupID, groupName, comment string) (*models.UserGroup, error) + DeleteGroup(userInfo models.UserInfo, groupID string) error +} + +type userGroupMutation struct{ *mutationImpl } + +func newUserGroupMutation(base *mutationImpl) UserGroupMutation { return &userGroupMutation{base} } + +func (m *userGroupMutation) CreateGroup(userInfo models.UserInfo, groupID, groupName, comment string) (*models.UserGroup, error) { if groupID == "" || groupName == "" { return nil, fmt.Errorf("invalid group id or group name") } @@ -16,7 +25,7 @@ func (q *queryImpl) CreateGroup(userInfo models.UserInfo, groupID, groupName, co return nil, fmt.Errorf("only admin can create group") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() g := &models.UserGroup{ TenantID: userInfo.GetTenantID(), @@ -32,12 +41,12 @@ func (q *queryImpl) CreateGroup(userInfo models.UserInfo, groupID, groupName, co return g, nil } -func (q *queryImpl) DeleteGroup(userInfo models.UserInfo, groupID string) error { +func (m *userGroupMutation) DeleteGroup(userInfo models.UserInfo, groupID string) error { if userInfo.GetRole() != defs.UserRole_Admin { return fmt.Errorf("only admin can delete group") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where(&models.UserGroup{ TenantID: userInfo.GetTenantID(), GroupID: groupID, diff --git a/services/dao/wireguard.go b/services/dao/wireguard.go index e7ddfce..c2bfae7 100644 --- a/services/dao/wireguard.go +++ b/services/dao/wireguard.go @@ -7,7 +7,30 @@ import ( "gorm.io/gorm" ) -func (q *queryImpl) CreateWireGuard(userInfo models.UserInfo, wg *models.WireGuard) error { +type WireGuardQuery interface { + GetWireGuardByID(userInfo models.UserInfo, id uint) (*models.WireGuard, error) + AdminGetWireGuardByClientIDAndInterfaceName(clientID, interfaceName string) (*models.WireGuard, error) + GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID uint) ([]*models.WireGuard, error) + GetWireGuardLocalAddressesByNetworkID(userInfo models.UserInfo, networkID uint) ([]string, error) + ListWireGuardsWithFilters(userInfo models.UserInfo, page, pageSize int, filter *models.WireGuardEntity, keyword string) ([]*models.WireGuard, error) + AdminListWireGuardsWithClientID(clientID string) ([]*models.WireGuard, error) + AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*models.WireGuard, error) + CountWireGuardsWithFilters(userInfo models.UserInfo, filter *models.WireGuardEntity, keyword string) (int64, error) +} + +type WireGuardMutation interface { + CreateWireGuard(userInfo models.UserInfo, wg *models.WireGuard) error + UpdateWireGuard(userInfo models.UserInfo, id uint, wg *models.WireGuard) error + DeleteWireGuard(userInfo models.UserInfo, id uint) error +} + +type wireGuardQuery struct{ *queryImpl } +type wireGuardMutation struct{ *mutationImpl } + +func newWireGuardQuery(base *queryImpl) WireGuardQuery { return &wireGuardQuery{base} } +func newWireGuardMutation(base *mutationImpl) WireGuardMutation { return &wireGuardMutation{base} } + +func (m *wireGuardMutation) CreateWireGuard(userInfo models.UserInfo, wg *models.WireGuard) error { if wg == nil || wg.WireGuardEntity == nil { return fmt.Errorf("invalid wireguard entity") } @@ -18,11 +41,11 @@ func (q *queryImpl) CreateWireGuard(userInfo models.UserInfo, wg *models.WireGua wg.UserId = uint32(userInfo.GetUserID()) wg.TenantId = uint32(userInfo.GetTenantID()) - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Create(wg).Error } -func (q *queryImpl) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *models.WireGuard) error { +func (m *wireGuardMutation) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *models.WireGuard) error { if id == 0 || wg == nil || wg.WireGuardEntity == nil { return fmt.Errorf("invalid wireguard id or entity") } @@ -30,7 +53,7 @@ func (q *queryImpl) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *model wg.UserId = uint32(userInfo.GetUserID()) wg.TenantId = uint32(userInfo.GetTenantID()) - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() // clear endpoints and resave if provided if wg.AdvertisedEndpoints != nil { @@ -49,11 +72,11 @@ func (q *queryImpl) UpdateWireGuard(userInfo models.UserInfo, id uint, wg *model }}).Save(wg).Error } -func (q *queryImpl) DeleteWireGuard(userInfo models.UserInfo, id uint) error { +func (m *wireGuardMutation) DeleteWireGuard(userInfo models.UserInfo, id uint) error { if id == 0 { return fmt.Errorf("invalid wireguard id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where(&models.WireGuard{ Model: gorm.Model{ID: id}, WireGuardEntity: &models.WireGuardEntity{ @@ -63,7 +86,7 @@ func (q *queryImpl) DeleteWireGuard(userInfo models.UserInfo, id uint) error { }).Delete(&models.WireGuard{}).Error } -func (q *queryImpl) GetWireGuardByID(userInfo models.UserInfo, id uint) (*models.WireGuard, error) { +func (q *wireGuardQuery) GetWireGuardByID(userInfo models.UserInfo, id uint) (*models.WireGuard, error) { if id == 0 { return nil, fmt.Errorf("invalid wireguard id") } @@ -84,7 +107,7 @@ func (q *queryImpl) GetWireGuardByID(userInfo models.UserInfo, id uint) (*models return &m, nil } -func (q *queryImpl) AdminGetWireGuardByClientIDAndInterfaceName(clientID, interfaceName string) (*models.WireGuard, error) { +func (q *wireGuardQuery) AdminGetWireGuardByClientIDAndInterfaceName(clientID, interfaceName string) (*models.WireGuard, error) { if clientID == "" || interfaceName == "" { return nil, fmt.Errorf("invalid client id or interface name") } @@ -100,7 +123,7 @@ func (q *queryImpl) AdminGetWireGuardByClientIDAndInterfaceName(clientID, interf return &m, nil } -func (q *queryImpl) GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID uint) ([]*models.WireGuard, error) { +func (q *wireGuardQuery) GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID uint) ([]*models.WireGuard, error) { if networkID == 0 { return nil, fmt.Errorf("invalid network id") } @@ -119,7 +142,7 @@ func (q *queryImpl) GetWireGuardsByNetworkID(userInfo models.UserInfo, networkID return list, nil } -func (q *queryImpl) GetWireGuardLocalAddressesByNetworkID(userInfo models.UserInfo, networkID uint) ([]string, error) { +func (q *wireGuardQuery) GetWireGuardLocalAddressesByNetworkID(userInfo models.UserInfo, networkID uint) ([]string, error) { if networkID == 0 { return nil, fmt.Errorf("invalid network id") } @@ -135,7 +158,7 @@ func (q *queryImpl) GetWireGuardLocalAddressesByNetworkID(userInfo models.UserIn return list, nil } -func (q *queryImpl) ListWireGuardsWithFilters(userInfo models.UserInfo, page, pageSize int, filter *models.WireGuardEntity, keyword string) ([]*models.WireGuard, error) { +func (q *wireGuardQuery) ListWireGuardsWithFilters(userInfo models.UserInfo, page, pageSize int, filter *models.WireGuardEntity, keyword string) ([]*models.WireGuard, error) { if page < 1 || pageSize < 1 { return nil, fmt.Errorf("invalid page or page size") } @@ -170,7 +193,7 @@ func (q *queryImpl) ListWireGuardsWithFilters(userInfo models.UserInfo, page, pa return list, nil } -func (q *queryImpl) AdminListWireGuardsWithClientID(clientID string) ([]*models.WireGuard, error) { +func (q *wireGuardQuery) AdminListWireGuardsWithClientID(clientID string) ([]*models.WireGuard, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var list []*models.WireGuard if err := db.Where(&models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ClientID: clientID}}). @@ -180,7 +203,7 @@ func (q *queryImpl) AdminListWireGuardsWithClientID(clientID string) ([]*models. return list, nil } -func (q *queryImpl) AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*models.WireGuard, error) { +func (q *wireGuardQuery) AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*models.WireGuard, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var list []*models.WireGuard if err := db.Where("network_id IN ?", networkIDs). @@ -190,7 +213,7 @@ func (q *queryImpl) AdminListWireGuardsWithNetworkIDs(networkIDs []uint) ([]*mod return list, nil } -func (q *queryImpl) CountWireGuardsWithFilters(userInfo models.UserInfo, filter *models.WireGuardEntity, keyword string) (int64, error) { +func (q *wireGuardQuery) CountWireGuardsWithFilters(userInfo models.UserInfo, filter *models.WireGuardEntity, keyword string) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 base := db.Model(&models.WireGuard{}).Where(&models.WireGuard{WireGuardEntity: &models.WireGuardEntity{ diff --git a/services/dao/worker.go b/services/dao/worker.go index 2f30602..2c044bf 100644 --- a/services/dao/worker.go +++ b/services/dao/worker.go @@ -6,8 +6,29 @@ import ( "github.com/VaalaCat/frp-panel/models" ) -func (q *queryImpl) CreateWorker(userInfo models.UserInfo, worker *models.Worker) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +type WorkerQuery interface { + GetWorkerByWorkerID(userInfo models.UserInfo, workerID string) (*models.Worker, error) + ListWorkers(userInfo models.UserInfo, page, pageSize int) ([]*models.Worker, error) + AdminListWorkersByClientID(clientID string) ([]*models.Worker, error) + ListWorkersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Worker, error) + CountWorkers(userInfo models.UserInfo) (int64, error) + CountWorkersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) +} + +type WorkerMutation interface { + CreateWorker(userInfo models.UserInfo, worker *models.Worker) error + DeleteWorker(userInfo models.UserInfo, workerID string) error + UpdateWorker(userInfo models.UserInfo, worker *models.Worker) error +} + +type workerQuery struct{ *queryImpl } +type workerMutation struct{ *mutationImpl } + +func newWorkerQuery(base *queryImpl) WorkerQuery { return &workerQuery{base} } +func newWorkerMutation(base *mutationImpl) WorkerMutation { return &workerMutation{base} } + +func (m *workerMutation) CreateWorker(userInfo models.UserInfo, worker *models.Worker) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() worker.UserId = uint32(userInfo.GetUserID()) worker.TenantId = uint32(userInfo.GetTenantID()) @@ -17,8 +38,8 @@ func (q *queryImpl) CreateWorker(userInfo models.UserInfo, worker *models.Worker return db.Create(worker).Error } -func (q *queryImpl) DeleteWorker(userInfo models.UserInfo, workerID string) error { - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() +func (m *workerMutation) DeleteWorker(userInfo models.UserInfo, workerID string) error { + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() return db.Unscoped().Where(&models.Worker{ WorkerEntity: &models.WorkerEntity{ @@ -29,7 +50,7 @@ func (q *queryImpl) DeleteWorker(userInfo models.UserInfo, workerID string) erro }).Delete(&models.Worker{}).Error } -func (q *queryImpl) UpdateWorker(userInfo models.UserInfo, worker *models.Worker) error { +func (m *workerMutation) UpdateWorker(userInfo models.UserInfo, worker *models.Worker) error { if worker.WorkerEntity == nil { return fmt.Errorf("invalid worker entity") } @@ -37,7 +58,7 @@ func (q *queryImpl) UpdateWorker(userInfo models.UserInfo, worker *models.Worker return fmt.Errorf("invalid worker id") } - db := q.ctx.GetApp().GetDBManager().GetDefaultDB() + db := m.ctx.GetApp().GetDBManager().GetDefaultDB() if err := db.Unscoped().Model(&models.Worker{ WorkerEntity: &models.WorkerEntity{ @@ -58,7 +79,7 @@ func (q *queryImpl) UpdateWorker(userInfo models.UserInfo, worker *models.Worker }).Save(worker).Error } -func (q *queryImpl) GetWorkerByWorkerID(userInfo models.UserInfo, workerID string) (*models.Worker, error) { +func (q *workerQuery) GetWorkerByWorkerID(userInfo models.UserInfo, workerID string) (*models.Worker, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() w := &models.Worker{} err := db.Where(&models.Worker{ @@ -74,7 +95,7 @@ func (q *queryImpl) GetWorkerByWorkerID(userInfo models.UserInfo, workerID strin return w, nil } -func (q *queryImpl) ListWorkers(userInfo models.UserInfo, page, pageSize int) ([]*models.Worker, error) { +func (q *workerQuery) ListWorkers(userInfo models.UserInfo, page, pageSize int) ([]*models.Worker, error) { if page < 1 || pageSize < 1 || pageSize > 100 { return nil, fmt.Errorf("invalid page or page size") } @@ -96,9 +117,9 @@ func (q *queryImpl) ListWorkers(userInfo models.UserInfo, page, pageSize int) ([ return workers, nil } -func (q *queryImpl) AdminListWorkersByClientID(clientID string) ([]*models.Worker, error) { +func (q *workerQuery) AdminListWorkersByClientID(clientID string) ([]*models.Worker, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() - client, err := q.AdminGetClientByClientID(clientID) + client, err := newClientQuery(q.queryImpl).AdminGetClientByClientID(clientID) if err != nil { return nil, err } @@ -111,7 +132,7 @@ func (q *queryImpl) AdminListWorkersByClientID(clientID string) ([]*models.Worke return client.Workers, nil } -func (q *queryImpl) ListWorkersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Worker, error) { +func (q *workerQuery) ListWorkersWithKeyword(userInfo models.UserInfo, page, pageSize int, keyword string) ([]*models.Worker, error) { if page < 1 || pageSize < 1 || len(keyword) == 0 || pageSize > 100 { return nil, fmt.Errorf("invalid page or page size or keyword") } @@ -134,7 +155,7 @@ func (q *queryImpl) ListWorkersWithKeyword(userInfo models.UserInfo, page, pageS return workers, nil } -func (q *queryImpl) CountWorkers(userInfo models.UserInfo) (int64, error) { +func (q *workerQuery) CountWorkers(userInfo models.UserInfo) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Worker{}).Where(&models.Worker{ @@ -149,7 +170,7 @@ func (q *queryImpl) CountWorkers(userInfo models.UserInfo) (int64, error) { return count, nil } -func (q *queryImpl) CountWorkersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { +func (q *workerQuery) CountWorkersWithKeyword(userInfo models.UserInfo, keyword string) (int64, error) { db := q.ctx.GetApp().GetDBManager().GetDefaultDB() var count int64 err := db.Model(&models.Worker{}).Where("name like ?", "%"+keyword+"%"). diff --git a/services/master/grpc_server.go b/services/master/grpc_server.go index 44a4ba8..390db54 100644 --- a/services/master/grpc_server.go +++ b/services/master/grpc_server.go @@ -208,7 +208,7 @@ func (s *server) ServerSend(sender pb.Master_ServerSendServer) error { } if cliType == defs.CliTypeClient { - if err := dao.NewQuery(ctx).AdminUpdateClientLastSeen(req.GetClientId()); err != nil { + if err := dao.NewMutation(ctx).AdminUpdateClientLastSeen(req.GetClientId()); err != nil { logger.Logger(ctx).Errorf("cannot update client last seen, %s id: [%s]", req.GetEvent().String(), req.GetClientId()) } }