lint: fix golangci-lint tips

This commit is contained in:
qloog
2020-04-10 14:23:19 +08:00
parent eb5eab12b8
commit f314c00dfd
36 changed files with 254 additions and 169 deletions

View File

@@ -27,7 +27,7 @@ func SendResponse(c *gin.Context, err error, data interface{}) {
})
}
// 返回用户id
// GetUserID 返回用户id
func GetUserID(c *gin.Context) uint64 {
if c == nil {
return 0

View File

@@ -1,7 +1,7 @@
package user
import (
. "github.com/1024casts/snake/handler"
"github.com/1024casts/snake/handler"
"github.com/gin-gonic/gin"
)
@@ -22,5 +22,5 @@ func Delete(c *gin.Context) {
// return
//}
SendResponse(c, nil, nil)
handler.SendResponse(c, nil, nil)
}

View File

@@ -3,7 +3,8 @@ package user
import (
"strconv"
. "github.com/1024casts/snake/handler"
"github.com/1024casts/snake/handler"
"github.com/1024casts/snake/pkg/errno"
"github.com/1024casts/snake/service/user"
@@ -25,7 +26,7 @@ func Get(c *gin.Context) {
userIDStr := c.Param("id")
if userIDStr == "" {
SendResponse(c, errno.ErrParam, nil)
handler.SendResponse(c, errno.ErrParam, nil)
return
}
userID, _ := strconv.Atoi(userIDStr)
@@ -33,9 +34,9 @@ func Get(c *gin.Context) {
// Get the user by the `user_id` from the database.
u, err := user.UserService.GetUserByID(uint64(userID))
if err != nil {
SendResponse(c, errno.ErrUserNotFound, nil)
handler.SendResponse(c, errno.ErrUserNotFound, nil)
return
}
SendResponse(c, nil, u)
handler.SendResponse(c, nil, u)
}

View File

@@ -23,7 +23,7 @@ func TestGet(t *testing.T) {
userTests := []model.UserModel{
{
BaseModel: model.BaseModel{
Id: 12,
ID: 12,
},
Username: "user001",
Password: "123456",
@@ -33,7 +33,7 @@ func TestGet(t *testing.T) {
},
{
BaseModel: model.BaseModel{
Id: 13,
ID: 13,
},
Username: "user002",
Password: "123456",
@@ -48,7 +48,7 @@ func TestGet(t *testing.T) {
app.DB.Create(user)
// Set up a new request.
req, err := http.NewRequest("GET", fmt.Sprintf("/v1/users/%d", user.Id), nil)
req, err := http.NewRequest("GET", fmt.Sprintf("/v1/users/%d", user.ID), nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -3,7 +3,8 @@ package user
import (
"strconv"
. "github.com/1024casts/snake/handler"
"github.com/1024casts/snake/handler"
"github.com/1024casts/snake/model"
"github.com/1024casts/snake/pkg/errno"
"github.com/1024casts/snake/pkg/token"
@@ -28,20 +29,20 @@ func Login(c *gin.Context) {
// Binding the data with the u struct.
var req PhoneLoginCredentials
if err := c.Bind(&req); err != nil {
SendResponse(c, errno.ErrBind, nil)
handler.SendResponse(c, errno.ErrBind, nil)
return
}
log.Infof("req %#v", req)
// check param
if req.Phone == 0 || req.VerifyCode == 0 {
SendResponse(c, errno.ErrParam, nil)
handler.SendResponse(c, errno.ErrParam, nil)
return
}
// 验证校验码
if !vcode.VCodeService.CheckLoginVCode(req.Phone, req.VerifyCode) {
SendResponse(c, errno.ErrVerifyCode, nil)
handler.SendResponse(c, errno.ErrVerifyCode, nil)
return
}
@@ -52,28 +53,28 @@ func Login(c *gin.Context) {
}
// 否则新建用户信息, 并取得用户信息
if u.Id == 0 {
if u.ID == 0 {
u := model.UserModel{
Phone: req.Phone,
Username: strconv.Itoa(req.Phone),
}
u.Id, err = user.UserService.CreateUser(u)
u.ID, err = user.UserService.CreateUser(u)
if err != nil {
log.Warnf("[login] create u err, %v", err)
SendResponse(c, errno.InternalServerError, nil)
handler.SendResponse(c, errno.InternalServerError, nil)
return
}
}
// 签发签名 Sign the json web token.
t, err := token.Sign(c, token.Context{UserID: u.Id, Username: u.Username}, "")
t, err := token.Sign(c, token.Context{UserID: u.ID, Username: u.Username}, "")
if err != nil {
log.Warnf("[login] gen token sign err:, %v", err)
SendResponse(c, errno.ErrToken, nil)
handler.SendResponse(c, errno.ErrToken, nil)
return
}
SendResponse(c, nil, model.Token{
handler.SendResponse(c, nil, model.Token{
Token: t,
})
}

View File

@@ -3,7 +3,8 @@ package user
import (
"strconv"
. "github.com/1024casts/snake/handler"
"github.com/1024casts/snake/handler"
"github.com/1024casts/snake/pkg/errno"
"github.com/1024casts/snake/service/user"
@@ -29,7 +30,7 @@ func Update(c *gin.Context) {
// Binding the user data.
var req UpdateRequest
if err := c.Bind(&req); err != nil {
SendResponse(c, errno.ErrBind, nil)
handler.SendResponse(c, errno.ErrBind, nil)
return
}
@@ -39,9 +40,9 @@ func Update(c *gin.Context) {
err := user.UserService.UpdateUser(userMap, uint64(userID))
if err != nil {
log.Warnf("[user] update user err, %v", err)
SendResponse(c, errno.InternalServerError, nil)
handler.SendResponse(c, errno.InternalServerError, nil)
return
}
SendResponse(c, nil, userID)
handler.SendResponse(c, nil, userID)
}

View File

@@ -1,10 +1,11 @@
package user
import (
. "github.com/1024casts/snake/handler"
"github.com/1024casts/snake/handler"
"github.com/1024casts/snake/pkg/errno"
"github.com/1024casts/snake/service/sms"
"github.com/1024casts/snake/service/vcode"
"github.com/gin-gonic/gin"
"github.com/lexkong/log"
"github.com/pkg/errors"
@@ -25,13 +26,13 @@ func VCode(c *gin.Context) {
// 验证区号和手机号是否为空
if c.Query("area_code") == "" {
SendResponse(c, errno.ErrAreaCodeEmpty, nil)
handler.SendResponse(c, errno.ErrAreaCodeEmpty, nil)
return
}
phone := c.Query("phone")
if phone == "" {
SendResponse(c, errno.ErrPhoneEmpty, nil)
handler.SendResponse(c, errno.ErrPhoneEmpty, nil)
return
}
@@ -41,7 +42,7 @@ func VCode(c *gin.Context) {
verifyCode, err := vcode.VCodeService.GenLoginVCode(phone)
if err != nil {
log.Warnf("gen login verify code err, %v", errors.WithStack(err))
SendResponse(c, errno.ErrGenVCode, nil)
handler.SendResponse(c, errno.ErrGenVCode, nil)
return
}
@@ -49,9 +50,9 @@ func VCode(c *gin.Context) {
err = sms.ServiceSms.Send(phone, verifyCode)
if err != nil {
log.Warnf("send phone sms err, %v", errors.WithStack(err))
SendResponse(c, errno.ErrSendSMS, nil)
handler.SendResponse(c, errno.ErrSendSMS, nil)
return
}
SendResponse(c, nil, nil)
handler.SendResponse(c, nil, nil)
}

View File

@@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"net/http"
// http pprof
_ "net/http/pprof"
"os"
"os/signal"
@@ -80,7 +82,7 @@ func main() {
// Middlwares.
middleware.Logging(),
middleware.RequestId(),
middleware.RequestID(),
)
// Ping the server to make sure the router is working.
@@ -133,6 +135,7 @@ func main() {
select {
case <-ctx.Done():
log.Info("timeout of 5 seconds.")
default:
}
log.Info("Server exiting")
}

View File

@@ -9,16 +9,20 @@ import (
// MySQL driver.
"github.com/jinzhu/gorm"
// GORM MySQL
_ "github.com/jinzhu/gorm/dialects/mysql"
)
// Database 定义现有的数据库
type Database struct {
Self *gorm.DB
Docker *gorm.DB
}
// DB 数据库全局变量
var DB *Database
// openDB 链接数据库,生成数据库实例
func openDB(username, password, addr, name string) *gorm.DB {
config := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=%t&loc=%s",
username,
@@ -42,6 +46,7 @@ func openDB(username, password, addr, name string) *gorm.DB {
return db
}
// setupDB 配置数据库
func setupDB(db *gorm.DB) {
db.LogMode(viper.GetBool("gorm.show_log"))
// 用于设置最大打开的连接数默认值为0表示不限制.设置最大的连接数可以避免并发太高导致连接mysql出现too many connections的错误。
@@ -51,7 +56,7 @@ func setupDB(db *gorm.DB) {
db.DB().SetConnMaxLifetime(time.Minute * viper.GetDuration("grom.conn_max_lift_time"))
}
// used for cli
// InitSelfDB used for cli
func InitSelfDB() *gorm.DB {
return openDB(viper.GetString("db.username"),
viper.GetString("db.password"),
@@ -59,10 +64,12 @@ func InitSelfDB() *gorm.DB {
viper.GetString("db.name"))
}
// GetSelfDB 获取self数据库示例
func GetSelfDB() *gorm.DB {
return InitSelfDB()
}
// InitDockerDB 初始化一个docker数据库
func InitDockerDB() *gorm.DB {
return openDB(viper.GetString("docker_db.username"),
viper.GetString("docker_db.password"),
@@ -70,10 +77,12 @@ func InitDockerDB() *gorm.DB {
viper.GetString("docker_db.name"))
}
// GetDockerDB 获取docker数据库
func GetDockerDB() *gorm.DB {
return InitDockerDB()
}
// Init 初始化数据库
func (db *Database) Init() {
DB = &Database{
Self: GetSelfDB(),
@@ -81,10 +90,12 @@ func (db *Database) Init() {
}
}
// GetDB 返回默认的数据库
func GetDB() *gorm.DB {
return DB.Self
}
// Close 关闭数据库链接
func (db *Database) Close() {
err := DB.Self.Close()
if err != nil {

View File

@@ -6,13 +6,15 @@ import (
"time"
)
// BaseModel 公共model
type BaseModel struct {
Id uint64 `gorm:"primary_key;AUTO_INCREMENT;column:id" json:"-"`
ID uint64 `gorm:"primary_key;AUTO_INCREMENT;column:id" json:"-"`
CreatedAt time.Time `gorm:"column:created_at" json:"-"`
UpdatedAt time.Time `gorm:"column:updated_at" json:"-"`
DeletedAt *time.Time `gorm:"column:deleted_at" sql:"index" json:"-"`
}
// NullType 空字节类型
type NullType byte
const (
@@ -23,7 +25,7 @@ const (
IsNotNull
)
// sql build where
// WhereBuild sql build where
// see: https://github.com/jinzhu/gorm/issues/2055
func WhereBuild(where map[string]interface{}) (whereSQL string, vals []interface{}, err error) {
for k, v := range where {
@@ -52,7 +54,6 @@ func WhereBuild(where map[string]interface{}) (whereSQL string, vals []interface
whereSQL += fmt.Sprint(k, "=?")
vals = append(vals, v)
}
break
case 2:
k = ks[0]
switch ks[1] {
@@ -84,7 +85,6 @@ func WhereBuild(where map[string]interface{}) (whereSQL string, vals []interface
whereSQL += fmt.Sprint(k, " like ?")
vals = append(vals, v)
}
break
}
}
return

View File

@@ -9,7 +9,7 @@ import (
validator "github.com/go-playground/validator/v10"
)
// User represents a registered user.
// UserModel User represents a registered user.
type UserModel struct {
BaseModel
Username string `json:"username" gorm:"column:username;not null" binding:"required" validate:"min=1,max=32"`
@@ -27,21 +27,24 @@ func (u *UserModel) Validate() error {
return validate.Struct(u)
}
// UserInfo 对外暴露的结构体
type UserInfo struct {
Id uint64 `json:"id" example:"1"`
ID uint64 `json:"id" example:"1"`
Username string `json:"username" example:"张三"`
Password string `json:"password" example:"9dXd13#k$1123!kln"`
CreatedAt string `json:"createdAt" example:"2020-03-23 20:00:00"`
UpdatedAt string `json:"updatedAt" example:"2020-03-23 20:00:00"`
}
func (c *UserModel) TableName() string {
// TableName 表名
func (u *UserModel) TableName() string {
return "tb_users"
}
// UserList 用户列表结构体
type UserList struct {
Lock *sync.Mutex
IdMap map[uint64]*UserInfo
IDMap map[uint64]*UserInfo
}
// Token represents a JSON web token.

16
pkg/cache/driver.go vendored
View File

@@ -8,17 +8,18 @@ import (
redis2 "github.com/1024casts/snake/pkg/redis"
)
// keyPrefix 一般为业务前缀
var Cache Driver = NewMemoryCache("snake:", JsonEncoding{})
// Cache 生成一个缓存客户端,其中keyPrefix 一般为业务前缀
var Cache Driver = NewMemoryCache("snake:", JSONEncoding{})
// 初始化缓存在main.go里调用
// Init 初始化缓存在main.go里调用
// 默认是redis这里也可以改为其他缓存
func Init() {
if gin.Mode() == gin.ReleaseMode {
Cache = NewRedisCache(redis2.Client, "snake:", JsonEncoding{})
Cache = NewRedisCache(redis2.Client, "snake:", JSONEncoding{})
}
}
// Driver 定义cache驱动接口
type Driver interface {
Set(key string, val interface{}, expiration time.Duration) error
Get(key string) (interface{}, error)
@@ -29,30 +30,37 @@ type Driver interface {
Decr(key string, step int64) (int64, error)
}
// Set 数据
func Set(key string, val interface{}, expiration time.Duration) error {
return Cache.Set(key, val, expiration)
}
// Get 数据
func Get(key string) (interface{}, error) {
return Cache.Get(key)
}
// MultiSet 批量set
func MultiSet(valMap map[string]interface{}, expiration time.Duration) error {
return Cache.MultiSet(valMap, expiration)
}
// MultiGet 批量获取
func MultiGet(keys ...string) (interface{}, error) {
return Cache.MultiGet(keys...)
}
// Del 批量删除
func Del(keys ...string) error {
return Cache.Del(keys...)
}
// Incr 自增
func Incr(key string, step int64) (int64, error) {
return Cache.Incr(key, step)
}
// Decr 自减
func Decr(key string, step int64) (int64, error) {
return Cache.Decr(key, step)
}

69
pkg/cache/encoding.go vendored
View File

@@ -16,53 +16,59 @@ import (
"github.com/vmihailenco/msgpack"
)
// Encoding 编码接口定义
type Encoding interface {
Marshal(v interface{}) ([]byte, error)
Unmarshal(data []byte, v interface{}) error
}
// Marshal encode data
func Marshal(e Encoding, v interface{}) (data []byte, err error) {
bm, ok := v.(encoding.BinaryMarshaler)
if ok && e == nil {
data, err = bm.MarshalBinary()
return
} else {
data, err = e.Marshal(v)
if err == nil {
return
}
if ok {
data, err = bm.MarshalBinary()
}
}
data, err = e.Marshal(v)
if err == nil {
return
}
if ok {
data, err = bm.MarshalBinary()
}
return
}
// Unmarshal decode data
func Unmarshal(e Encoding, data []byte, v interface{}) (err error) {
bm, ok := v.(encoding.BinaryUnmarshaler)
if ok && e == nil {
err = bm.UnmarshalBinary(data)
return err
} else {
err = e.Unmarshal(data, v)
if err == nil {
return
}
if ok {
return bm.UnmarshalBinary(data)
}
}
err = e.Unmarshal(data, v)
if err == nil {
return
}
if ok {
return bm.UnmarshalBinary(data)
}
return
}
type JsonEncoding struct{}
// JSONEncoding json格式
type JSONEncoding struct{}
func (this JsonEncoding) Marshal(v interface{}) ([]byte, error) {
// Marshal json encode
func (j JSONEncoding) Marshal(v interface{}) ([]byte, error) {
buf, err := json.Marshal(v)
return buf, err
}
func (this JsonEncoding) Unmarshal(data []byte, value interface{}) error {
// Unmarshal json decode
func (j JSONEncoding) Unmarshal(data []byte, value interface{}) error {
err := json.Unmarshal(data, value)
if err != nil {
return err
@@ -70,9 +76,11 @@ func (this JsonEncoding) Unmarshal(data []byte, value interface{}) error {
return nil
}
// GobEncoding gob encode
type GobEncoding struct{}
func (this GobEncoding) Marshal(v interface{}) ([]byte, error) {
// Marshal gob encode
func (g GobEncoding) Marshal(v interface{}) ([]byte, error) {
var (
buffer bytes.Buffer
)
@@ -81,7 +89,8 @@ func (this GobEncoding) Marshal(v interface{}) ([]byte, error) {
return buffer.Bytes(), err
}
func (this GobEncoding) Unmarshal(data []byte, value interface{}) error {
// Unmarshal gob encode
func (g GobEncoding) Unmarshal(data []byte, value interface{}) error {
err := gob.NewDecoder(bytes.NewReader(data)).Decode(value)
if err != nil {
return err
@@ -89,9 +98,11 @@ func (this GobEncoding) Unmarshal(data []byte, value interface{}) error {
return nil
}
type JsonGzipEncoding struct{}
// JSONGzipEncoding json and gzip
type JSONGzipEncoding struct{}
func (this JsonGzipEncoding) Marshal(v interface{}) ([]byte, error) {
// Marshal json encode and gzip
func (jz JSONGzipEncoding) Marshal(v interface{}) ([]byte, error) {
buf, err := json.Marshal(v)
if err != nil {
return nil, err
@@ -103,7 +114,8 @@ func (this JsonGzipEncoding) Marshal(v interface{}) ([]byte, error) {
return buf, err
}
func (this JsonGzipEncoding) Unmarshal(data []byte, value interface{}) error {
// Unmarshal json encode and gzip
func (jz JSONGzipEncoding) Unmarshal(data []byte, value interface{}) error {
jsonData, err := GzipDecode(data)
if err != nil {
return err
@@ -116,6 +128,7 @@ func (this JsonGzipEncoding) Unmarshal(data []byte, value interface{}) error {
return nil
}
// GzipEncode 编码
func GzipEncode(in []byte) ([]byte, error) {
var (
buffer bytes.Buffer
@@ -143,6 +156,7 @@ func GzipEncode(in []byte) ([]byte, error) {
return buffer.Bytes(), nil
}
// GzipDecode 解码
func GzipDecode(in []byte) ([]byte, error) {
reader, err := gzip.NewReader(bytes.NewReader(in))
if err != nil {
@@ -182,14 +196,17 @@ func (s JSONSnappyEncoding) Unmarshal(data []byte, value interface{}) error {
return json.Unmarshal(b, value)
}
// MsgPackEncoding msgpack 格式
type MsgPackEncoding struct{}
func (this MsgPackEncoding) Marshal(v interface{}) ([]byte, error) {
// Marshal msgpack encode
func (mp MsgPackEncoding) Marshal(v interface{}) ([]byte, error) {
buf, err := msgpack.Marshal(v)
return buf, err
}
func (this MsgPackEncoding) Unmarshal(data []byte, value interface{}) error {
// Unmarshal msgpack decode
func (mp MsgPackEncoding) Unmarshal(data []byte, value interface{}) error {
err := msgpack.Unmarshal(data, value)
if err != nil {
return err

View File

@@ -7,7 +7,7 @@ func BenchmarkJsonMarshal(b *testing.B) {
for i := 0; i < 400; i++ {
a = append(a, i)
}
jsonEncoding := JsonEncoding{}
jsonEncoding := JSONEncoding{}
for n := 0; n < b.N; n++ {
_, err := jsonEncoding.Marshal(a)
if err != nil {
@@ -21,7 +21,7 @@ func BenchmarkJsonUnmarshal(b *testing.B) {
for i := 0; i < 400; i++ {
a = append(a, i)
}
jsonEncoding := JsonEncoding{}
jsonEncoding := JSONEncoding{}
data, err := jsonEncoding.Marshal(a)
if err != nil {
b.Error(err)

1
pkg/cache/key.go vendored
View File

@@ -5,6 +5,7 @@ import (
"strings"
)
// BuildCacheKey 构建一个带有前缀的缓存key
func BuildCacheKey(keyPrefix string, key string) (cacheKey string, err error) {
if key == "" {
return "", errors.New("[cache] key should not be empty")

24
pkg/cache/memory.go vendored
View File

@@ -1,8 +1,6 @@
package cache
import (
"bytes"
"encoding/gob"
"sync"
"time"
@@ -17,7 +15,8 @@ type memoryCache struct {
encoding Encoding
}
func NewMemoryCache(keyPrefix string, encoding Encoding) *memoryCache {
// NewMemoryCache 实例化一个内存cache
func NewMemoryCache(keyPrefix string, encoding Encoding) Driver {
return &memoryCache{
client: &sync.Map{},
KeyPrefix: keyPrefix,
@@ -31,6 +30,7 @@ type itemWithTTL struct {
value interface{}
}
// newItem 返回带有效期的value
func newItem(value interface{}, expires time.Duration) itemWithTTL {
expires64 := int64(expires)
if expires > 0 {
@@ -60,17 +60,7 @@ func getValue(item interface{}, ok bool) (interface{}, bool) {
return itemObj.value, true
}
// interface 转 byte
func GetBytes(key interface{}) ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(key)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// Set data
func (m memoryCache) Set(key string, val interface{}, expiration time.Duration) error {
cacheKey, err := BuildCacheKey(m.KeyPrefix, key)
if err != nil {
@@ -80,6 +70,7 @@ func (m memoryCache) Set(key string, val interface{}, expiration time.Duration)
return nil
}
// Get data
func (m memoryCache) Get(key string) (interface{}, error) {
cacheKey, err := BuildCacheKey(m.KeyPrefix, key)
if err != nil {
@@ -92,14 +83,17 @@ func (m memoryCache) Get(key string) (interface{}, error) {
return val, nil
}
// MultiSet 批量set
func (m memoryCache) MultiSet(valMap map[string]interface{}, expiration time.Duration) error {
panic("implement me")
}
// MultiGet 批量获取
func (m memoryCache) MultiGet(keys ...string) (interface{}, error) {
panic("implement me")
}
// Del 批量删除
func (m memoryCache) Del(keys ...string) error {
if len(keys) == 0 {
return nil
@@ -117,10 +111,12 @@ func (m memoryCache) Del(keys ...string) error {
return nil
}
// Incr 自增
func (m memoryCache) Incr(key string, step int64) (int64, error) {
panic("implement me")
}
// Decr 自减
func (m memoryCache) Decr(key string, step int64) (int64, error) {
panic("implement me")
}

View File

@@ -8,7 +8,7 @@ import (
func Test_memoryCache_SetGet(t *testing.T) {
// 实例化memory cache
cache := NewMemoryCache("memory-unit-test", JsonEncoding{})
cache := NewMemoryCache("memory-unit-test", JSONEncoding{})
// test set
type setArgs struct {
@@ -19,7 +19,7 @@ func Test_memoryCache_SetGet(t *testing.T) {
setTests := []struct {
name string
cache *memoryCache
cache Driver
args setArgs
wantErr bool
}{
@@ -47,7 +47,7 @@ func Test_memoryCache_SetGet(t *testing.T) {
tests := []struct {
name string
cache *memoryCache
cache Driver
args args
wantVal interface{}
wantErr bool

6
pkg/cache/redis.go vendored
View File

@@ -10,6 +10,7 @@ import (
"github.com/pkg/errors"
)
// redisCache redis cache结构体
type redisCache struct {
client *redis.Client
KeyPrefix string
@@ -19,11 +20,12 @@ type redisCache struct {
}
const (
// DefaultExpireTime 默认过期时间
DefaultExpireTime = 60 * time.Second
)
// client 参数是可传入的,这样方便进行单元测试
func NewRedisCache(client *redis.Client, keyPrefix string, encoding Encoding) *redisCache {
// NewRedisCache new一个redis cache, client 参数是可传入的,这样方便进行单元测试
func NewRedisCache(client *redis.Client, keyPrefix string, encoding Encoding) Driver {
return &redisCache{
client: client,
KeyPrefix: keyPrefix,

View File

@@ -15,7 +15,7 @@ func Test_redisCache_SetGet(t *testing.T) {
// 获取redis客户端
redisClient := redis2.Client
// 实例化redis cache
cache := NewRedisCache(redisClient, "unit-test", JsonEncoding{})
cache := NewRedisCache(redisClient, "unit-test", JSONEncoding{})
// test set
type setArgs struct {
@@ -26,7 +26,7 @@ func Test_redisCache_SetGet(t *testing.T) {
setTests := []struct {
name string
cache *redisCache
cache Driver
args setArgs
wantErr bool
}{
@@ -54,7 +54,7 @@ func Test_redisCache_SetGet(t *testing.T) {
tests := []struct {
name string
cache *redisCache
cache Driver
args args
wantVal interface{}
wantErr bool

View File

@@ -1,5 +1,6 @@
package errno
//nolint: golint
var (
// Common errors
OK = &Errno{Code: 0, Message: "OK"}

View File

@@ -23,6 +23,7 @@ func (err *Err) Error() string {
return fmt.Sprintf("Err - code: %d, message: %s, error: %s", err.Code, err.Message, err.Err)
}
// DecodeErr 对错误进行解码返回错误code和错误提示
func DecodeErr(err error) (int, string) {
if err == nil {
return OK.Code, OK.Message

View File

@@ -8,7 +8,7 @@ import (
// see: https://github.com/iiinsomnia/gochat/blob/master/utils/http.go
// 禁止直接调用resty统一使用HttpClient
// HTTPClient 禁止直接调用resty统一使用HttpClient
var HTTPClient = New("resty")
// New 实例化一个client

View File

@@ -2,7 +2,7 @@ package http
import "time"
// http client 接口
// Client 定义 http client 接口
type Client interface {
Get(url string, params map[string]string, duration time.Duration) ([]byte, error)
Post(url string, requestBody string, duration time.Duration) ([]byte, error)

View File

@@ -18,8 +18,9 @@ import (
var logger *zap.Logger
// InitLogger 初始化logger
func InitLogger() *zap.Logger {
encoder := getJsonEncoder()
encoder := getJSONEncoder()
// 注意:如果多个文件,最后一个会是全的,前两个可能会丢日志
infoFilename := viper.GetString("log.logger_file")
@@ -58,7 +59,7 @@ func InitLogger() *zap.Logger {
return logger
}
func getJsonEncoder() zapcore.Encoder {
func getJSONEncoder() zapcore.Encoder {
encoderConfig := zapcore.EncoderConfig{
MessageKey: "msg",
LevelKey: "level",
@@ -94,26 +95,32 @@ func getLogWriterWithTime(filename string) io.Writer {
return hook
}
// Debug log
func Debug(msg string, args ...zap.Field) {
logger.Debug(msg, args...)
}
// Info log
func Info(msg string, args ...zap.Field) {
logger.Info(msg, args...)
}
// Warn log
func Warn(msg string, args ...zap.Field) {
logger.Warn(msg, args...)
}
// Error log
func Error(msg string, args ...zap.Field) {
logger.Error(msg, args...)
}
// Fatal log
func Fatal(msg string, args ...zap.Field) {
logger.Fatal(msg, args...)
}
// Infof log
func Infof(format string, args ...interface{}) {
message := fmt.Sprintf(format, args...)
logger.Info(message)

View File

@@ -10,10 +10,10 @@ import (
"github.com/spf13/viper"
)
// redis 客户端
// Client redis 客户端
var Client *redis.Client
// redis 返回为空
// Nil redis 返回为空
const Nil = redis.Nil
// Init 实例化一个redis client

View File

@@ -1,6 +1,8 @@
package util
import (
"bytes"
"encoding/gob"
"sync"
"github.com/gin-gonic/gin"
@@ -9,22 +11,25 @@ import (
tnet "github.com/toolkits/net"
)
func GenShortId() (string, error) {
// GenShortID 生成一个id
func GenShortID() (string, error) {
return shortid.Generate()
}
// GenUUID 生成随机字符串
func GenUUID() string {
u, _ := uuid.NewRandom()
return u.String()
}
// GetReqID 获取请求中的request_id
func GetReqID(c *gin.Context) string {
v, ok := c.Get("X-Request-Id")
v, ok := c.Get("X-Request-ID")
if !ok {
return ""
}
if requestId, ok := v.(string); ok {
return requestId
if requestID, ok := v.(string); ok {
return requestID
}
return ""
}
@@ -46,3 +51,14 @@ func GetLocalIP() string {
})
return clientIP
}
// GetBytes interface 转 byte
func GetBytes(key interface{}) ([]byte, error) {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
err := enc.Encode(key)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

View File

@@ -4,32 +4,32 @@ import (
"testing"
)
func TestGenShortId(t *testing.T) {
shortId, err := GenShortId()
if shortId == "" || err != nil {
t.Error("GenShortId failed!")
func TestGenShortID(t *testing.T) {
shortID, err := GenShortID()
if shortID == "" || err != nil {
t.Error("GenShortID failed!")
}
t.Log("GenShortId test pass")
t.Log("GenShortID test pass")
}
func BenchmarkGenShortId(b *testing.B) {
func BenchmarkGenShortID(b *testing.B) {
for i := 0; i < b.N; i++ {
GenShortId()
GenShortID()
}
}
func BenchmarkGenShortIdTimeConsuming(b *testing.B) {
func BenchmarkGenShortIDTimeConsuming(b *testing.B) {
b.StopTimer() //调用该函数停止压力测试的时间计数
shortId, err := GenShortId()
if shortId == "" || err != nil {
shortID, err := GenShortID()
if shortID == "" || err != nil {
b.Error(err)
}
b.StartTimer() //重新开始时间
for i := 0; i < b.N; i++ {
GenShortId()
GenShortID()
}
}

View File

@@ -21,6 +21,7 @@ func (info *Info) String() string {
return info.GitTag
}
// Get 返回详细的版本信息
func Get() Info {
return Info{
GitTag: gitTag,

View File

@@ -7,34 +7,36 @@ import (
"github.com/lexkong/log"
)
// IUserRepo 定义用户仓库接口
type IUserRepo interface {
CreateUser(db *gorm.DB, user model.UserModel) (id uint64, err error)
GetUserById(id uint64) (*model.UserModel, error)
GetUserByID(id uint64) (*model.UserModel, error)
GetUserByPhone(phone int) (*model.UserModel, error)
GetUserByEmail(email string) (*model.UserModel, error)
GetUsersByIds(ids []uint64) ([]*model.UserModel, error)
Update(userMap map[string]interface{}, id uint64) error
}
type UserRepo struct {
}
// userRepo 用户仓库
type userRepo struct{}
// NewUserRepo 实例化用户仓库
func NewUserRepo() IUserRepo {
return &UserRepo{}
return &userRepo{}
}
// CreateUser 创建用户
func (repo *UserRepo) CreateUser(db *gorm.DB, user model.UserModel) (id uint64, err error) {
func (repo *userRepo) CreateUser(db *gorm.DB, user model.UserModel) (id uint64, err error) {
err = db.Create(&user).Error
if err != nil {
return 0, err
}
return user.Id, nil
return user.ID, nil
}
// GetUserByID 获取用户
func (repo *UserRepo) GetUserById(id uint64) (*model.UserModel, error) {
func (repo *userRepo) GetUserByID(id uint64) (*model.UserModel, error) {
user := &model.UserModel{}
result := model.GetDB().Where("id = ?", id).First(user)
@@ -42,7 +44,7 @@ func (repo *UserRepo) GetUserById(id uint64) (*model.UserModel, error) {
}
// GetUserByPhone 根据手机号获取用户
func (repo *UserRepo) GetUserByPhone(phone int) (*model.UserModel, error) {
func (repo *userRepo) GetUserByPhone(phone int) (*model.UserModel, error) {
user := model.UserModel{}
result := model.GetDB().Where("phone = ?", phone).First(&user)
@@ -52,7 +54,7 @@ func (repo *UserRepo) GetUserByPhone(phone int) (*model.UserModel, error) {
}
// GetUserByEmail 根据邮箱获取手机号
func (repo *UserRepo) GetUserByEmail(phone string) (*model.UserModel, error) {
func (repo *userRepo) GetUserByEmail(phone string) (*model.UserModel, error) {
user := model.UserModel{}
result := model.GetDB().Where("email = ?", phone).First(&user)
@@ -62,7 +64,7 @@ func (repo *UserRepo) GetUserByEmail(phone string) (*model.UserModel, error) {
}
// GetUsersByIds 批量获取用户
func (repo *UserRepo) GetUsersByIds(ids []uint64) ([]*model.UserModel, error) {
func (repo *userRepo) GetUsersByIds(ids []uint64) ([]*model.UserModel, error) {
users := make([]*model.UserModel, 0)
result := model.GetDB().Where("id in (?)", ids).Find(&users)
@@ -70,8 +72,8 @@ func (repo *UserRepo) GetUsersByIds(ids []uint64) ([]*model.UserModel, error) {
}
// Update 更新用户信息
func (repo *UserRepo) Update(userMap map[string]interface{}, id uint64) error {
user, err := repo.GetUserById(id)
func (repo *userRepo) Update(userMap map[string]interface{}, id uint64) error {
user, err := repo.GetUserByID(id)
if err != nil {
return err
}

View File

@@ -4,11 +4,12 @@ import (
"github.com/1024casts/snake/handler"
"github.com/1024casts/snake/pkg/errno"
"github.com/1024casts/snake/pkg/token"
"github.com/lexkong/log"
"github.com/gin-gonic/gin"
"github.com/lexkong/log"
)
// AuthMiddleware 认证中间件
func AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Parse the json web token.

View File

@@ -24,7 +24,8 @@ package middleware
// Tracer, Closer, Error = jaeger_trace.NewJaegerTracer(config.AppName, config.JaegerHostPort)
// defer Closer.Close()
//
// spCtx, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(c.Request.Header))
// spCtx, err := opentracing.GlobalTracer().Extract(opentracing.HTTPHeaders,
// opentracing.HTTPHeadersCarrier(c.Request.Header))
// if err != nil {
// ParentSpan = Tracer.StartSpan(c.Request.URL.Path)
// defer ParentSpan.Finish()

View File

@@ -2,24 +2,26 @@ package middleware
import (
"github.com/1024casts/snake/pkg/util"
"github.com/gin-gonic/gin"
)
func RequestId() gin.HandlerFunc {
// RequestID 透传Request-ID如果没有则生成一个
func RequestID() gin.HandlerFunc {
return func(c *gin.Context) {
// Check for incoming header, use it if exists
requestId := c.Request.Header.Get("X-Request-Id")
requestID := c.Request.Header.Get("X-Request-ID")
// Create request id with UUID4
if requestId == "" {
requestId = util.GenUUID()
if requestID == "" {
requestID = util.GenUUID()
}
// Expose it for use in the application
c.Set("X-Request-Id", requestId)
c.Set("X-Request-ID", requestID)
// Set X-Request-Id header
c.Writer.Header().Set("X-Request-Id", requestId)
// Set X-Request-ID header
c.Writer.Header().Set("X-Request-ID", requestID)
c.Next()
}
}

View File

@@ -3,15 +3,16 @@ package router
import (
"net/http"
_ "github.com/1024casts/snake/docs" // docs is generated by Swag CLI, you have to import it.
"github.com/1024casts/snake/handler/sd"
"github.com/1024casts/snake/handler/user"
"github.com/1024casts/snake/router/middleware"
"github.com/gin-contrib/pprof"
"github.com/gin-gonic/gin"
"github.com/swaggo/gin-swagger"
"github.com/swaggo/gin-swagger" //nolint: goimports
"github.com/swaggo/gin-swagger/swaggerFiles"
// import swagger handler
_ "github.com/1024casts/snake/docs" // docs is generated by Swag CLI, you have to import it.
"github.com/1024casts/snake/handler/user"
"github.com/1024casts/snake/router/middleware"
)
// Load loads the middlewares, routes, handlers.
@@ -48,14 +49,5 @@ func Load(g *gin.Engine, mw ...gin.HandlerFunc) *gin.Engine {
u.PUT("/:id", user.Update)
}
// The health check handlers
svcd := g.Group("/sd")
{
svcd.GET("/health", sd.HealthCheck)
svcd.GET("/disk", sd.DiskCheck)
svcd.GET("/cpu", sd.CPUCheck)
svcd.GET("/ram", sd.RAMCheck)
}
return g
}

View File

@@ -7,21 +7,26 @@ import (
"github.com/spf13/viper"
)
// 短信服务
// ServiceSms 短信服务
// 使用七牛云
// 直接初始化,可以避免在使用时再实例化
var ServiceSms = NewSmsService()
// 校验码服务,生成校验码和获得校验码
type smsService struct {
// ISmsService 短信服务接口定义
type ISmsService interface {
Send(phoneNumber string, verifyCode int) error
_sendViaQiNiu(phoneNumber string, verifyCode int) error
}
// smsService 校验码服务,生成校验码和获得校验码
type smsService struct{}
// NewSmsService 实例化一个sms
func NewSmsService() *smsService {
func NewSmsService() ISmsService {
return &smsService{}
}
// 发送短信
// Send 发送短信
func (srv *smsService) Send(phoneNumber string, verifyCode int) error {
// 校验参数的正确性
if phoneNumber == "" || verifyCode == 0 {
@@ -32,9 +37,8 @@ func (srv *smsService) Send(phoneNumber string, verifyCode int) error {
return srv._sendViaQiNiu(phoneNumber, verifyCode)
}
// 调用七牛短信服务
// _sendViaQiNiu 调用七牛短信服务
func (srv *smsService) _sendViaQiNiu(phoneNumber string, verifyCode int) error {
accessKey := viper.GetString("qiniu.access_key")
secretKey := viper.GetString("qiniu.secret_key")

View File

@@ -8,6 +8,7 @@ import (
"github.com/pkg/errors"
)
// IUserService 用户服务接口定义
type IUserService interface {
CreateUser(user model.UserModel) (id uint64, err error)
UpdateUser(userMap map[string]interface{}, id uint64) error
@@ -17,7 +18,7 @@ type IUserService interface {
GetUserByEmail(email string) (*model.UserModel, error)
}
// 直接初始化,可以避免在使用时再实例化
// UserService 直接初始化,可以避免在使用时再实例化
var UserService = NewUserService()
type userService struct {
@@ -51,7 +52,7 @@ func (srv *userService) UpdateUser(userMap map[string]interface{}, id uint64) er
}
func (srv *userService) GetUserByID(id uint64) (*model.UserModel, error) {
userModel, err := srv.userRepo.GetUserById(id)
userModel, err := srv.userRepo.GetUserByID(id)
if err != nil {
return userModel, errors.Wrapf(err, "get user info err from db by id: %d", id)
}
@@ -69,7 +70,7 @@ func (srv *userService) GetUserListByIds(id []uint64) (map[uint64]*model.UserMod
}
for _, v := range userModels {
retMap[v.Id] = v
retMap[v.ID] = v
}
return retMap, nil

View File

@@ -12,7 +12,7 @@ import (
"github.com/pkg/errors"
)
// 验证码服务,主要提供生成验证码和获取验证码
// VCodeService 验证码服务,主要提供生成验证码和获取验证码
// 直接初始化,可以避免在使用时再实例化
var VCodeService = NewVCodeService()
@@ -21,15 +21,26 @@ const (
maxDurationTime = 10 * time.Minute // 验证码有效期
)
// 校验码服务,生成校验码和获得校验码
type vcodeService struct {
// IVerifyCodeService 校验码服务接口定义
type IVerifyCodeService interface {
// public func
GenLoginVCode(phone string) (int, error)
CheckLoginVCode(phone, vCode int) bool
GetLoginVCode(phone int) (int, error)
// private func
isTestPhone(phone int) bool
}
func NewVCodeService() *vcodeService {
// vcodeService 校验码服务,生成校验码和获得校验码
type vcodeService struct{}
// NewVCodeService 实例化一个验证码服务
func NewVCodeService() IVerifyCodeService {
return &vcodeService{}
}
// 生成校验码
// GenLoginVCode 生成校验码
func (srv *vcodeService) GenLoginVCode(phone string) (int, error) {
// step1: 生成随机数
vCodeStr := fmt.Sprintf("%06v", rand.New(rand.NewSource(time.Now().UnixNano())).Int31n(1000000))
@@ -55,7 +66,7 @@ var phoneWhiteLit = []int{
13010102020,
}
// 这里可以添加测试号,直接通过
// isTestPhone 这里可以添加测试号,直接通过
func (srv *vcodeService) isTestPhone(phone int) bool {
for _, val := range phoneWhiteLit {
if val == phone {
@@ -65,7 +76,7 @@ func (srv *vcodeService) isTestPhone(phone int) bool {
return false
}
// 验证校验码是否正确
// CheckLoginVCode 验证校验码是否正确
func (srv *vcodeService) CheckLoginVCode(phone, vCode int) bool {
if srv.isTestPhone(phone) {
return true
@@ -84,7 +95,7 @@ func (srv *vcodeService) CheckLoginVCode(phone, vCode int) bool {
return true
}
// 获得校验码
// GetLoginVCode 获得校验码
func (srv *vcodeService) GetLoginVCode(phone int) (int, error) {
// 直接从redis里获取
key := fmt.Sprintf(verifyCodeRedisKey, phone)