mirror of
https://github.com/go-eagle/eagle.git
synced 2025-09-26 20:41:26 +08:00
lint: fix golangci-lint tips
This commit is contained in:
@@ -27,7 +27,7 @@ func SendResponse(c *gin.Context, err error, data interface{}) {
|
||||
})
|
||||
}
|
||||
|
||||
// 返回用户id
|
||||
// GetUserID 返回用户id
|
||||
func GetUserID(c *gin.Context) uint64 {
|
||||
if c == nil {
|
||||
return 0
|
||||
|
@@ -1,7 +1,7 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
. "github.com/1024casts/snake/handler"
|
||||
"github.com/1024casts/snake/handler"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -22,5 +22,5 @@ func Delete(c *gin.Context) {
|
||||
// return
|
||||
//}
|
||||
|
||||
SendResponse(c, nil, nil)
|
||||
handler.SendResponse(c, nil, nil)
|
||||
}
|
||||
|
@@ -3,7 +3,8 @@ package user
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
. "github.com/1024casts/snake/handler"
|
||||
"github.com/1024casts/snake/handler"
|
||||
|
||||
"github.com/1024casts/snake/pkg/errno"
|
||||
"github.com/1024casts/snake/service/user"
|
||||
|
||||
@@ -25,7 +26,7 @@ func Get(c *gin.Context) {
|
||||
|
||||
userIDStr := c.Param("id")
|
||||
if userIDStr == "" {
|
||||
SendResponse(c, errno.ErrParam, nil)
|
||||
handler.SendResponse(c, errno.ErrParam, nil)
|
||||
return
|
||||
}
|
||||
userID, _ := strconv.Atoi(userIDStr)
|
||||
@@ -33,9 +34,9 @@ func Get(c *gin.Context) {
|
||||
// Get the user by the `user_id` from the database.
|
||||
u, err := user.UserService.GetUserByID(uint64(userID))
|
||||
if err != nil {
|
||||
SendResponse(c, errno.ErrUserNotFound, nil)
|
||||
handler.SendResponse(c, errno.ErrUserNotFound, nil)
|
||||
return
|
||||
}
|
||||
|
||||
SendResponse(c, nil, u)
|
||||
handler.SendResponse(c, nil, u)
|
||||
}
|
||||
|
@@ -23,7 +23,7 @@ func TestGet(t *testing.T) {
|
||||
userTests := []model.UserModel{
|
||||
{
|
||||
BaseModel: model.BaseModel{
|
||||
Id: 12,
|
||||
ID: 12,
|
||||
},
|
||||
Username: "user001",
|
||||
Password: "123456",
|
||||
@@ -33,7 +33,7 @@ func TestGet(t *testing.T) {
|
||||
},
|
||||
{
|
||||
BaseModel: model.BaseModel{
|
||||
Id: 13,
|
||||
ID: 13,
|
||||
},
|
||||
Username: "user002",
|
||||
Password: "123456",
|
||||
@@ -48,7 +48,7 @@ func TestGet(t *testing.T) {
|
||||
app.DB.Create(user)
|
||||
|
||||
// Set up a new request.
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("/v1/users/%d", user.Id), nil)
|
||||
req, err := http.NewRequest("GET", fmt.Sprintf("/v1/users/%d", user.ID), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@@ -3,7 +3,8 @@ package user
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
. "github.com/1024casts/snake/handler"
|
||||
"github.com/1024casts/snake/handler"
|
||||
|
||||
"github.com/1024casts/snake/model"
|
||||
"github.com/1024casts/snake/pkg/errno"
|
||||
"github.com/1024casts/snake/pkg/token"
|
||||
@@ -28,20 +29,20 @@ func Login(c *gin.Context) {
|
||||
// Binding the data with the u struct.
|
||||
var req PhoneLoginCredentials
|
||||
if err := c.Bind(&req); err != nil {
|
||||
SendResponse(c, errno.ErrBind, nil)
|
||||
handler.SendResponse(c, errno.ErrBind, nil)
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("req %#v", req)
|
||||
// check param
|
||||
if req.Phone == 0 || req.VerifyCode == 0 {
|
||||
SendResponse(c, errno.ErrParam, nil)
|
||||
handler.SendResponse(c, errno.ErrParam, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证校验码
|
||||
if !vcode.VCodeService.CheckLoginVCode(req.Phone, req.VerifyCode) {
|
||||
SendResponse(c, errno.ErrVerifyCode, nil)
|
||||
handler.SendResponse(c, errno.ErrVerifyCode, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -52,28 +53,28 @@ func Login(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 否则新建用户信息, 并取得用户信息
|
||||
if u.Id == 0 {
|
||||
if u.ID == 0 {
|
||||
u := model.UserModel{
|
||||
Phone: req.Phone,
|
||||
Username: strconv.Itoa(req.Phone),
|
||||
}
|
||||
u.Id, err = user.UserService.CreateUser(u)
|
||||
u.ID, err = user.UserService.CreateUser(u)
|
||||
if err != nil {
|
||||
log.Warnf("[login] create u err, %v", err)
|
||||
SendResponse(c, errno.InternalServerError, nil)
|
||||
handler.SendResponse(c, errno.InternalServerError, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 签发签名 Sign the json web token.
|
||||
t, err := token.Sign(c, token.Context{UserID: u.Id, Username: u.Username}, "")
|
||||
t, err := token.Sign(c, token.Context{UserID: u.ID, Username: u.Username}, "")
|
||||
if err != nil {
|
||||
log.Warnf("[login] gen token sign err:, %v", err)
|
||||
SendResponse(c, errno.ErrToken, nil)
|
||||
handler.SendResponse(c, errno.ErrToken, nil)
|
||||
return
|
||||
}
|
||||
|
||||
SendResponse(c, nil, model.Token{
|
||||
handler.SendResponse(c, nil, model.Token{
|
||||
Token: t,
|
||||
})
|
||||
}
|
||||
|
@@ -3,7 +3,8 @@ package user
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
. "github.com/1024casts/snake/handler"
|
||||
"github.com/1024casts/snake/handler"
|
||||
|
||||
"github.com/1024casts/snake/pkg/errno"
|
||||
"github.com/1024casts/snake/service/user"
|
||||
|
||||
@@ -29,7 +30,7 @@ func Update(c *gin.Context) {
|
||||
// Binding the user data.
|
||||
var req UpdateRequest
|
||||
if err := c.Bind(&req); err != nil {
|
||||
SendResponse(c, errno.ErrBind, nil)
|
||||
handler.SendResponse(c, errno.ErrBind, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -39,9 +40,9 @@ func Update(c *gin.Context) {
|
||||
err := user.UserService.UpdateUser(userMap, uint64(userID))
|
||||
if err != nil {
|
||||
log.Warnf("[user] update user err, %v", err)
|
||||
SendResponse(c, errno.InternalServerError, nil)
|
||||
handler.SendResponse(c, errno.InternalServerError, nil)
|
||||
return
|
||||
}
|
||||
|
||||
SendResponse(c, nil, userID)
|
||||
handler.SendResponse(c, nil, userID)
|
||||
}
|
||||
|
@@ -1,10 +1,11 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
. "github.com/1024casts/snake/handler"
|
||||
"github.com/1024casts/snake/handler"
|
||||
"github.com/1024casts/snake/pkg/errno"
|
||||
"github.com/1024casts/snake/service/sms"
|
||||
"github.com/1024casts/snake/service/vcode"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lexkong/log"
|
||||
"github.com/pkg/errors"
|
||||
@@ -25,13 +26,13 @@ func VCode(c *gin.Context) {
|
||||
|
||||
// 验证区号和手机号是否为空
|
||||
if c.Query("area_code") == "" {
|
||||
SendResponse(c, errno.ErrAreaCodeEmpty, nil)
|
||||
handler.SendResponse(c, errno.ErrAreaCodeEmpty, nil)
|
||||
return
|
||||
}
|
||||
|
||||
phone := c.Query("phone")
|
||||
if phone == "" {
|
||||
SendResponse(c, errno.ErrPhoneEmpty, nil)
|
||||
handler.SendResponse(c, errno.ErrPhoneEmpty, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -41,7 +42,7 @@ func VCode(c *gin.Context) {
|
||||
verifyCode, err := vcode.VCodeService.GenLoginVCode(phone)
|
||||
if err != nil {
|
||||
log.Warnf("gen login verify code err, %v", errors.WithStack(err))
|
||||
SendResponse(c, errno.ErrGenVCode, nil)
|
||||
handler.SendResponse(c, errno.ErrGenVCode, nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -49,9 +50,9 @@ func VCode(c *gin.Context) {
|
||||
err = sms.ServiceSms.Send(phone, verifyCode)
|
||||
if err != nil {
|
||||
log.Warnf("send phone sms err, %v", errors.WithStack(err))
|
||||
SendResponse(c, errno.ErrSendSMS, nil)
|
||||
handler.SendResponse(c, errno.ErrSendSMS, nil)
|
||||
return
|
||||
}
|
||||
|
||||
SendResponse(c, nil, nil)
|
||||
handler.SendResponse(c, nil, nil)
|
||||
}
|
||||
|
5
main.go
5
main.go
@@ -6,6 +6,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
// http pprof
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
"os/signal"
|
||||
@@ -80,7 +82,7 @@ func main() {
|
||||
|
||||
// Middlwares.
|
||||
middleware.Logging(),
|
||||
middleware.RequestId(),
|
||||
middleware.RequestID(),
|
||||
)
|
||||
|
||||
// Ping the server to make sure the router is working.
|
||||
@@ -133,6 +135,7 @@ func main() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("timeout of 5 seconds.")
|
||||
default:
|
||||
}
|
||||
log.Info("Server exiting")
|
||||
}
|
||||
|
@@ -9,16 +9,20 @@ import (
|
||||
|
||||
// MySQL driver.
|
||||
"github.com/jinzhu/gorm"
|
||||
// GORM MySQL
|
||||
_ "github.com/jinzhu/gorm/dialects/mysql"
|
||||
)
|
||||
|
||||
// Database 定义现有的数据库
|
||||
type Database struct {
|
||||
Self *gorm.DB
|
||||
Docker *gorm.DB
|
||||
}
|
||||
|
||||
// DB 数据库全局变量
|
||||
var DB *Database
|
||||
|
||||
// openDB 链接数据库,生成数据库实例
|
||||
func openDB(username, password, addr, name string) *gorm.DB {
|
||||
config := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
|
||||
username,
|
||||
@@ -42,6 +46,7 @@ func openDB(username, password, addr, name string) *gorm.DB {
|
||||
return db
|
||||
}
|
||||
|
||||
// setupDB 配置数据库
|
||||
func setupDB(db *gorm.DB) {
|
||||
db.LogMode(viper.GetBool("gorm.show_log"))
|
||||
// 用于设置最大打开的连接数,默认值为0表示不限制.设置最大的连接数,可以避免并发太高导致连接mysql出现too many connections的错误。
|
||||
@@ -51,7 +56,7 @@ func setupDB(db *gorm.DB) {
|
||||
db.DB().SetConnMaxLifetime(time.Minute * viper.GetDuration("grom.conn_max_lift_time"))
|
||||
}
|
||||
|
||||
// used for cli
|
||||
// InitSelfDB used for cli
|
||||
func InitSelfDB() *gorm.DB {
|
||||
return openDB(viper.GetString("db.username"),
|
||||
viper.GetString("db.password"),
|
||||
@@ -59,10 +64,12 @@ func InitSelfDB() *gorm.DB {
|
||||
viper.GetString("db.name"))
|
||||
}
|
||||
|
||||
// GetSelfDB 获取self数据库示例
|
||||
func GetSelfDB() *gorm.DB {
|
||||
return InitSelfDB()
|
||||
}
|
||||
|
||||
// InitDockerDB 初始化一个docker数据库
|
||||
func InitDockerDB() *gorm.DB {
|
||||
return openDB(viper.GetString("docker_db.username"),
|
||||
viper.GetString("docker_db.password"),
|
||||
@@ -70,10 +77,12 @@ func InitDockerDB() *gorm.DB {
|
||||
viper.GetString("docker_db.name"))
|
||||
}
|
||||
|
||||
// GetDockerDB 获取docker数据库
|
||||
func GetDockerDB() *gorm.DB {
|
||||
return InitDockerDB()
|
||||
}
|
||||
|
||||
// Init 初始化数据库
|
||||
func (db *Database) Init() {
|
||||
DB = &Database{
|
||||
Self: GetSelfDB(),
|
||||
@@ -81,10 +90,12 @@ func (db *Database) Init() {
|
||||
}
|
||||
}
|
||||
|
||||
// GetDB 返回默认的数据库
|
||||
func GetDB() *gorm.DB {
|
||||
return DB.Self
|
||||
}
|
||||
|
||||
// Close 关闭数据库链接
|
||||
func (db *Database) Close() {
|
||||
err := DB.Self.Close()
|
||||
if err != nil {
|
||||
|
@@ -6,13 +6,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// BaseModel 公共model
|
||||
type BaseModel struct {
|
||||
Id uint64 `gorm:"primary_key;AUTO_INCREMENT;column:id" json:"-"`
|
||||
ID uint64 `gorm:"primary_key;AUTO_INCREMENT;column:id" json:"-"`
|
||||
CreatedAt time.Time `gorm:"column:created_at" json:"-"`
|
||||
UpdatedAt time.Time `gorm:"column:updated_at" json:"-"`
|
||||
DeletedAt *time.Time `gorm:"column:deleted_at" sql:"index" json:"-"`
|
||||
}
|
||||
|
||||
// NullType 空字节类型
|
||||
type NullType byte
|
||||
|
||||
const (
|
||||
@@ -23,7 +25,7 @@ const (
|
||||
IsNotNull
|
||||
)
|
||||
|
||||
// sql build where
|
||||
// WhereBuild sql build where
|
||||
// see: https://github.com/jinzhu/gorm/issues/2055
|
||||
func WhereBuild(where map[string]interface{}) (whereSQL string, vals []interface{}, err error) {
|
||||
for k, v := range where {
|
||||
@@ -52,7 +54,6 @@ func WhereBuild(where map[string]interface{}) (whereSQL string, vals []interface
|
||||
whereSQL += fmt.Sprint(k, "=?")
|
||||
vals = append(vals, v)
|
||||
}
|
||||
break
|
||||
case 2:
|
||||
k = ks[0]
|
||||
switch ks[1] {
|
||||
@@ -84,7 +85,6 @@ func WhereBuild(where map[string]interface{}) (whereSQL string, vals []interface
|
||||
whereSQL += fmt.Sprint(k, " like ?")
|
||||
vals = append(vals, v)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
|
@@ -9,7 +9,7 @@ import (
|
||||
validator "github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// User represents a registered user.
|
||||
// UserModel User represents a registered user.
|
||||
type UserModel struct {
|
||||
BaseModel
|
||||
Username string `json:"username" gorm:"column:username;not null" binding:"required" validate:"min=1,max=32"`
|
||||
@@ -27,21 +27,24 @@ func (u *UserModel) Validate() error {
|
||||
return validate.Struct(u)
|
||||
}
|
||||
|
||||
// UserInfo 对外暴露的结构体
|
||||
type UserInfo struct {
|
||||
Id uint64 `json:"id" example:"1"`
|
||||
ID uint64 `json:"id" example:"1"`
|
||||
Username string `json:"username" example:"张三"`
|
||||
Password string `json:"password" example:"9dXd13#k$1123!kln"`
|
||||
CreatedAt string `json:"createdAt" example:"2020-03-23 20:00:00"`
|
||||
UpdatedAt string `json:"updatedAt" example:"2020-03-23 20:00:00"`
|
||||
}
|
||||
|
||||
func (c *UserModel) TableName() string {
|
||||
// TableName 表名
|
||||
func (u *UserModel) TableName() string {
|
||||
return "tb_users"
|
||||
}
|
||||
|
||||
// UserList 用户列表结构体
|
||||
type UserList struct {
|
||||
Lock *sync.Mutex
|
||||
IdMap map[uint64]*UserInfo
|
||||
IDMap map[uint64]*UserInfo
|
||||
}
|
||||
|
||||
// Token represents a JSON web token.
|
||||
|
16
pkg/cache/driver.go
vendored
16
pkg/cache/driver.go
vendored
@@ -8,17 +8,18 @@ import (
|
||||
redis2 "github.com/1024casts/snake/pkg/redis"
|
||||
)
|
||||
|
||||
// keyPrefix 一般为业务前缀
|
||||
var Cache Driver = NewMemoryCache("snake:", JsonEncoding{})
|
||||
// Cache 生成一个缓存客户端,其中keyPrefix 一般为业务前缀
|
||||
var Cache Driver = NewMemoryCache("snake:", JSONEncoding{})
|
||||
|
||||
// 初始化缓存,在main.go里调用
|
||||
// Init 初始化缓存,在main.go里调用
|
||||
// 默认是redis,这里也可以改为其他缓存
|
||||
func Init() {
|
||||
if gin.Mode() == gin.ReleaseMode {
|
||||
Cache = NewRedisCache(redis2.Client, "snake:", JsonEncoding{})
|
||||
Cache = NewRedisCache(redis2.Client, "snake:", JSONEncoding{})
|
||||
}
|
||||
}
|
||||
|
||||
// Driver 定义cache驱动接口
|
||||
type Driver interface {
|
||||
Set(key string, val interface{}, expiration time.Duration) error
|
||||
Get(key string) (interface{}, error)
|
||||
@@ -29,30 +30,37 @@ type Driver interface {
|
||||
Decr(key string, step int64) (int64, error)
|
||||
}
|
||||
|
||||
// Set 数据
|
||||
func Set(key string, val interface{}, expiration time.Duration) error {
|
||||
return Cache.Set(key, val, expiration)
|
||||
}
|
||||
|
||||
// Get 数据
|
||||
func Get(key string) (interface{}, error) {
|
||||
return Cache.Get(key)
|
||||
}
|
||||
|
||||
// MultiSet 批量set
|
||||
func MultiSet(valMap map[string]interface{}, expiration time.Duration) error {
|
||||
return Cache.MultiSet(valMap, expiration)
|
||||
}
|
||||
|
||||
// MultiGet 批量获取
|
||||
func MultiGet(keys ...string) (interface{}, error) {
|
||||
return Cache.MultiGet(keys...)
|
||||
}
|
||||
|
||||
// Del 批量删除
|
||||
func Del(keys ...string) error {
|
||||
return Cache.Del(keys...)
|
||||
}
|
||||
|
||||
// Incr 自增
|
||||
func Incr(key string, step int64) (int64, error) {
|
||||
return Cache.Incr(key, step)
|
||||
}
|
||||
|
||||
// Decr 自减
|
||||
func Decr(key string, step int64) (int64, error) {
|
||||
return Cache.Decr(key, step)
|
||||
}
|
||||
|
69
pkg/cache/encoding.go
vendored
69
pkg/cache/encoding.go
vendored
@@ -16,53 +16,59 @@ import (
|
||||
"github.com/vmihailenco/msgpack"
|
||||
)
|
||||
|
||||
// Encoding 编码接口定义
|
||||
type Encoding interface {
|
||||
Marshal(v interface{}) ([]byte, error)
|
||||
Unmarshal(data []byte, v interface{}) error
|
||||
}
|
||||
|
||||
// Marshal encode data
|
||||
func Marshal(e Encoding, v interface{}) (data []byte, err error) {
|
||||
bm, ok := v.(encoding.BinaryMarshaler)
|
||||
if ok && e == nil {
|
||||
data, err = bm.MarshalBinary()
|
||||
return
|
||||
} else {
|
||||
data, err = e.Marshal(v)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
data, err = bm.MarshalBinary()
|
||||
}
|
||||
}
|
||||
|
||||
data, err = e.Marshal(v)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
data, err = bm.MarshalBinary()
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Unmarshal decode data
|
||||
func Unmarshal(e Encoding, data []byte, v interface{}) (err error) {
|
||||
bm, ok := v.(encoding.BinaryUnmarshaler)
|
||||
if ok && e == nil {
|
||||
err = bm.UnmarshalBinary(data)
|
||||
return err
|
||||
} else {
|
||||
err = e.Unmarshal(data, v)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
return bm.UnmarshalBinary(data)
|
||||
}
|
||||
}
|
||||
err = e.Unmarshal(data, v)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if ok {
|
||||
return bm.UnmarshalBinary(data)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type JsonEncoding struct{}
|
||||
// JSONEncoding json格式
|
||||
type JSONEncoding struct{}
|
||||
|
||||
func (this JsonEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
// Marshal json encode
|
||||
func (j JSONEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
buf, err := json.Marshal(v)
|
||||
return buf, err
|
||||
}
|
||||
|
||||
func (this JsonEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
// Unmarshal json decode
|
||||
func (j JSONEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
err := json.Unmarshal(data, value)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -70,9 +76,11 @@ func (this JsonEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GobEncoding gob encode
|
||||
type GobEncoding struct{}
|
||||
|
||||
func (this GobEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
// Marshal gob encode
|
||||
func (g GobEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
var (
|
||||
buffer bytes.Buffer
|
||||
)
|
||||
@@ -81,7 +89,8 @@ func (this GobEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
return buffer.Bytes(), err
|
||||
}
|
||||
|
||||
func (this GobEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
// Unmarshal gob encode
|
||||
func (g GobEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
err := gob.NewDecoder(bytes.NewReader(data)).Decode(value)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -89,9 +98,11 @@ func (this GobEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type JsonGzipEncoding struct{}
|
||||
// JSONGzipEncoding json and gzip
|
||||
type JSONGzipEncoding struct{}
|
||||
|
||||
func (this JsonGzipEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
// Marshal json encode and gzip
|
||||
func (jz JSONGzipEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
buf, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -103,7 +114,8 @@ func (this JsonGzipEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
return buf, err
|
||||
}
|
||||
|
||||
func (this JsonGzipEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
// Unmarshal json encode and gzip
|
||||
func (jz JSONGzipEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
jsonData, err := GzipDecode(data)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -116,6 +128,7 @@ func (this JsonGzipEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GzipEncode 编码
|
||||
func GzipEncode(in []byte) ([]byte, error) {
|
||||
var (
|
||||
buffer bytes.Buffer
|
||||
@@ -143,6 +156,7 @@ func GzipEncode(in []byte) ([]byte, error) {
|
||||
return buffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// GzipDecode 解码
|
||||
func GzipDecode(in []byte) ([]byte, error) {
|
||||
reader, err := gzip.NewReader(bytes.NewReader(in))
|
||||
if err != nil {
|
||||
@@ -182,14 +196,17 @@ func (s JSONSnappyEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
return json.Unmarshal(b, value)
|
||||
}
|
||||
|
||||
// MsgPackEncoding msgpack 格式
|
||||
type MsgPackEncoding struct{}
|
||||
|
||||
func (this MsgPackEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
// Marshal msgpack encode
|
||||
func (mp MsgPackEncoding) Marshal(v interface{}) ([]byte, error) {
|
||||
buf, err := msgpack.Marshal(v)
|
||||
return buf, err
|
||||
}
|
||||
|
||||
func (this MsgPackEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
// Unmarshal msgpack decode
|
||||
func (mp MsgPackEncoding) Unmarshal(data []byte, value interface{}) error {
|
||||
err := msgpack.Unmarshal(data, value)
|
||||
if err != nil {
|
||||
return err
|
||||
|
4
pkg/cache/encoding_test.go
vendored
4
pkg/cache/encoding_test.go
vendored
@@ -7,7 +7,7 @@ func BenchmarkJsonMarshal(b *testing.B) {
|
||||
for i := 0; i < 400; i++ {
|
||||
a = append(a, i)
|
||||
}
|
||||
jsonEncoding := JsonEncoding{}
|
||||
jsonEncoding := JSONEncoding{}
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, err := jsonEncoding.Marshal(a)
|
||||
if err != nil {
|
||||
@@ -21,7 +21,7 @@ func BenchmarkJsonUnmarshal(b *testing.B) {
|
||||
for i := 0; i < 400; i++ {
|
||||
a = append(a, i)
|
||||
}
|
||||
jsonEncoding := JsonEncoding{}
|
||||
jsonEncoding := JSONEncoding{}
|
||||
data, err := jsonEncoding.Marshal(a)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
|
1
pkg/cache/key.go
vendored
1
pkg/cache/key.go
vendored
@@ -5,6 +5,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// BuildCacheKey 构建一个带有前缀的缓存key
|
||||
func BuildCacheKey(keyPrefix string, key string) (cacheKey string, err error) {
|
||||
if key == "" {
|
||||
return "", errors.New("[cache] key should not be empty")
|
||||
|
24
pkg/cache/memory.go
vendored
24
pkg/cache/memory.go
vendored
@@ -1,8 +1,6 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -17,7 +15,8 @@ type memoryCache struct {
|
||||
encoding Encoding
|
||||
}
|
||||
|
||||
func NewMemoryCache(keyPrefix string, encoding Encoding) *memoryCache {
|
||||
// NewMemoryCache 实例化一个内存cache
|
||||
func NewMemoryCache(keyPrefix string, encoding Encoding) Driver {
|
||||
return &memoryCache{
|
||||
client: &sync.Map{},
|
||||
KeyPrefix: keyPrefix,
|
||||
@@ -31,6 +30,7 @@ type itemWithTTL struct {
|
||||
value interface{}
|
||||
}
|
||||
|
||||
// newItem 返回带有效期的value
|
||||
func newItem(value interface{}, expires time.Duration) itemWithTTL {
|
||||
expires64 := int64(expires)
|
||||
if expires > 0 {
|
||||
@@ -60,17 +60,7 @@ func getValue(item interface{}, ok bool) (interface{}, bool) {
|
||||
return itemObj.value, true
|
||||
}
|
||||
|
||||
// interface 转 byte
|
||||
func GetBytes(key interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// Set data
|
||||
func (m memoryCache) Set(key string, val interface{}, expiration time.Duration) error {
|
||||
cacheKey, err := BuildCacheKey(m.KeyPrefix, key)
|
||||
if err != nil {
|
||||
@@ -80,6 +70,7 @@ func (m memoryCache) Set(key string, val interface{}, expiration time.Duration)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get data
|
||||
func (m memoryCache) Get(key string) (interface{}, error) {
|
||||
cacheKey, err := BuildCacheKey(m.KeyPrefix, key)
|
||||
if err != nil {
|
||||
@@ -92,14 +83,17 @@ func (m memoryCache) Get(key string) (interface{}, error) {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// MultiSet 批量set
|
||||
func (m memoryCache) MultiSet(valMap map[string]interface{}, expiration time.Duration) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// MultiGet 批量获取
|
||||
func (m memoryCache) MultiGet(keys ...string) (interface{}, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// Del 批量删除
|
||||
func (m memoryCache) Del(keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
@@ -117,10 +111,12 @@ func (m memoryCache) Del(keys ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Incr 自增
|
||||
func (m memoryCache) Incr(key string, step int64) (int64, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// Decr 自减
|
||||
func (m memoryCache) Decr(key string, step int64) (int64, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
6
pkg/cache/memory_test.go
vendored
6
pkg/cache/memory_test.go
vendored
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
func Test_memoryCache_SetGet(t *testing.T) {
|
||||
// 实例化memory cache
|
||||
cache := NewMemoryCache("memory-unit-test", JsonEncoding{})
|
||||
cache := NewMemoryCache("memory-unit-test", JSONEncoding{})
|
||||
|
||||
// test set
|
||||
type setArgs struct {
|
||||
@@ -19,7 +19,7 @@ func Test_memoryCache_SetGet(t *testing.T) {
|
||||
|
||||
setTests := []struct {
|
||||
name string
|
||||
cache *memoryCache
|
||||
cache Driver
|
||||
args setArgs
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -47,7 +47,7 @@ func Test_memoryCache_SetGet(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cache *memoryCache
|
||||
cache Driver
|
||||
args args
|
||||
wantVal interface{}
|
||||
wantErr bool
|
||||
|
6
pkg/cache/redis.go
vendored
6
pkg/cache/redis.go
vendored
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// redisCache redis cache结构体
|
||||
type redisCache struct {
|
||||
client *redis.Client
|
||||
KeyPrefix string
|
||||
@@ -19,11 +20,12 @@ type redisCache struct {
|
||||
}
|
||||
|
||||
const (
|
||||
// DefaultExpireTime 默认过期时间
|
||||
DefaultExpireTime = 60 * time.Second
|
||||
)
|
||||
|
||||
// client 参数是可传入的,这样方便进行单元测试
|
||||
func NewRedisCache(client *redis.Client, keyPrefix string, encoding Encoding) *redisCache {
|
||||
// NewRedisCache new一个redis cache, client 参数是可传入的,这样方便进行单元测试
|
||||
func NewRedisCache(client *redis.Client, keyPrefix string, encoding Encoding) Driver {
|
||||
return &redisCache{
|
||||
client: client,
|
||||
KeyPrefix: keyPrefix,
|
||||
|
6
pkg/cache/redis_test.go
vendored
6
pkg/cache/redis_test.go
vendored
@@ -15,7 +15,7 @@ func Test_redisCache_SetGet(t *testing.T) {
|
||||
// 获取redis客户端
|
||||
redisClient := redis2.Client
|
||||
// 实例化redis cache
|
||||
cache := NewRedisCache(redisClient, "unit-test", JsonEncoding{})
|
||||
cache := NewRedisCache(redisClient, "unit-test", JSONEncoding{})
|
||||
|
||||
// test set
|
||||
type setArgs struct {
|
||||
@@ -26,7 +26,7 @@ func Test_redisCache_SetGet(t *testing.T) {
|
||||
|
||||
setTests := []struct {
|
||||
name string
|
||||
cache *redisCache
|
||||
cache Driver
|
||||
args setArgs
|
||||
wantErr bool
|
||||
}{
|
||||
@@ -54,7 +54,7 @@ func Test_redisCache_SetGet(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cache *redisCache
|
||||
cache Driver
|
||||
args args
|
||||
wantVal interface{}
|
||||
wantErr bool
|
||||
|
@@ -1,5 +1,6 @@
|
||||
package errno
|
||||
|
||||
//nolint: golint
|
||||
var (
|
||||
// Common errors
|
||||
OK = &Errno{Code: 0, Message: "OK"}
|
||||
|
@@ -23,6 +23,7 @@ func (err *Err) Error() string {
|
||||
return fmt.Sprintf("Err - code: %d, message: %s, error: %s", err.Code, err.Message, err.Err)
|
||||
}
|
||||
|
||||
// DecodeErr 对错误进行解码,返回错误code和错误提示
|
||||
func DecodeErr(err error) (int, string) {
|
||||
if err == nil {
|
||||
return OK.Code, OK.Message
|
||||
|
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
// see: https://github.com/iiinsomnia/gochat/blob/master/utils/http.go
|
||||
|
||||
// 禁止直接调用resty,统一使用HttpClient
|
||||
// HTTPClient 禁止直接调用resty,统一使用HttpClient
|
||||
var HTTPClient = New("resty")
|
||||
|
||||
// New 实例化一个client
|
||||
|
@@ -2,7 +2,7 @@ package http
|
||||
|
||||
import "time"
|
||||
|
||||
// http client 接口
|
||||
// Client 定义 http client 接口
|
||||
type Client interface {
|
||||
Get(url string, params map[string]string, duration time.Duration) ([]byte, error)
|
||||
Post(url string, requestBody string, duration time.Duration) ([]byte, error)
|
||||
|
@@ -18,8 +18,9 @@ import (
|
||||
|
||||
var logger *zap.Logger
|
||||
|
||||
// InitLogger 初始化logger
|
||||
func InitLogger() *zap.Logger {
|
||||
encoder := getJsonEncoder()
|
||||
encoder := getJSONEncoder()
|
||||
|
||||
// 注意:如果多个文件,最后一个会是全的,前两个可能会丢日志
|
||||
infoFilename := viper.GetString("log.logger_file")
|
||||
@@ -58,7 +59,7 @@ func InitLogger() *zap.Logger {
|
||||
return logger
|
||||
}
|
||||
|
||||
func getJsonEncoder() zapcore.Encoder {
|
||||
func getJSONEncoder() zapcore.Encoder {
|
||||
encoderConfig := zapcore.EncoderConfig{
|
||||
MessageKey: "msg",
|
||||
LevelKey: "level",
|
||||
@@ -94,26 +95,32 @@ func getLogWriterWithTime(filename string) io.Writer {
|
||||
return hook
|
||||
}
|
||||
|
||||
// Debug log
|
||||
func Debug(msg string, args ...zap.Field) {
|
||||
logger.Debug(msg, args...)
|
||||
}
|
||||
|
||||
// Info log
|
||||
func Info(msg string, args ...zap.Field) {
|
||||
logger.Info(msg, args...)
|
||||
}
|
||||
|
||||
// Warn log
|
||||
func Warn(msg string, args ...zap.Field) {
|
||||
logger.Warn(msg, args...)
|
||||
}
|
||||
|
||||
// Error log
|
||||
func Error(msg string, args ...zap.Field) {
|
||||
logger.Error(msg, args...)
|
||||
}
|
||||
|
||||
// Fatal log
|
||||
func Fatal(msg string, args ...zap.Field) {
|
||||
logger.Fatal(msg, args...)
|
||||
}
|
||||
|
||||
// Infof log
|
||||
func Infof(format string, args ...interface{}) {
|
||||
message := fmt.Sprintf(format, args...)
|
||||
logger.Info(message)
|
||||
|
@@ -10,10 +10,10 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// redis 客户端
|
||||
// Client redis 客户端
|
||||
var Client *redis.Client
|
||||
|
||||
// redis 返回为空
|
||||
// Nil redis 返回为空
|
||||
const Nil = redis.Nil
|
||||
|
||||
// Init 实例化一个redis client
|
||||
|
@@ -1,6 +1,8 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -9,22 +11,25 @@ import (
|
||||
tnet "github.com/toolkits/net"
|
||||
)
|
||||
|
||||
func GenShortId() (string, error) {
|
||||
// GenShortID 生成一个id
|
||||
func GenShortID() (string, error) {
|
||||
return shortid.Generate()
|
||||
}
|
||||
|
||||
// GenUUID 生成随机字符串
|
||||
func GenUUID() string {
|
||||
u, _ := uuid.NewRandom()
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// GetReqID 获取请求中的request_id
|
||||
func GetReqID(c *gin.Context) string {
|
||||
v, ok := c.Get("X-Request-Id")
|
||||
v, ok := c.Get("X-Request-ID")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
if requestId, ok := v.(string); ok {
|
||||
return requestId
|
||||
if requestID, ok := v.(string); ok {
|
||||
return requestID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -46,3 +51,14 @@ func GetLocalIP() string {
|
||||
})
|
||||
return clientIP
|
||||
}
|
||||
|
||||
// GetBytes interface 转 byte
|
||||
func GetBytes(key interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
@@ -4,32 +4,32 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenShortId(t *testing.T) {
|
||||
shortId, err := GenShortId()
|
||||
if shortId == "" || err != nil {
|
||||
t.Error("GenShortId failed!")
|
||||
func TestGenShortID(t *testing.T) {
|
||||
shortID, err := GenShortID()
|
||||
if shortID == "" || err != nil {
|
||||
t.Error("GenShortID failed!")
|
||||
}
|
||||
|
||||
t.Log("GenShortId test pass")
|
||||
t.Log("GenShortID test pass")
|
||||
}
|
||||
|
||||
func BenchmarkGenShortId(b *testing.B) {
|
||||
func BenchmarkGenShortID(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenShortId()
|
||||
GenShortID()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGenShortIdTimeConsuming(b *testing.B) {
|
||||
func BenchmarkGenShortIDTimeConsuming(b *testing.B) {
|
||||
b.StopTimer() //调用该函数停止压力测试的时间计数
|
||||
|
||||
shortId, err := GenShortId()
|
||||
if shortId == "" || err != nil {
|
||||
shortID, err := GenShortID()
|
||||
if shortID == "" || err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
|
||||
b.StartTimer() //重新开始时间
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
GenShortId()
|
||||
GenShortID()
|
||||
}
|
||||
}
|
||||
|
@@ -21,6 +21,7 @@ func (info *Info) String() string {
|
||||
return info.GitTag
|
||||
}
|
||||
|
||||
// Get 返回详细的版本信息
|
||||
func Get() Info {
|
||||
return Info{
|
||||
GitTag: gitTag,
|
||||
|
@@ -7,34 +7,36 @@ import (
|
||||
"github.com/lexkong/log"
|
||||
)
|
||||
|
||||
// IUserRepo 定义用户仓库接口
|
||||
type IUserRepo interface {
|
||||
CreateUser(db *gorm.DB, user model.UserModel) (id uint64, err error)
|
||||
GetUserById(id uint64) (*model.UserModel, error)
|
||||
GetUserByID(id uint64) (*model.UserModel, error)
|
||||
GetUserByPhone(phone int) (*model.UserModel, error)
|
||||
GetUserByEmail(email string) (*model.UserModel, error)
|
||||
GetUsersByIds(ids []uint64) ([]*model.UserModel, error)
|
||||
Update(userMap map[string]interface{}, id uint64) error
|
||||
}
|
||||
|
||||
type UserRepo struct {
|
||||
}
|
||||
// userRepo 用户仓库
|
||||
type userRepo struct{}
|
||||
|
||||
// NewUserRepo 实例化用户仓库
|
||||
func NewUserRepo() IUserRepo {
|
||||
return &UserRepo{}
|
||||
return &userRepo{}
|
||||
}
|
||||
|
||||
// CreateUser 创建用户
|
||||
func (repo *UserRepo) CreateUser(db *gorm.DB, user model.UserModel) (id uint64, err error) {
|
||||
func (repo *userRepo) CreateUser(db *gorm.DB, user model.UserModel) (id uint64, err error) {
|
||||
err = db.Create(&user).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return user.Id, nil
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
// GetUserByID 获取用户
|
||||
func (repo *UserRepo) GetUserById(id uint64) (*model.UserModel, error) {
|
||||
func (repo *userRepo) GetUserByID(id uint64) (*model.UserModel, error) {
|
||||
user := &model.UserModel{}
|
||||
result := model.GetDB().Where("id = ?", id).First(user)
|
||||
|
||||
@@ -42,7 +44,7 @@ func (repo *UserRepo) GetUserById(id uint64) (*model.UserModel, error) {
|
||||
}
|
||||
|
||||
// GetUserByPhone 根据手机号获取用户
|
||||
func (repo *UserRepo) GetUserByPhone(phone int) (*model.UserModel, error) {
|
||||
func (repo *userRepo) GetUserByPhone(phone int) (*model.UserModel, error) {
|
||||
user := model.UserModel{}
|
||||
result := model.GetDB().Where("phone = ?", phone).First(&user)
|
||||
|
||||
@@ -52,7 +54,7 @@ func (repo *UserRepo) GetUserByPhone(phone int) (*model.UserModel, error) {
|
||||
}
|
||||
|
||||
// GetUserByEmail 根据邮箱获取手机号
|
||||
func (repo *UserRepo) GetUserByEmail(phone string) (*model.UserModel, error) {
|
||||
func (repo *userRepo) GetUserByEmail(phone string) (*model.UserModel, error) {
|
||||
user := model.UserModel{}
|
||||
result := model.GetDB().Where("email = ?", phone).First(&user)
|
||||
|
||||
@@ -62,7 +64,7 @@ func (repo *UserRepo) GetUserByEmail(phone string) (*model.UserModel, error) {
|
||||
}
|
||||
|
||||
// GetUsersByIds 批量获取用户
|
||||
func (repo *UserRepo) GetUsersByIds(ids []uint64) ([]*model.UserModel, error) {
|
||||
func (repo *userRepo) GetUsersByIds(ids []uint64) ([]*model.UserModel, error) {
|
||||
users := make([]*model.UserModel, 0)
|
||||
result := model.GetDB().Where("id in (?)", ids).Find(&users)
|
||||
|
||||
@@ -70,8 +72,8 @@ func (repo *UserRepo) GetUsersByIds(ids []uint64) ([]*model.UserModel, error) {
|
||||
}
|
||||
|
||||
// Update 更新用户信息
|
||||
func (repo *UserRepo) Update(userMap map[string]interface{}, id uint64) error {
|
||||
user, err := repo.GetUserById(id)
|
||||
func (repo *userRepo) Update(userMap map[string]interface{}, id uint64) error {
|
||||
user, err := repo.GetUserByID(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -4,11 +4,12 @@ import (
|
||||
"github.com/1024casts/snake/handler"
|
||||
"github.com/1024casts/snake/pkg/errno"
|
||||
"github.com/1024casts/snake/pkg/token"
|
||||
"github.com/lexkong/log"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/lexkong/log"
|
||||
)
|
||||
|
||||
// AuthMiddleware 认证中间件
|
||||
func AuthMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Parse the json web token.
|
||||
|
@@ -24,7 +24,8 @@ package middleware
|
||||
// Tracer, Closer, Error = jaeger_trace.NewJaegerTracer(config.AppName, config.JaegerHostPort)
|
||||
// defer Closer.Close()
|
||||
//
|
||||
// spCtx, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(c.Request.Header))
|
||||
// spCtx, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders,
|
||||
// opentracing.HTTPHeadersCarrier(c.Request.Header))
|
||||
// if err != nil {
|
||||
// ParentSpan = Tracer.StartSpan(c.Request.URL.Path)
|
||||
// defer ParentSpan.Finish()
|
||||
|
@@ -2,24 +2,26 @@ package middleware
|
||||
|
||||
import (
|
||||
"github.com/1024casts/snake/pkg/util"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func RequestId() gin.HandlerFunc {
|
||||
// RequestID 透传Request-ID,如果没有则生成一个
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Check for incoming header, use it if exists
|
||||
requestId := c.Request.Header.Get("X-Request-Id")
|
||||
requestID := c.Request.Header.Get("X-Request-ID")
|
||||
|
||||
// Create request id with UUID4
|
||||
if requestId == "" {
|
||||
requestId = util.GenUUID()
|
||||
if requestID == "" {
|
||||
requestID = util.GenUUID()
|
||||
}
|
||||
|
||||
// Expose it for use in the application
|
||||
c.Set("X-Request-Id", requestId)
|
||||
c.Set("X-Request-ID", requestID)
|
||||
|
||||
// Set X-Request-Id header
|
||||
c.Writer.Header().Set("X-Request-Id", requestId)
|
||||
// Set X-Request-ID header
|
||||
c.Writer.Header().Set("X-Request-ID", requestID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
@@ -3,15 +3,16 @@ package router
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
_ "github.com/1024casts/snake/docs" // docs is generated by Swag CLI, you have to import it.
|
||||
"github.com/1024casts/snake/handler/sd"
|
||||
"github.com/1024casts/snake/handler/user"
|
||||
"github.com/1024casts/snake/router/middleware"
|
||||
|
||||
"github.com/gin-contrib/pprof"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/swaggo/gin-swagger"
|
||||
|
||||
"github.com/swaggo/gin-swagger" //nolint: goimports
|
||||
"github.com/swaggo/gin-swagger/swaggerFiles"
|
||||
|
||||
// import swagger handler
|
||||
_ "github.com/1024casts/snake/docs" // docs is generated by Swag CLI, you have to import it.
|
||||
"github.com/1024casts/snake/handler/user"
|
||||
"github.com/1024casts/snake/router/middleware"
|
||||
)
|
||||
|
||||
// Load loads the middlewares, routes, handlers.
|
||||
@@ -48,14 +49,5 @@ func Load(g *gin.Engine, mw ...gin.HandlerFunc) *gin.Engine {
|
||||
u.PUT("/:id", user.Update)
|
||||
}
|
||||
|
||||
// The health check handlers
|
||||
svcd := g.Group("/sd")
|
||||
{
|
||||
svcd.GET("/health", sd.HealthCheck)
|
||||
svcd.GET("/disk", sd.DiskCheck)
|
||||
svcd.GET("/cpu", sd.CPUCheck)
|
||||
svcd.GET("/ram", sd.RAMCheck)
|
||||
}
|
||||
|
||||
return g
|
||||
}
|
||||
|
@@ -7,21 +7,26 @@ import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// 短信服务
|
||||
// ServiceSms 短信服务
|
||||
// 使用七牛云
|
||||
// 直接初始化,可以避免在使用时再实例化
|
||||
var ServiceSms = NewSmsService()
|
||||
|
||||
// 校验码服务,生成校验码和获得校验码
|
||||
type smsService struct {
|
||||
// ISmsService 短信服务接口定义
|
||||
type ISmsService interface {
|
||||
Send(phoneNumber string, verifyCode int) error
|
||||
_sendViaQiNiu(phoneNumber string, verifyCode int) error
|
||||
}
|
||||
|
||||
// smsService 校验码服务,生成校验码和获得校验码
|
||||
type smsService struct{}
|
||||
|
||||
// NewSmsService 实例化一个sms
|
||||
func NewSmsService() *smsService {
|
||||
func NewSmsService() ISmsService {
|
||||
return &smsService{}
|
||||
}
|
||||
|
||||
// 发送短信
|
||||
// Send 发送短信
|
||||
func (srv *smsService) Send(phoneNumber string, verifyCode int) error {
|
||||
// 校验参数的正确性
|
||||
if phoneNumber == "" || verifyCode == 0 {
|
||||
@@ -32,9 +37,8 @@ func (srv *smsService) Send(phoneNumber string, verifyCode int) error {
|
||||
return srv._sendViaQiNiu(phoneNumber, verifyCode)
|
||||
}
|
||||
|
||||
// 调用七牛短信服务
|
||||
// _sendViaQiNiu 调用七牛短信服务
|
||||
func (srv *smsService) _sendViaQiNiu(phoneNumber string, verifyCode int) error {
|
||||
|
||||
accessKey := viper.GetString("qiniu.access_key")
|
||||
secretKey := viper.GetString("qiniu.secret_key")
|
||||
|
||||
|
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// IUserService 用户服务接口定义
|
||||
type IUserService interface {
|
||||
CreateUser(user model.UserModel) (id uint64, err error)
|
||||
UpdateUser(userMap map[string]interface{}, id uint64) error
|
||||
@@ -17,7 +18,7 @@ type IUserService interface {
|
||||
GetUserByEmail(email string) (*model.UserModel, error)
|
||||
}
|
||||
|
||||
// 直接初始化,可以避免在使用时再实例化
|
||||
// UserService 直接初始化,可以避免在使用时再实例化
|
||||
var UserService = NewUserService()
|
||||
|
||||
type userService struct {
|
||||
@@ -51,7 +52,7 @@ func (srv *userService) UpdateUser(userMap map[string]interface{}, id uint64) er
|
||||
}
|
||||
|
||||
func (srv *userService) GetUserByID(id uint64) (*model.UserModel, error) {
|
||||
userModel, err := srv.userRepo.GetUserById(id)
|
||||
userModel, err := srv.userRepo.GetUserByID(id)
|
||||
if err != nil {
|
||||
return userModel, errors.Wrapf(err, "get user info err from db by id: %d", id)
|
||||
}
|
||||
@@ -69,7 +70,7 @@ func (srv *userService) GetUserListByIds(id []uint64) (map[uint64]*model.UserMod
|
||||
}
|
||||
|
||||
for _, v := range userModels {
|
||||
retMap[v.Id] = v
|
||||
retMap[v.ID] = v
|
||||
}
|
||||
|
||||
return retMap, nil
|
||||
|
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// 验证码服务,主要提供生成验证码和获取验证码
|
||||
// VCodeService 验证码服务,主要提供生成验证码和获取验证码
|
||||
// 直接初始化,可以避免在使用时再实例化
|
||||
var VCodeService = NewVCodeService()
|
||||
|
||||
@@ -21,15 +21,26 @@ const (
|
||||
maxDurationTime = 10 * time.Minute // 验证码有效期
|
||||
)
|
||||
|
||||
// 校验码服务,生成校验码和获得校验码
|
||||
type vcodeService struct {
|
||||
// IVerifyCodeService 校验码服务接口定义
|
||||
type IVerifyCodeService interface {
|
||||
// public func
|
||||
GenLoginVCode(phone string) (int, error)
|
||||
CheckLoginVCode(phone, vCode int) bool
|
||||
GetLoginVCode(phone int) (int, error)
|
||||
|
||||
// private func
|
||||
isTestPhone(phone int) bool
|
||||
}
|
||||
|
||||
func NewVCodeService() *vcodeService {
|
||||
// vcodeService 校验码服务,生成校验码和获得校验码
|
||||
type vcodeService struct{}
|
||||
|
||||
// NewVCodeService 实例化一个验证码服务
|
||||
func NewVCodeService() IVerifyCodeService {
|
||||
return &vcodeService{}
|
||||
}
|
||||
|
||||
// 生成校验码
|
||||
// GenLoginVCode 生成校验码
|
||||
func (srv *vcodeService) GenLoginVCode(phone string) (int, error) {
|
||||
// step1: 生成随机数
|
||||
vCodeStr := fmt.Sprintf("%06v", rand.New(rand.NewSource(time.Now().UnixNano())).Int31n(1000000))
|
||||
@@ -55,7 +66,7 @@ var phoneWhiteLit = []int{
|
||||
13010102020,
|
||||
}
|
||||
|
||||
// 这里可以添加测试号,直接通过
|
||||
// isTestPhone 这里可以添加测试号,直接通过
|
||||
func (srv *vcodeService) isTestPhone(phone int) bool {
|
||||
for _, val := range phoneWhiteLit {
|
||||
if val == phone {
|
||||
@@ -65,7 +76,7 @@ func (srv *vcodeService) isTestPhone(phone int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证校验码是否正确
|
||||
// CheckLoginVCode 验证校验码是否正确
|
||||
func (srv *vcodeService) CheckLoginVCode(phone, vCode int) bool {
|
||||
if srv.isTestPhone(phone) {
|
||||
return true
|
||||
@@ -84,7 +95,7 @@ func (srv *vcodeService) CheckLoginVCode(phone, vCode int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// 获得校验码
|
||||
// GetLoginVCode 获得校验码
|
||||
func (srv *vcodeService) GetLoginVCode(phone int) (int, error) {
|
||||
// 直接从redis里获取
|
||||
key := fmt.Sprintf(verifyCodeRedisKey, phone)
|
||||
|
Reference in New Issue
Block a user