This commit is contained in:
pyihe
2022-04-28 15:09:19 +08:00
parent 32327dff31
commit 6c88f85bed
6 changed files with 51 additions and 50 deletions

View File

@@ -9,7 +9,7 @@ import (
type ByteBuffer = bytebufferpool.ByteBuffer type ByteBuffer = bytebufferpool.ByteBuffer
var ( var (
Get = bytebufferpool.Get() Get = bytebufferpool.Get
Put = func(b *ByteBuffer) { Put = func(b *ByteBuffer) {
if b != nil { if b != nil {
bytebufferpool.Put(b) bytebufferpool.Put(b)

1
go.mod
View File

@@ -13,5 +13,4 @@ require (
github.com/vmihailenco/msgpack/v5 v5.3.4 // indirect github.com/vmihailenco/msgpack/v5 v5.3.4 // indirect
go.uber.org/zap v1.18.1 // indirect go.uber.org/zap v1.18.1 // indirect
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e // indirect golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e // indirect
gopkg.in/ini.v1 v1.66.4 // indirect
) )

View File

@@ -10,8 +10,10 @@ import (
) )
type Packet interface { type Packet interface {
Packet(message Message) (data []byte, err error) HeaderLen() int
UnPacket(reader io.Reader, message Message) error MaxMessageLen() int
Packet(message []byte) (data []byte, err error)
UnPacket(reader io.Reader) ([]byte, error)
} }
type Message interface { type Message interface {
@@ -25,44 +27,48 @@ type packet struct {
} }
func NewPacket(headerLen, maxDataLen int) Packet { func NewPacket(headerLen, maxDataLen int) Packet {
if headerLen <= 0 { if maths.MaxInt(0, headerLen) == 0 {
headerLen = 4 headerLen = 4
} }
if maths.MaxInt(0, maxDataLen) == 0 {
maxDataLen = 2046
}
return &packet{ return &packet{
headerLen: headerLen, headerLen: headerLen,
maxDataLen: maths.MaxInt(0, maxDataLen), maxDataLen: maths.MaxInt(0, maxDataLen),
} }
} }
func (p *packet) HeaderLen() int {
if p != nil {
return p.headerLen
}
return -1
}
func (p *packet) MaxMessageLen() int {
if p != nil {
return p.maxDataLen
}
return -1
}
// Packet 封包 // Packet 封包
func (p *packet) Packet(message Message) (data []byte, err error) { func (p *packet) Packet(message []byte) (data []byte, err error) {
if message == nil { if p.maxDataLen > 0 && len(message) > p.maxDataLen {
err = errors.New("nil Message")
return
}
mBytes, err := message.Marshal()
if err != nil {
return
}
if p.maxDataLen > 0 && len(mBytes) > p.maxDataLen {
err = errors.New("packet: message is too large") err = errors.New("packet: message is too large")
return return
} }
data = make([]byte, p.headerLen+len(mBytes)) data = make([]byte, p.headerLen+len(message))
// 头headerLen个字节存放数据长度 // 头headerLen个字节存放数据长度
binary.LittleEndian.PutUint32(data[:4], uint32(len(mBytes))) binary.LittleEndian.PutUint32(data[:4], uint32(len(message)))
// 将数据写进剩余的字节 // 将数据写进剩余的字节
copy(data[4:], mBytes) copy(data[4:], message)
return return
} }
// UnPacket 拆包 // UnPacket 拆包
func (p *packet) UnPacket(reader io.Reader, message Message) (err error) { func (p *packet) UnPacket(reader io.Reader) (b []byte, err error) {
if message == nil {
err = errors.New("nil Message")
return
}
// 先读取header中的数据长度 // 先读取header中的数据长度
header := make([]byte, p.headerLen) header := make([]byte, p.headerLen)
n, err := io.ReadFull(reader, header) n, err := io.ReadFull(reader, header)
@@ -84,8 +90,7 @@ func (p *packet) UnPacket(reader io.Reader, message Message) (err error) {
data := make([]byte, dataLen) data := make([]byte, dataLen)
n, err = io.ReadFull(reader, data) n, err = io.ReadFull(reader, data)
if err == nil { if err == nil {
// 反序列化数据到对应的结构体中 b = data[:n]
err = message.Unmarshal(data[:n])
} }
return return
} }

View File

@@ -111,7 +111,7 @@ func NewPool(opts ...InitOptions) (RedisPool, error) {
defaultPool.db = 1 defaultPool.db = 1
} }
if defaultPool.net == "" { if defaultPool.net == "" {
defaultPool.net = "tcp" defaultPool.net = "tcps"
} }
defaultPool.p = &redis.Pool{ defaultPool.p = &redis.Pool{
Dial: func() (conn redis.Conn, e error) { Dial: func() (conn redis.Conn, e error) {

View File

@@ -1,7 +1,6 @@
package snowflakes package snowflakes
import ( import (
"errors"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -36,38 +35,26 @@ type builder struct {
number int64 // 当前毫秒已经生成的id序列号(从0开始累加) 1毫秒内最多生成4096个ID number int64 // 当前毫秒已经生成的id序列号(从0开始累加) 1毫秒内最多生成4096个ID
} }
// 实例化一个工作节点 func NewWorker(workerId int64) Worker {
func NewWorker(opts ...Option) Worker { assertWorkId(workerId)
b := &builder{ b := &builder{
epoch: time.Now().Unix() * 1000, epoch: time.Now().Unix() * 1000,
workerId: workerId,
} }
for _, opt := range opts {
if err := opt(b); err != nil {
panic(err)
}
}
return b return b
} }
func WithWorkerId(workerId int64) Option { func assertWorkId(workerId int64) {
return func(b *builder) error {
if workerId < 0 || workerId > nodeMax { if workerId < 0 || workerId > nodeMax {
return errors.New("work id cannot more than 1024") panic("work id cannot more than 1024")
}
b.workerId = workerId
return nil
} }
} }
func (w *builder) GetInt64() int64 { func (w *builder) GetInt64() (id int64) {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock()
now := time.Now().UnixNano() / 1e6 now := time.Now().UnixNano() / 1e6
if w.timestamp == now { if w.timestamp == now {
w.number++ w.number++
if w.number > numberMax { if w.number > numberMax {
for now <= w.timestamp { for now <= w.timestamp {
now = time.Now().UnixNano() / 1e6 now = time.Now().UnixNano() / 1e6
@@ -77,8 +64,9 @@ func (w *builder) GetInt64() int64 {
w.number = 0 w.number = 0
w.timestamp = now w.timestamp = now
} }
id = (now-w.epoch)<<timeShift | (w.workerId << workerShift) | (w.number)
return (now-w.epoch)<<timeShift | (w.workerId << workerShift) | (w.number) w.mu.Unlock()
return
} }
func (w *builder) GetString() string { func (w *builder) GetString() string {

View File

@@ -6,6 +6,7 @@ import (
type ( type (
AtomicInt32 int32 AtomicInt32 int32
AtomicInt64 int64
) )
func (ai *AtomicInt32) Inc(delta int32) { func (ai *AtomicInt32) Inc(delta int32) {
@@ -15,3 +16,11 @@ func (ai *AtomicInt32) Inc(delta int32) {
func (ai *AtomicInt32) Value() int32 { func (ai *AtomicInt32) Value() int32 {
return atomic.LoadInt32((*int32)(ai)) return atomic.LoadInt32((*int32)(ai))
} }
func (ai *AtomicInt64) Inc(delta int64) {
atomic.AddInt64((*int64)(ai), delta)
}
func (ai *AtomicInt64) Value() int64 {
return atomic.LoadInt64((*int64)(ai))
}