mirror of
https://github.com/XZB-1248/Spark
synced 2025-09-26 20:21:11 +08:00
optimize: update concurrent_map
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -3,6 +3,6 @@
|
||||
/tools
|
||||
/logs
|
||||
/.idea
|
||||
/Config.json
|
||||
/config.json
|
||||
dist/
|
||||
node_modules/
|
@@ -62,7 +62,7 @@ const imageQuality = 70
|
||||
|
||||
var lock = &sync.Mutex{}
|
||||
var working = false
|
||||
var sessions = cmap.New()
|
||||
var sessions = cmap.New[*session]()
|
||||
var prevDesktop *image.RGBA
|
||||
var displayBounds image.Rectangle
|
||||
var errNoImage = errors.New(`DESKTOP.NO_IMAGE_YET`)
|
||||
@@ -127,8 +127,7 @@ func worker() {
|
||||
}
|
||||
|
||||
func sendImageDiff(diff []*[]byte) {
|
||||
sessions.IterCb(func(uuid string, t any) bool {
|
||||
desktop := t.(*session)
|
||||
sessions.IterCb(func(uuid string, desktop *session) bool {
|
||||
desktop.lock.Lock()
|
||||
if !desktop.escape {
|
||||
if len(desktop.channel) >= frameBuffer {
|
||||
@@ -146,9 +145,8 @@ func sendImageDiff(diff []*[]byte) {
|
||||
|
||||
func quitAllDesktop(info string) {
|
||||
keys := make([]string, 0)
|
||||
sessions.IterCb(func(uuid string, t any) bool {
|
||||
sessions.IterCb(func(uuid string, desktop *session) bool {
|
||||
keys = append(keys, uuid)
|
||||
desktop := t.(*session)
|
||||
desktop.escape = true
|
||||
desktop.channel <- message{t: 1, info: info}
|
||||
return true
|
||||
@@ -346,24 +344,23 @@ func PingDesktop(pack modules.Packet) {
|
||||
} else {
|
||||
uuid = val.(string)
|
||||
}
|
||||
if val, ok := sessions.Get(uuid); ok {
|
||||
desktop = val.(*session)
|
||||
desktop.lastPack = utils.Unix
|
||||
desktop, ok := sessions.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
desktop.lastPack = utils.Unix
|
||||
}
|
||||
|
||||
func KillDesktop(pack modules.Packet) {
|
||||
var uuid string
|
||||
var desktop *session
|
||||
if val, ok := pack.GetData(`desktop`, reflect.String); !ok {
|
||||
return
|
||||
} else {
|
||||
uuid = val.(string)
|
||||
}
|
||||
if val, ok := sessions.Get(uuid); !ok {
|
||||
desktop, ok := sessions.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
} else {
|
||||
desktop = val.(*session)
|
||||
}
|
||||
sessions.Remove(uuid)
|
||||
data, _ := utils.JSON.Marshal(modules.Packet{Act: `DESKTOP_QUIT`, Msg: `${i18n|DESKTOP.SESSION_CLOSED}`})
|
||||
@@ -383,10 +380,9 @@ func GetDesktop(pack modules.Packet) {
|
||||
} else {
|
||||
uuid = val.(string)
|
||||
}
|
||||
if val, ok := sessions.Get(uuid); !ok {
|
||||
desktop, ok := sessions.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
} else {
|
||||
desktop = val.(*session)
|
||||
}
|
||||
if !desktop.escape {
|
||||
lock.Lock()
|
||||
@@ -450,8 +446,7 @@ func healthCheck() {
|
||||
timestamp := now.Unix()
|
||||
// stores sessions to be disconnected
|
||||
keys := make([]string, 0)
|
||||
sessions.IterCb(func(uuid string, t any) bool {
|
||||
desktop := t.(*session)
|
||||
sessions.IterCb(func(uuid string, desktop *session) bool {
|
||||
if timestamp-desktop.lastPack > MaxInterval {
|
||||
keys = append(keys, uuid)
|
||||
}
|
||||
|
@@ -1,11 +1,9 @@
|
||||
package terminal
|
||||
|
||||
import (
|
||||
"Spark/utils/cmap"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var terminals = cmap.New()
|
||||
var (
|
||||
errDataNotFound = errors.New(`no input found in packet`)
|
||||
errDataInvalid = errors.New(`can not parse data in packet`)
|
||||
|
@@ -6,6 +6,7 @@ import (
|
||||
"Spark/client/common"
|
||||
"Spark/modules"
|
||||
"Spark/utils"
|
||||
"Spark/utils/cmap"
|
||||
"encoding/hex"
|
||||
"github.com/creack/pty"
|
||||
"os"
|
||||
@@ -23,6 +24,7 @@ type terminal struct {
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
var terminals = cmap.New[*terminal]()
|
||||
var defaultShell = ``
|
||||
|
||||
func init() {
|
||||
@@ -88,11 +90,8 @@ func InitTerminal(pack modules.Packet) error {
|
||||
}
|
||||
|
||||
func InputRawTerminal(input []byte, uuid string) {
|
||||
var session *terminal
|
||||
|
||||
if val, ok := terminals.Get(uuid); ok {
|
||||
session = val.(*terminal)
|
||||
} else {
|
||||
session, ok := terminals.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.pty.Write(input)
|
||||
@@ -165,10 +164,9 @@ func KillTerminal(pack modules.Packet) {
|
||||
} else {
|
||||
uuid = val.(string)
|
||||
}
|
||||
if val, ok := terminals.Get(uuid); !ok {
|
||||
session, ok := terminals.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
} else {
|
||||
session = val.(*terminal)
|
||||
}
|
||||
terminals.Remove(uuid)
|
||||
data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`})
|
||||
@@ -180,18 +178,16 @@ func KillTerminal(pack modules.Packet) {
|
||||
|
||||
func PingTerminal(pack modules.Packet) {
|
||||
var termUUID string
|
||||
var termSession *terminal
|
||||
if val, ok := pack.GetData(`terminal`, reflect.String); !ok {
|
||||
return
|
||||
} else {
|
||||
termUUID = val.(string)
|
||||
}
|
||||
if val, ok := terminals.Get(termUUID); !ok {
|
||||
session, ok := terminals.Get(termUUID)
|
||||
if !ok {
|
||||
return
|
||||
} else {
|
||||
termSession = val.(*terminal)
|
||||
termSession.lastPack = utils.Unix
|
||||
}
|
||||
session.lastPack = utils.Unix
|
||||
}
|
||||
|
||||
func doKillTerminal(terminal *terminal) {
|
||||
@@ -234,11 +230,10 @@ func healthCheck() {
|
||||
timestamp := now.Unix()
|
||||
// stores sessions to be disconnected
|
||||
queue := make([]string, 0)
|
||||
terminals.IterCb(func(uuid string, t any) bool {
|
||||
termSession := t.(*terminal)
|
||||
if timestamp-termSession.lastPack > MaxInterval {
|
||||
terminals.IterCb(func(uuid string, session *terminal) bool {
|
||||
if timestamp-session.lastPack > MaxInterval {
|
||||
queue = append(queue, uuid)
|
||||
doKillTerminal(termSession)
|
||||
doKillTerminal(session)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
@@ -4,6 +4,7 @@ import (
|
||||
"Spark/client/common"
|
||||
"Spark/modules"
|
||||
"Spark/utils"
|
||||
"Spark/utils/cmap"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"os/exec"
|
||||
@@ -23,6 +24,7 @@ type terminal struct {
|
||||
stdin *io.WriteCloser
|
||||
}
|
||||
|
||||
var terminals = cmap.New[*terminal]()
|
||||
var defaultCmd = ``
|
||||
|
||||
func init() {
|
||||
@@ -108,11 +110,8 @@ func InitTerminal(pack modules.Packet) error {
|
||||
}
|
||||
|
||||
func InputRawTerminal(input []byte, uuid string) {
|
||||
var session *terminal
|
||||
|
||||
if val, ok := terminals.Get(uuid); ok {
|
||||
session = val.(*terminal)
|
||||
} else {
|
||||
session, ok := terminals.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
(*session.stdin).Write(input)
|
||||
@@ -152,16 +151,14 @@ func ResizeTerminal(pack modules.Packet) error {
|
||||
|
||||
func KillTerminal(pack modules.Packet) {
|
||||
var uuid string
|
||||
var session *terminal
|
||||
if val, ok := pack.GetData(`terminal`, reflect.String); !ok {
|
||||
return
|
||||
} else {
|
||||
uuid = val.(string)
|
||||
}
|
||||
if val, ok := terminals.Get(uuid); !ok {
|
||||
session, ok := terminals.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
} else {
|
||||
session = val.(*terminal)
|
||||
}
|
||||
terminals.Remove(uuid)
|
||||
data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`})
|
||||
@@ -179,12 +176,11 @@ func PingTerminal(pack modules.Packet) {
|
||||
} else {
|
||||
uuid = val.(string)
|
||||
}
|
||||
if val, ok := terminals.Get(uuid); !ok {
|
||||
session, ok := terminals.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
} else {
|
||||
session = val.(*terminal)
|
||||
session.lastPack = utils.Unix
|
||||
}
|
||||
session.lastPack = utils.Unix
|
||||
}
|
||||
|
||||
func doKillTerminal(terminal *terminal) {
|
||||
@@ -221,11 +217,10 @@ func healthCheck() {
|
||||
timestamp := now.Unix()
|
||||
// stores sessions to be disconnected
|
||||
keys := make([]string, 0)
|
||||
terminals.IterCb(func(uuid string, t any) bool {
|
||||
termSession := t.(*terminal)
|
||||
if timestamp-termSession.lastPack > MaxInterval {
|
||||
terminals.IterCb(func(uuid string, session *terminal) bool {
|
||||
if timestamp-session.lastPack > MaxInterval {
|
||||
keys = append(keys, uuid)
|
||||
doKillTerminal(termSession)
|
||||
doKillTerminal(session)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
@@ -17,7 +17,7 @@ import (
|
||||
const MaxMessageSize = (2 << 15) + 1024
|
||||
|
||||
var Melody = melody.New()
|
||||
var Devices = cmap.New()
|
||||
var Devices = cmap.New[*modules.Device]()
|
||||
|
||||
func SendPackByUUID(pack modules.Packet, uuid string) bool {
|
||||
session, ok := Melody.GetSessionByUUID(uuid)
|
||||
@@ -164,8 +164,7 @@ func CheckDevice(deviceID, connUUID string) (string, bool) {
|
||||
}
|
||||
} else {
|
||||
tempConnUUID := ``
|
||||
Devices.IterCb(func(uuid string, v any) bool {
|
||||
device := v.(*modules.Device)
|
||||
Devices.IterCb(func(uuid string, device *modules.Device) bool {
|
||||
if device.ID == deviceID {
|
||||
tempConnUUID = uuid
|
||||
return false
|
||||
|
@@ -15,7 +15,7 @@ type event struct {
|
||||
remove chan bool
|
||||
}
|
||||
|
||||
var events = cmap.New()
|
||||
var events = cmap.New[*event]()
|
||||
|
||||
// CallEvent tries to call the callback with the given uuid
|
||||
// after that, it will notify the caller via the channel
|
||||
@@ -23,11 +23,10 @@ func CallEvent(pack modules.Packet, session *melody.Session) {
|
||||
if len(pack.Event) == 0 {
|
||||
return
|
||||
}
|
||||
v, ok := events.Get(pack.Event)
|
||||
ev, ok := events.Get(pack.Event)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ev := v.(*event)
|
||||
if session != nil && session.UUID != ev.connection {
|
||||
return
|
||||
}
|
||||
@@ -76,12 +75,11 @@ func AddEvent(fn EventCallback, connUUID, trigger string) {
|
||||
// RemoveEvent deletes the event with the given event trigger.
|
||||
// The ok will be returned to caller if the event is temp (only once).
|
||||
func RemoveEvent(trigger string, ok ...bool) {
|
||||
v, found := events.Get(trigger)
|
||||
ev, found := events.Get(trigger)
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
events.Remove(trigger)
|
||||
ev := v.(*event)
|
||||
if ev.remove != nil {
|
||||
if len(ok) > 0 {
|
||||
ev.remove <- ok[0]
|
||||
@@ -89,7 +87,6 @@ func RemoveEvent(trigger string, ok ...bool) {
|
||||
ev.remove <- false
|
||||
}
|
||||
}
|
||||
v = nil
|
||||
ev = nil
|
||||
}
|
||||
|
||||
|
@@ -1,7 +1,6 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"Spark/modules"
|
||||
"Spark/server/config"
|
||||
"Spark/utils"
|
||||
"Spark/utils/melody"
|
||||
@@ -81,9 +80,8 @@ func getLog(ctx any, event, status, msg string, args map[string]any) string {
|
||||
}
|
||||
}
|
||||
if targetInfo {
|
||||
val, ok := Devices.Get(connUUID)
|
||||
device, ok := Devices.Get(connUUID)
|
||||
if ok {
|
||||
device := val.(*modules.Device)
|
||||
args[`target`] = map[string]any{
|
||||
`name`: device.Hostname,
|
||||
`ip`: device.WAN,
|
||||
|
@@ -28,15 +28,14 @@ type Bridge struct {
|
||||
OnFinish func(bridge *Bridge)
|
||||
}
|
||||
|
||||
var bridges = cmap.New()
|
||||
var bridges = cmap.New[*Bridge]()
|
||||
|
||||
func init() {
|
||||
go func() {
|
||||
for now := range time.NewTicker(15 * time.Second).C {
|
||||
var queue []string
|
||||
timestamp := now.Unix()
|
||||
bridges.IterCb(func(k string, v any) bool {
|
||||
b := v.(*Bridge)
|
||||
bridges.IterCb(func(k string, b *Bridge) bool {
|
||||
if timestamp-b.creation > 60 && !b.using {
|
||||
b.lock.Lock()
|
||||
if b.Src != nil && b.Src.Request.Body != nil {
|
||||
@@ -63,12 +62,12 @@ func CheckBridge(ctx *gin.Context) *Bridge {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_PARAMETER}`})
|
||||
return nil
|
||||
}
|
||||
val, ok := bridges.Get(form.Bridge)
|
||||
b, ok := bridges.Get(form.Bridge)
|
||||
if !ok {
|
||||
ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_BRIDGE_ID}`})
|
||||
return nil
|
||||
}
|
||||
return val.(*Bridge)
|
||||
return b
|
||||
}
|
||||
|
||||
func BridgePush(ctx *gin.Context) {
|
||||
@@ -218,12 +217,11 @@ func AddBridgeWithDst(ext any, uuid string, Dst *gin.Context) *Bridge {
|
||||
}
|
||||
|
||||
func RemoveBridge(uuid string) {
|
||||
val, ok := bridges.Get(uuid)
|
||||
b, ok := bridges.Get(uuid)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
bridges.Remove(uuid)
|
||||
b := val.(*Bridge)
|
||||
if b.Src != nil && b.Src.Request.Body != nil {
|
||||
b.Src.Request.Body.Close()
|
||||
}
|
||||
|
@@ -70,8 +70,7 @@ func OnDevicePack(data []byte, session *melody.Session) error {
|
||||
// If so, then find the session and let client quit.
|
||||
// This will keep only one connection remained per device.
|
||||
exSession := ``
|
||||
common.Devices.IterCb(func(uuid string, v any) bool {
|
||||
device := v.(*modules.Device)
|
||||
common.Devices.IterCb(func(uuid string, device *modules.Device) bool {
|
||||
if device.ID == pack.Device.ID {
|
||||
exSession = uuid
|
||||
target, ok := common.Melody.GetSessionByUUID(uuid)
|
||||
@@ -94,14 +93,13 @@ func OnDevicePack(data []byte, session *melody.Session) error {
|
||||
},
|
||||
})
|
||||
} else {
|
||||
val, ok := common.Devices.Get(session.UUID)
|
||||
device, ok := common.Devices.Get(session.UUID)
|
||||
if ok {
|
||||
deviceInfo := val.(*modules.Device)
|
||||
deviceInfo.CPU = pack.Device.CPU
|
||||
deviceInfo.RAM = pack.Device.RAM
|
||||
deviceInfo.Net = pack.Device.Net
|
||||
deviceInfo.Disk = pack.Device.Disk
|
||||
deviceInfo.Uptime = pack.Device.Uptime
|
||||
device.CPU = pack.Device.CPU
|
||||
device.RAM = pack.Device.RAM
|
||||
device.Net = pack.Device.Net
|
||||
device.Disk = pack.Device.Disk
|
||||
device.Uptime = pack.Device.Uptime
|
||||
}
|
||||
}
|
||||
common.SendPack(modules.Packet{Code: 0}, session)
|
||||
@@ -268,8 +266,7 @@ func ExecDeviceCmd(ctx *gin.Context) {
|
||||
// GetDevices will return all info about all clients.
|
||||
func GetDevices(ctx *gin.Context) {
|
||||
devices := map[string]any{}
|
||||
common.Devices.IterCb(func(uuid string, v any) bool {
|
||||
device := v.(*modules.Device)
|
||||
common.Devices.IterCb(func(uuid string, device *modules.Device) bool {
|
||||
devices[uuid] = *device
|
||||
return true
|
||||
})
|
||||
|
@@ -32,7 +32,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var blocked = cmap.New()
|
||||
var blocked = cmap.New[int64]()
|
||||
var lastRequest = time.Now().Unix()
|
||||
|
||||
func main() {
|
||||
@@ -212,14 +212,13 @@ func wsOnMessageBinary(session *melody.Session, data []byte) {
|
||||
}
|
||||
|
||||
func wsOnDisconnect(session *melody.Session) {
|
||||
if val, ok := common.Devices.Get(session.UUID); ok {
|
||||
deviceInfo := val.(*modules.Device)
|
||||
terminal.CloseSessionsByDevice(deviceInfo.ID)
|
||||
desktop.CloseSessionsByDevice(deviceInfo.ID)
|
||||
if device, ok := common.Devices.Get(session.UUID); ok {
|
||||
terminal.CloseSessionsByDevice(device.ID)
|
||||
desktop.CloseSessionsByDevice(device.ID)
|
||||
common.Info(nil, `CLIENT_OFFLINE`, ``, ``, map[string]any{
|
||||
`device`: map[string]any{
|
||||
`name`: deviceInfo.Hostname,
|
||||
`ip`: deviceInfo.WAN,
|
||||
`name`: device.Hostname,
|
||||
`ip`: device.WAN,
|
||||
},
|
||||
})
|
||||
} else {
|
||||
@@ -289,10 +288,9 @@ func pingDevice(s *melody.Session) {
|
||||
trigger := utils.GetStrUUID()
|
||||
common.SendPack(modules.Packet{Act: `PING`, Event: trigger}, s)
|
||||
common.AddEventOnce(func(packet modules.Packet, session *melody.Session) {
|
||||
val, ok := common.Devices.Get(s.UUID)
|
||||
device, ok := common.Devices.Get(s.UUID)
|
||||
if ok {
|
||||
deviceInfo := val.(*modules.Device)
|
||||
deviceInfo.Latency = uint(time.Now().UnixMilli()-t) / 2
|
||||
device.Latency = uint(time.Now().UnixMilli()-t) / 2
|
||||
}
|
||||
}, s.UUID, trigger, 3*time.Second)
|
||||
}
|
||||
@@ -300,12 +298,12 @@ func pingDevice(s *melody.Session) {
|
||||
func checkAuth() gin.HandlerFunc {
|
||||
// Token as key and update timestamp as value.
|
||||
// Stores authenticated tokens.
|
||||
tokens := cmap.New()
|
||||
tokens := cmap.New[int64]()
|
||||
go func() {
|
||||
for now := range time.NewTicker(60 * time.Second).C {
|
||||
var queue []string
|
||||
tokens.IterCb(func(key string, v any) bool {
|
||||
if now.Unix()-v.(int64) > 1800 {
|
||||
tokens.IterCb(func(key string, t int64) bool {
|
||||
if now.Unix()-t > 1800 {
|
||||
queue = append(queue, key)
|
||||
}
|
||||
return true
|
||||
@@ -313,8 +311,8 @@ func checkAuth() gin.HandlerFunc {
|
||||
tokens.Remove(queue...)
|
||||
queue = nil
|
||||
|
||||
blocked.IterCb(func(addr string, v any) bool {
|
||||
if now.Unix() > v.(int64) {
|
||||
blocked.IterCb(func(addr string, t int64) bool {
|
||||
if now.Unix() > t {
|
||||
queue = append(queue, addr)
|
||||
}
|
||||
return true
|
||||
@@ -347,7 +345,7 @@ func checkAuth() gin.HandlerFunc {
|
||||
if !passed {
|
||||
addr := common.GetRealIP(ctx)
|
||||
if expire, ok := blocked.Get(addr); ok {
|
||||
if now < expire.(int64) {
|
||||
if now < expire {
|
||||
ctx.AbortWithStatusJSON(http.StatusTooManyRequests, modules.Packet{Code: 1})
|
||||
return
|
||||
}
|
||||
|
@@ -2,36 +2,62 @@ package cmap
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const SHARD_COUNT = 32
|
||||
var SHARD_COUNT = 32
|
||||
|
||||
// ConcurrentMap is a "thread" safe map of type string:Anything.
|
||||
type Stringer interface {
|
||||
fmt.Stringer
|
||||
comparable
|
||||
}
|
||||
|
||||
// A "thread" safe map of type string:Anything.
|
||||
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
|
||||
type ConcurrentMap []*ConcurrentMapShared
|
||||
type ConcurrentMap[K comparable, V any] struct {
|
||||
shards []*ConcurrentMapShared[K, V]
|
||||
sharding func(key K) uint32
|
||||
}
|
||||
|
||||
// ConcurrentMapShared is a "thread" safe string to anything map.
|
||||
type ConcurrentMapShared struct {
|
||||
items map[string]interface{}
|
||||
// A "thread" safe string to anything map.
|
||||
type ConcurrentMapShared[K comparable, V any] struct {
|
||||
items map[K]V
|
||||
sync.RWMutex // Read Write mutex, guards access to internal map.
|
||||
}
|
||||
|
||||
// New creates a new concurrent map.
|
||||
func New() ConcurrentMap {
|
||||
m := make(ConcurrentMap, SHARD_COUNT)
|
||||
func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
|
||||
m := ConcurrentMap[K, V]{
|
||||
sharding: sharding,
|
||||
shards: make([]*ConcurrentMapShared[K, V], SHARD_COUNT),
|
||||
}
|
||||
for i := 0; i < SHARD_COUNT; i++ {
|
||||
m[i] = &ConcurrentMapShared{items: make(map[string]interface{})}
|
||||
m.shards[i] = &ConcurrentMapShared[K, V]{items: make(map[K]V)}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetShard returns shard under given key
|
||||
func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared {
|
||||
return m[uint(fnv32(key))%uint(SHARD_COUNT)]
|
||||
// Creates a new concurrent map.
|
||||
func New[V any]() ConcurrentMap[string, V] {
|
||||
return create[string, V](fnv32)
|
||||
}
|
||||
|
||||
func (m ConcurrentMap) MSet(data map[string]interface{}) {
|
||||
// Creates a new concurrent map.
|
||||
func NewStringer[K Stringer, V any]() ConcurrentMap[K, V] {
|
||||
return create[K, V](strfnv32[K])
|
||||
}
|
||||
|
||||
// Creates a new concurrent map.
|
||||
func NewWithCustomShardingFunction[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
|
||||
return create[K, V](sharding)
|
||||
}
|
||||
|
||||
// GetShard returns shard under given key
|
||||
func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] {
|
||||
return m.shards[uint(m.sharding(key))%uint(SHARD_COUNT)]
|
||||
}
|
||||
|
||||
func (m ConcurrentMap[K, V]) MSet(data map[K]V) {
|
||||
for key, value := range data {
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
@@ -40,8 +66,8 @@ func (m ConcurrentMap) MSet(data map[string]interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// Set sets the given value under the specified key.
|
||||
func (m ConcurrentMap) Set(key string, value interface{}) {
|
||||
// Sets the given value under the specified key.
|
||||
func (m ConcurrentMap[K, V]) Set(key K, value V) {
|
||||
// Get map shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
@@ -49,14 +75,14 @@ func (m ConcurrentMap) Set(key string, value interface{}) {
|
||||
shard.Unlock()
|
||||
}
|
||||
|
||||
// UpsertCb is callback to return new element to be inserted into the map
|
||||
// Callback to return new element to be inserted into the map
|
||||
// It is called while lock is held, therefore it MUST NOT
|
||||
// try to access other keys in same map, as it can lead to deadlock since
|
||||
// Go sync.RWLock is not reentrant
|
||||
type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{}
|
||||
type UpsertCb[V any] func(exist bool, valueInMap V, newValue V) V
|
||||
|
||||
// Upsert means Insert or Update - updates existing element or inserts a new one using UpsertCb
|
||||
func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) {
|
||||
// Insert or Update - updates existing element or inserts a new one using UpsertCb
|
||||
func (m ConcurrentMap[K, V]) Upsert(key K, value V, cb UpsertCb[V]) (res V) {
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
v, ok := shard.items[key]
|
||||
@@ -66,8 +92,8 @@ func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res i
|
||||
return res
|
||||
}
|
||||
|
||||
// SetIfAbsent sets the given value under the specified key if no value was associated with it.
|
||||
func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
|
||||
// Sets the given value under the specified key if no value was associated with it.
|
||||
func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool {
|
||||
// Get map shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
@@ -80,7 +106,7 @@ func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
|
||||
}
|
||||
|
||||
// Get retrieves an element from map under given key.
|
||||
func (m ConcurrentMap) Get(key string) (interface{}, bool) {
|
||||
func (m ConcurrentMap[K, V]) Get(key K) (V, bool) {
|
||||
// Get shard
|
||||
shard := m.GetShard(key)
|
||||
shard.RLock()
|
||||
@@ -91,10 +117,10 @@ func (m ConcurrentMap) Get(key string) (interface{}, bool) {
|
||||
}
|
||||
|
||||
// Count returns the number of elements within the map.
|
||||
func (m ConcurrentMap) Count() int {
|
||||
func (m ConcurrentMap[K, V]) Count() int {
|
||||
count := 0
|
||||
for i := 0; i < SHARD_COUNT; i++ {
|
||||
shard := m[i]
|
||||
shard := m.shards[i]
|
||||
shard.RLock()
|
||||
count += len(shard.items)
|
||||
shard.RUnlock()
|
||||
@@ -102,8 +128,8 @@ func (m ConcurrentMap) Count() int {
|
||||
return count
|
||||
}
|
||||
|
||||
// Has looks up an item under specified key
|
||||
func (m ConcurrentMap) Has(key string) bool {
|
||||
// Looks up an item under specified key
|
||||
func (m ConcurrentMap[K, V]) Has(key K) bool {
|
||||
// Get shard
|
||||
shard := m.GetShard(key)
|
||||
shard.RLock()
|
||||
@@ -114,9 +140,9 @@ func (m ConcurrentMap) Has(key string) bool {
|
||||
}
|
||||
|
||||
// Remove removes an element from the map.
|
||||
func (m ConcurrentMap) Remove(key ...string) {
|
||||
func (m ConcurrentMap[K, V]) Remove(keys ...K) {
|
||||
// Try to get shard.
|
||||
for _, k := range key {
|
||||
for _, k := range keys {
|
||||
shard := m.GetShard(k)
|
||||
shard.Lock()
|
||||
delete(shard.items, k)
|
||||
@@ -126,12 +152,12 @@ func (m ConcurrentMap) Remove(key ...string) {
|
||||
|
||||
// RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held
|
||||
// If returns true, the element will be removed from the map
|
||||
type RemoveCb func(key string, v interface{}, exists bool) bool
|
||||
type RemoveCb[K any, V any] func(key K, v V, exists bool) bool
|
||||
|
||||
// RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params
|
||||
// If callback returns true and element exists, it will remove it from the map
|
||||
// Returns the value returned by the callback (even if element was not present in the map)
|
||||
func (m ConcurrentMap) RemoveCb(key string, cb RemoveCb) bool {
|
||||
func (m ConcurrentMap[K, V]) RemoveCb(key K, cb RemoveCb[K, V]) bool {
|
||||
// Try to get shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
@@ -145,7 +171,7 @@ func (m ConcurrentMap) RemoveCb(key string, cb RemoveCb) bool {
|
||||
}
|
||||
|
||||
// Pop removes an element from the map and returns it
|
||||
func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
|
||||
func (m ConcurrentMap[K, V]) Pop(key K) (v V, exists bool) {
|
||||
// Try to get shard.
|
||||
shard := m.GetShard(key)
|
||||
shard.Lock()
|
||||
@@ -156,66 +182,66 @@ func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
|
||||
}
|
||||
|
||||
// IsEmpty checks if map is empty.
|
||||
func (m ConcurrentMap) IsEmpty() bool {
|
||||
func (m ConcurrentMap[K, V]) IsEmpty() bool {
|
||||
return m.Count() == 0
|
||||
}
|
||||
|
||||
// Tuple is used by the Iter & IterBuffered functions to wrap two variables together over a channel,
|
||||
type Tuple struct {
|
||||
Key string
|
||||
Val interface{}
|
||||
// Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
|
||||
type Tuple[K comparable, V any] struct {
|
||||
Key K
|
||||
Val V
|
||||
}
|
||||
|
||||
// Iter returns an iterator which could be used in a for range loop.
|
||||
//
|
||||
// Deprecated: using IterBuffered() will get a better performance
|
||||
func (m ConcurrentMap) Iter() <-chan Tuple {
|
||||
// Deprecated: using IterBuffered() will get a better performence
|
||||
func (m ConcurrentMap[K, V]) Iter() <-chan Tuple[K, V] {
|
||||
chans := snapshot(m)
|
||||
ch := make(chan Tuple)
|
||||
ch := make(chan Tuple[K, V])
|
||||
go fanIn(chans, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// IterBuffered returns a buffered iterator which could be used in a for range loop.
|
||||
func (m ConcurrentMap) IterBuffered() <-chan Tuple {
|
||||
func (m ConcurrentMap[K, V]) IterBuffered() <-chan Tuple[K, V] {
|
||||
chans := snapshot(m)
|
||||
total := 0
|
||||
for _, c := range chans {
|
||||
total += cap(c)
|
||||
}
|
||||
ch := make(chan Tuple, total)
|
||||
ch := make(chan Tuple[K, V], total)
|
||||
go fanIn(chans, ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
// Clear removes all items from map.
|
||||
func (m ConcurrentMap) Clear() {
|
||||
func (m ConcurrentMap[K, V]) Clear() {
|
||||
for item := range m.IterBuffered() {
|
||||
m.Remove(item.Key)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns an array of channels that contains elements in each shard,
|
||||
// Returns a array of channels that contains elements in each shard,
|
||||
// which likely takes a snapshot of `m`.
|
||||
// It returns once the size of each buffered channel is determined,
|
||||
// before all the channels are populated using goroutines.
|
||||
func snapshot(m ConcurrentMap) (chans []chan Tuple) {
|
||||
func snapshot[K comparable, V any](m ConcurrentMap[K, V]) (chans []chan Tuple[K, V]) {
|
||||
//When you access map items before initializing.
|
||||
if len(m) == 0 {
|
||||
if len(m.shards) == 0 {
|
||||
panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`)
|
||||
}
|
||||
chans = make([]chan Tuple, SHARD_COUNT)
|
||||
chans = make([]chan Tuple[K, V], SHARD_COUNT)
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(SHARD_COUNT)
|
||||
// Foreach shard.
|
||||
for index, shard := range m {
|
||||
go func(index int, shard *ConcurrentMapShared) {
|
||||
for index, shard := range m.shards {
|
||||
go func(index int, shard *ConcurrentMapShared[K, V]) {
|
||||
// Foreach key, value pair.
|
||||
shard.RLock()
|
||||
chans[index] = make(chan Tuple, len(shard.items))
|
||||
chans[index] = make(chan Tuple[K, V], len(shard.items))
|
||||
wg.Done()
|
||||
for key, val := range shard.items {
|
||||
chans[index] <- Tuple{key, val}
|
||||
chans[index] <- Tuple[K, V]{key, val}
|
||||
}
|
||||
shard.RUnlock()
|
||||
close(chans[index])
|
||||
@@ -226,11 +252,11 @@ func snapshot(m ConcurrentMap) (chans []chan Tuple) {
|
||||
}
|
||||
|
||||
// fanIn reads elements from channels `chans` into channel `out`
|
||||
func fanIn(chans []chan Tuple, out chan Tuple) {
|
||||
func fanIn[K comparable, V any](chans []chan Tuple[K, V], out chan Tuple[K, V]) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(len(chans))
|
||||
for _, ch := range chans {
|
||||
go func(ch chan Tuple) {
|
||||
go func(ch chan Tuple[K, V]) {
|
||||
for t := range ch {
|
||||
out <- t
|
||||
}
|
||||
@@ -241,9 +267,9 @@ func fanIn(chans []chan Tuple, out chan Tuple) {
|
||||
close(out)
|
||||
}
|
||||
|
||||
// Items returns all items as map[string]interface{}
|
||||
func (m ConcurrentMap) Items() map[string]interface{} {
|
||||
tmp := make(map[string]interface{})
|
||||
// Items returns all items as map[string]V
|
||||
func (m ConcurrentMap[K, V]) Items() map[K]V {
|
||||
tmp := make(map[K]V)
|
||||
|
||||
// Insert items to temporary map.
|
||||
for item := range m.IterBuffered() {
|
||||
@@ -253,18 +279,18 @@ func (m ConcurrentMap) Items() map[string]interface{} {
|
||||
return tmp
|
||||
}
|
||||
|
||||
// IterCb is iterator callback, called for every key,value found in
|
||||
// Iterator callbacalled for every key,value found in
|
||||
// maps. RLock is held for all calls for a given shard
|
||||
// therefore callback sess consistent view of a shard,
|
||||
// but not across the shards
|
||||
type IterCb func(key string, v interface{}) bool
|
||||
type IterCb[K comparable, V any] func(key K, v V) bool
|
||||
|
||||
// IterCb callback based iterator, the cheapest way to read
|
||||
// Callback based iterator, cheapest way to read
|
||||
// all elements in a map.
|
||||
func (m ConcurrentMap) IterCb(fn IterCb) {
|
||||
func (m ConcurrentMap[K, V]) IterCb(fn IterCb[K, V]) {
|
||||
escape := false
|
||||
for idx := range m {
|
||||
shard := (m)[idx]
|
||||
for idx := range m.shards {
|
||||
shard := (m.shards)[idx]
|
||||
shard.RLock()
|
||||
for key, value := range shard.items {
|
||||
if !fn(key, value) {
|
||||
@@ -280,15 +306,15 @@ func (m ConcurrentMap) IterCb(fn IterCb) {
|
||||
}
|
||||
|
||||
// Keys returns all keys as []string
|
||||
func (m ConcurrentMap) Keys() []string {
|
||||
func (m ConcurrentMap[K, V]) Keys() []K {
|
||||
count := m.Count()
|
||||
ch := make(chan string, count)
|
||||
ch := make(chan K, count)
|
||||
go func() {
|
||||
// Foreach shard.
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(SHARD_COUNT)
|
||||
for _, shard := range m {
|
||||
go func(shard *ConcurrentMapShared) {
|
||||
for _, shard := range m.shards {
|
||||
go func(shard *ConcurrentMapShared[K, V]) {
|
||||
// Foreach key, value pair.
|
||||
shard.RLock()
|
||||
for key := range shard.items {
|
||||
@@ -303,17 +329,17 @@ func (m ConcurrentMap) Keys() []string {
|
||||
}()
|
||||
|
||||
// Generate keys
|
||||
keys := make([]string, 0, count)
|
||||
keys := make([]K, 0, count)
|
||||
for k := range ch {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
//MarshalJSON reviles ConcurrentMap "private" variables to json marshal.
|
||||
func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
|
||||
// Reviles ConcurrentMap "private" variables to json marshal.
|
||||
func (m ConcurrentMap[K, V]) MarshalJSON() ([]byte, error) {
|
||||
// Create a temporary map, which will hold all item spread across shards.
|
||||
tmp := make(map[string]interface{})
|
||||
tmp := make(map[K]V)
|
||||
|
||||
// Insert items to temporary map.
|
||||
for item := range m.IterBuffered() {
|
||||
@@ -321,6 +347,9 @@ func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
return json.Marshal(tmp)
|
||||
}
|
||||
func strfnv32[K fmt.Stringer](key K) uint32 {
|
||||
return fnv32(key.String())
|
||||
}
|
||||
|
||||
func fnv32(key string) uint32 {
|
||||
hash := uint32(2166136261)
|
||||
@@ -333,24 +362,18 @@ func fnv32(key string) uint32 {
|
||||
return hash
|
||||
}
|
||||
|
||||
// Concurrent map uses Interface{} as its value, therefore JSON Unmarshal
|
||||
// probably won't know which to type to unmarshal into, in such case
|
||||
// we'll end up with a value of type map[string]interface{}, In most cases this isn't
|
||||
// out value type, this is why we've decided to remove this functionality.
|
||||
// Reverse process of Marshal.
|
||||
func (m *ConcurrentMap[K, V]) UnmarshalJSON(b []byte) (err error) {
|
||||
tmp := make(map[K]V)
|
||||
|
||||
// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) {
|
||||
// // Reverse process of Marshal.
|
||||
// Unmarshal into a single map.
|
||||
if err := json.Unmarshal(b, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// tmp := make(map[string]interface{})
|
||||
|
||||
// // Unmarshal into a single map.
|
||||
// if err := json.Unmarshal(b, &tmp); err != nil {
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // foreach key,value pair in temporary map insert into our concurrent map.
|
||||
// for key, val := range tmp {
|
||||
// m.Set(key, val)
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
// foreach key,value pair in temporary map insert into our concurrent map.
|
||||
for key, val := range tmp {
|
||||
m.Set(key, val)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@@ -5,7 +5,7 @@ import (
|
||||
)
|
||||
|
||||
type hub struct {
|
||||
sessions cmap.ConcurrentMap
|
||||
sessions cmap.ConcurrentMap[string, *Session]
|
||||
queue chan *envelope
|
||||
register chan *Session
|
||||
unregister chan *Session
|
||||
@@ -15,7 +15,7 @@ type hub struct {
|
||||
|
||||
func newHub() *hub {
|
||||
return &hub{
|
||||
sessions: cmap.New(),
|
||||
sessions: cmap.New[*Session](),
|
||||
queue: make(chan *envelope),
|
||||
register: make(chan *Session),
|
||||
unregister: make(chan *Session),
|
||||
@@ -38,19 +38,16 @@ loop:
|
||||
if len(m.list) > 0 {
|
||||
for _, uuid := range m.list {
|
||||
if s, ok := h.sessions.Get(uuid); ok {
|
||||
s := s.(*Session)
|
||||
s.writeMessage(m)
|
||||
}
|
||||
}
|
||||
} else if m.filter == nil {
|
||||
h.sessions.IterCb(func(uuid string, v interface{}) bool {
|
||||
s := v.(*Session)
|
||||
h.sessions.IterCb(func(uuid string, s *Session) bool {
|
||||
s.writeMessage(m)
|
||||
return true
|
||||
})
|
||||
} else {
|
||||
h.sessions.IterCb(func(uuid string, v interface{}) bool {
|
||||
s := v.(*Session)
|
||||
h.sessions.IterCb(func(uuid string, s *Session) bool {
|
||||
if m.filter(s) {
|
||||
s.writeMessage(m)
|
||||
}
|
||||
@@ -60,8 +57,7 @@ loop:
|
||||
case m := <-h.exit:
|
||||
var keys []string
|
||||
h.open = false
|
||||
h.sessions.IterCb(func(uuid string, v interface{}) bool {
|
||||
s := v.(*Session)
|
||||
h.sessions.IterCb(func(uuid string, s *Session) bool {
|
||||
s.writeMessage(m)
|
||||
s.Close()
|
||||
keys = append(keys, uuid)
|
||||
|
@@ -301,27 +301,14 @@ func (m *Melody) SendMultiple(msg []byte, list []string) error {
|
||||
|
||||
// GetSessionByUUID returns the session with specified uuid.
|
||||
func (m *Melody) GetSessionByUUID(uuid string) (*Session, bool) {
|
||||
val, ok := m.hub.sessions.Get(uuid)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
s, ok := val.(*Session)
|
||||
if !ok {
|
||||
m.hub.sessions.Remove(uuid)
|
||||
}
|
||||
return s, ok
|
||||
return m.hub.sessions.Get(uuid)
|
||||
}
|
||||
|
||||
// IterSessions iterates all sessions.
|
||||
func (m *Melody) IterSessions(fn func(uuid string, s *Session) bool) {
|
||||
var invalid []string
|
||||
m.hub.sessions.IterCb(func(uuid string, v interface{}) bool {
|
||||
if s, ok := v.(*Session); !ok {
|
||||
invalid = append(invalid, uuid)
|
||||
return true
|
||||
} else {
|
||||
return fn(uuid, s)
|
||||
}
|
||||
m.hub.sessions.IterCb(func(uuid string, s *Session) bool {
|
||||
return fn(uuid, s)
|
||||
})
|
||||
m.hub.sessions.Remove(invalid...)
|
||||
}
|
||||
|
@@ -394,7 +394,7 @@ function overview(props) {
|
||||
<>
|
||||
<Image
|
||||
preview={{
|
||||
visible: screenBlob,
|
||||
visible: !!screenBlob,
|
||||
src: screenBlob,
|
||||
onVisibleChange: () => {
|
||||
URL.revokeObjectURL(screenBlob);
|
||||
|
Reference in New Issue
Block a user