From bb68f494fdf43e5b8d50ff75408c960660d765d8 Mon Sep 17 00:00:00 2001 From: XZB-1248 Date: Sun, 21 May 2023 23:11:09 +0800 Subject: [PATCH] optimize: update concurrent_map --- .gitignore | 2 +- client/service/desktop/desktop.go | 29 ++- client/service/terminal/terminal.go | 2 - client/service/terminal/terminal_others.go | 29 ++- client/service/terminal/terminal_windows.go | 29 ++- server/common/common.go | 5 +- server/common/event.go | 9 +- server/common/log.go | 4 +- server/handler/bridge/bridge.go | 12 +- server/handler/utility/utility.go | 19 +- server/main.go | 30 ++- utils/cmap/concurrent_map.go | 201 +++++++++++--------- utils/melody/hub.go | 14 +- utils/melody/melody.go | 19 +- web/src/pages/overview.jsx | 2 +- 15 files changed, 191 insertions(+), 215 deletions(-) diff --git a/.gitignore b/.gitignore index 5176f85..1a1e632 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,6 @@ /tools /logs /.idea -/Config.json +/config.json dist/ node_modules/ \ No newline at end of file diff --git a/client/service/desktop/desktop.go b/client/service/desktop/desktop.go index 64bcd51..838830c 100644 --- a/client/service/desktop/desktop.go +++ b/client/service/desktop/desktop.go @@ -62,7 +62,7 @@ const imageQuality = 70 var lock = &sync.Mutex{} var working = false -var sessions = cmap.New() +var sessions = cmap.New[*session]() var prevDesktop *image.RGBA var displayBounds image.Rectangle var errNoImage = errors.New(`DESKTOP.NO_IMAGE_YET`) @@ -127,8 +127,7 @@ func worker() { } func sendImageDiff(diff []*[]byte) { - sessions.IterCb(func(uuid string, t any) bool { - desktop := t.(*session) + sessions.IterCb(func(uuid string, desktop *session) bool { desktop.lock.Lock() if !desktop.escape { if len(desktop.channel) >= frameBuffer { @@ -146,9 +145,8 @@ func sendImageDiff(diff []*[]byte) { func quitAllDesktop(info string) { keys := make([]string, 0) - sessions.IterCb(func(uuid string, t any) bool { + sessions.IterCb(func(uuid string, desktop *session) bool { keys = append(keys, uuid) - desktop := t.(*session) desktop.escape = true desktop.channel <- message{t: 1, info: info} return true @@ -346,24 +344,23 @@ func PingDesktop(pack modules.Packet) { } else { uuid = val.(string) } - if val, ok := sessions.Get(uuid); ok { - desktop = val.(*session) - desktop.lastPack = utils.Unix + desktop, ok := sessions.Get(uuid) + if !ok { + return } + desktop.lastPack = utils.Unix } func KillDesktop(pack modules.Packet) { var uuid string - var desktop *session if val, ok := pack.GetData(`desktop`, reflect.String); !ok { return } else { uuid = val.(string) } - if val, ok := sessions.Get(uuid); !ok { + desktop, ok := sessions.Get(uuid) + if !ok { return - } else { - desktop = val.(*session) } sessions.Remove(uuid) data, _ := utils.JSON.Marshal(modules.Packet{Act: `DESKTOP_QUIT`, Msg: `${i18n|DESKTOP.SESSION_CLOSED}`}) @@ -383,10 +380,9 @@ func GetDesktop(pack modules.Packet) { } else { uuid = val.(string) } - if val, ok := sessions.Get(uuid); !ok { + desktop, ok := sessions.Get(uuid) + if !ok { return - } else { - desktop = val.(*session) } if !desktop.escape { lock.Lock() @@ -450,8 +446,7 @@ func healthCheck() { timestamp := now.Unix() // stores sessions to be disconnected keys := make([]string, 0) - sessions.IterCb(func(uuid string, t any) bool { - desktop := t.(*session) + sessions.IterCb(func(uuid string, desktop *session) bool { if timestamp-desktop.lastPack > MaxInterval { keys = append(keys, uuid) } diff --git a/client/service/terminal/terminal.go b/client/service/terminal/terminal.go index f7e8462..6e94168 100644 --- a/client/service/terminal/terminal.go +++ b/client/service/terminal/terminal.go @@ -1,11 +1,9 @@ package terminal import ( - "Spark/utils/cmap" "errors" ) -var terminals = cmap.New() var ( errDataNotFound = errors.New(`no input found in packet`) errDataInvalid = errors.New(`can not parse data in packet`) diff --git a/client/service/terminal/terminal_others.go b/client/service/terminal/terminal_others.go index 9b34254..e3c39dd 100644 --- a/client/service/terminal/terminal_others.go +++ b/client/service/terminal/terminal_others.go @@ -6,6 +6,7 @@ import ( "Spark/client/common" "Spark/modules" "Spark/utils" + "Spark/utils/cmap" "encoding/hex" "github.com/creack/pty" "os" @@ -23,6 +24,7 @@ type terminal struct { cmd *exec.Cmd } +var terminals = cmap.New[*terminal]() var defaultShell = `` func init() { @@ -88,11 +90,8 @@ func InitTerminal(pack modules.Packet) error { } func InputRawTerminal(input []byte, uuid string) { - var session *terminal - - if val, ok := terminals.Get(uuid); ok { - session = val.(*terminal) - } else { + session, ok := terminals.Get(uuid) + if !ok { return } session.pty.Write(input) @@ -165,10 +164,9 @@ func KillTerminal(pack modules.Packet) { } else { uuid = val.(string) } - if val, ok := terminals.Get(uuid); !ok { + session, ok := terminals.Get(uuid) + if !ok { return - } else { - session = val.(*terminal) } terminals.Remove(uuid) data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`}) @@ -180,18 +178,16 @@ func KillTerminal(pack modules.Packet) { func PingTerminal(pack modules.Packet) { var termUUID string - var termSession *terminal if val, ok := pack.GetData(`terminal`, reflect.String); !ok { return } else { termUUID = val.(string) } - if val, ok := terminals.Get(termUUID); !ok { + session, ok := terminals.Get(termUUID) + if !ok { return - } else { - termSession = val.(*terminal) - termSession.lastPack = utils.Unix } + session.lastPack = utils.Unix } func doKillTerminal(terminal *terminal) { @@ -234,11 +230,10 @@ func healthCheck() { timestamp := now.Unix() // stores sessions to be disconnected queue := make([]string, 0) - terminals.IterCb(func(uuid string, t any) bool { - termSession := t.(*terminal) - if timestamp-termSession.lastPack > MaxInterval { + terminals.IterCb(func(uuid string, session *terminal) bool { + if timestamp-session.lastPack > MaxInterval { queue = append(queue, uuid) - doKillTerminal(termSession) + doKillTerminal(session) } return true }) diff --git a/client/service/terminal/terminal_windows.go b/client/service/terminal/terminal_windows.go index ec383cb..31c70f2 100644 --- a/client/service/terminal/terminal_windows.go +++ b/client/service/terminal/terminal_windows.go @@ -4,6 +4,7 @@ import ( "Spark/client/common" "Spark/modules" "Spark/utils" + "Spark/utils/cmap" "encoding/hex" "io" "os/exec" @@ -23,6 +24,7 @@ type terminal struct { stdin *io.WriteCloser } +var terminals = cmap.New[*terminal]() var defaultCmd = `` func init() { @@ -108,11 +110,8 @@ func InitTerminal(pack modules.Packet) error { } func InputRawTerminal(input []byte, uuid string) { - var session *terminal - - if val, ok := terminals.Get(uuid); ok { - session = val.(*terminal) - } else { + session, ok := terminals.Get(uuid) + if !ok { return } (*session.stdin).Write(input) @@ -152,16 +151,14 @@ func ResizeTerminal(pack modules.Packet) error { func KillTerminal(pack modules.Packet) { var uuid string - var session *terminal if val, ok := pack.GetData(`terminal`, reflect.String); !ok { return } else { uuid = val.(string) } - if val, ok := terminals.Get(uuid); !ok { + session, ok := terminals.Get(uuid) + if !ok { return - } else { - session = val.(*terminal) } terminals.Remove(uuid) data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`}) @@ -179,12 +176,11 @@ func PingTerminal(pack modules.Packet) { } else { uuid = val.(string) } - if val, ok := terminals.Get(uuid); !ok { + session, ok := terminals.Get(uuid) + if !ok { return - } else { - session = val.(*terminal) - session.lastPack = utils.Unix } + session.lastPack = utils.Unix } func doKillTerminal(terminal *terminal) { @@ -221,11 +217,10 @@ func healthCheck() { timestamp := now.Unix() // stores sessions to be disconnected keys := make([]string, 0) - terminals.IterCb(func(uuid string, t any) bool { - termSession := t.(*terminal) - if timestamp-termSession.lastPack > MaxInterval { + terminals.IterCb(func(uuid string, session *terminal) bool { + if timestamp-session.lastPack > MaxInterval { keys = append(keys, uuid) - doKillTerminal(termSession) + doKillTerminal(session) } return true }) diff --git a/server/common/common.go b/server/common/common.go index 245f1a0..037988e 100644 --- a/server/common/common.go +++ b/server/common/common.go @@ -17,7 +17,7 @@ import ( const MaxMessageSize = (2 << 15) + 1024 var Melody = melody.New() -var Devices = cmap.New() +var Devices = cmap.New[*modules.Device]() func SendPackByUUID(pack modules.Packet, uuid string) bool { session, ok := Melody.GetSessionByUUID(uuid) @@ -164,8 +164,7 @@ func CheckDevice(deviceID, connUUID string) (string, bool) { } } else { tempConnUUID := `` - Devices.IterCb(func(uuid string, v any) bool { - device := v.(*modules.Device) + Devices.IterCb(func(uuid string, device *modules.Device) bool { if device.ID == deviceID { tempConnUUID = uuid return false diff --git a/server/common/event.go b/server/common/event.go index 0fcc848..d692e99 100644 --- a/server/common/event.go +++ b/server/common/event.go @@ -15,7 +15,7 @@ type event struct { remove chan bool } -var events = cmap.New() +var events = cmap.New[*event]() // CallEvent tries to call the callback with the given uuid // after that, it will notify the caller via the channel @@ -23,11 +23,10 @@ func CallEvent(pack modules.Packet, session *melody.Session) { if len(pack.Event) == 0 { return } - v, ok := events.Get(pack.Event) + ev, ok := events.Get(pack.Event) if !ok { return } - ev := v.(*event) if session != nil && session.UUID != ev.connection { return } @@ -76,12 +75,11 @@ func AddEvent(fn EventCallback, connUUID, trigger string) { // RemoveEvent deletes the event with the given event trigger. // The ok will be returned to caller if the event is temp (only once). func RemoveEvent(trigger string, ok ...bool) { - v, found := events.Get(trigger) + ev, found := events.Get(trigger) if !found { return } events.Remove(trigger) - ev := v.(*event) if ev.remove != nil { if len(ok) > 0 { ev.remove <- ok[0] @@ -89,7 +87,6 @@ func RemoveEvent(trigger string, ok ...bool) { ev.remove <- false } } - v = nil ev = nil } diff --git a/server/common/log.go b/server/common/log.go index 5db5952..5c98827 100644 --- a/server/common/log.go +++ b/server/common/log.go @@ -1,7 +1,6 @@ package common import ( - "Spark/modules" "Spark/server/config" "Spark/utils" "Spark/utils/melody" @@ -81,9 +80,8 @@ func getLog(ctx any, event, status, msg string, args map[string]any) string { } } if targetInfo { - val, ok := Devices.Get(connUUID) + device, ok := Devices.Get(connUUID) if ok { - device := val.(*modules.Device) args[`target`] = map[string]any{ `name`: device.Hostname, `ip`: device.WAN, diff --git a/server/handler/bridge/bridge.go b/server/handler/bridge/bridge.go index 34820d4..cc12089 100644 --- a/server/handler/bridge/bridge.go +++ b/server/handler/bridge/bridge.go @@ -28,15 +28,14 @@ type Bridge struct { OnFinish func(bridge *Bridge) } -var bridges = cmap.New() +var bridges = cmap.New[*Bridge]() func init() { go func() { for now := range time.NewTicker(15 * time.Second).C { var queue []string timestamp := now.Unix() - bridges.IterCb(func(k string, v any) bool { - b := v.(*Bridge) + bridges.IterCb(func(k string, b *Bridge) bool { if timestamp-b.creation > 60 && !b.using { b.lock.Lock() if b.Src != nil && b.Src.Request.Body != nil { @@ -63,12 +62,12 @@ func CheckBridge(ctx *gin.Context) *Bridge { ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_PARAMETER}`}) return nil } - val, ok := bridges.Get(form.Bridge) + b, ok := bridges.Get(form.Bridge) if !ok { ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_BRIDGE_ID}`}) return nil } - return val.(*Bridge) + return b } func BridgePush(ctx *gin.Context) { @@ -218,12 +217,11 @@ func AddBridgeWithDst(ext any, uuid string, Dst *gin.Context) *Bridge { } func RemoveBridge(uuid string) { - val, ok := bridges.Get(uuid) + b, ok := bridges.Get(uuid) if !ok { return } bridges.Remove(uuid) - b := val.(*Bridge) if b.Src != nil && b.Src.Request.Body != nil { b.Src.Request.Body.Close() } diff --git a/server/handler/utility/utility.go b/server/handler/utility/utility.go index 0e652d8..5bba88f 100644 --- a/server/handler/utility/utility.go +++ b/server/handler/utility/utility.go @@ -70,8 +70,7 @@ func OnDevicePack(data []byte, session *melody.Session) error { // If so, then find the session and let client quit. // This will keep only one connection remained per device. exSession := `` - common.Devices.IterCb(func(uuid string, v any) bool { - device := v.(*modules.Device) + common.Devices.IterCb(func(uuid string, device *modules.Device) bool { if device.ID == pack.Device.ID { exSession = uuid target, ok := common.Melody.GetSessionByUUID(uuid) @@ -94,14 +93,13 @@ func OnDevicePack(data []byte, session *melody.Session) error { }, }) } else { - val, ok := common.Devices.Get(session.UUID) + device, ok := common.Devices.Get(session.UUID) if ok { - deviceInfo := val.(*modules.Device) - deviceInfo.CPU = pack.Device.CPU - deviceInfo.RAM = pack.Device.RAM - deviceInfo.Net = pack.Device.Net - deviceInfo.Disk = pack.Device.Disk - deviceInfo.Uptime = pack.Device.Uptime + device.CPU = pack.Device.CPU + device.RAM = pack.Device.RAM + device.Net = pack.Device.Net + device.Disk = pack.Device.Disk + device.Uptime = pack.Device.Uptime } } common.SendPack(modules.Packet{Code: 0}, session) @@ -268,8 +266,7 @@ func ExecDeviceCmd(ctx *gin.Context) { // GetDevices will return all info about all clients. func GetDevices(ctx *gin.Context) { devices := map[string]any{} - common.Devices.IterCb(func(uuid string, v any) bool { - device := v.(*modules.Device) + common.Devices.IterCb(func(uuid string, device *modules.Device) bool { devices[uuid] = *device return true }) diff --git a/server/main.go b/server/main.go index e03af4c..7197d9d 100644 --- a/server/main.go +++ b/server/main.go @@ -32,7 +32,7 @@ import ( "github.com/gin-gonic/gin" ) -var blocked = cmap.New() +var blocked = cmap.New[int64]() var lastRequest = time.Now().Unix() func main() { @@ -212,14 +212,13 @@ func wsOnMessageBinary(session *melody.Session, data []byte) { } func wsOnDisconnect(session *melody.Session) { - if val, ok := common.Devices.Get(session.UUID); ok { - deviceInfo := val.(*modules.Device) - terminal.CloseSessionsByDevice(deviceInfo.ID) - desktop.CloseSessionsByDevice(deviceInfo.ID) + if device, ok := common.Devices.Get(session.UUID); ok { + terminal.CloseSessionsByDevice(device.ID) + desktop.CloseSessionsByDevice(device.ID) common.Info(nil, `CLIENT_OFFLINE`, ``, ``, map[string]any{ `device`: map[string]any{ - `name`: deviceInfo.Hostname, - `ip`: deviceInfo.WAN, + `name`: device.Hostname, + `ip`: device.WAN, }, }) } else { @@ -289,10 +288,9 @@ func pingDevice(s *melody.Session) { trigger := utils.GetStrUUID() common.SendPack(modules.Packet{Act: `PING`, Event: trigger}, s) common.AddEventOnce(func(packet modules.Packet, session *melody.Session) { - val, ok := common.Devices.Get(s.UUID) + device, ok := common.Devices.Get(s.UUID) if ok { - deviceInfo := val.(*modules.Device) - deviceInfo.Latency = uint(time.Now().UnixMilli()-t) / 2 + device.Latency = uint(time.Now().UnixMilli()-t) / 2 } }, s.UUID, trigger, 3*time.Second) } @@ -300,12 +298,12 @@ func pingDevice(s *melody.Session) { func checkAuth() gin.HandlerFunc { // Token as key and update timestamp as value. // Stores authenticated tokens. - tokens := cmap.New() + tokens := cmap.New[int64]() go func() { for now := range time.NewTicker(60 * time.Second).C { var queue []string - tokens.IterCb(func(key string, v any) bool { - if now.Unix()-v.(int64) > 1800 { + tokens.IterCb(func(key string, t int64) bool { + if now.Unix()-t > 1800 { queue = append(queue, key) } return true @@ -313,8 +311,8 @@ func checkAuth() gin.HandlerFunc { tokens.Remove(queue...) queue = nil - blocked.IterCb(func(addr string, v any) bool { - if now.Unix() > v.(int64) { + blocked.IterCb(func(addr string, t int64) bool { + if now.Unix() > t { queue = append(queue, addr) } return true @@ -347,7 +345,7 @@ func checkAuth() gin.HandlerFunc { if !passed { addr := common.GetRealIP(ctx) if expire, ok := blocked.Get(addr); ok { - if now < expire.(int64) { + if now < expire { ctx.AbortWithStatusJSON(http.StatusTooManyRequests, modules.Packet{Code: 1}) return } diff --git a/utils/cmap/concurrent_map.go b/utils/cmap/concurrent_map.go index 687a827..91a4ce0 100644 --- a/utils/cmap/concurrent_map.go +++ b/utils/cmap/concurrent_map.go @@ -2,36 +2,62 @@ package cmap import ( "encoding/json" + "fmt" "sync" ) -const SHARD_COUNT = 32 +var SHARD_COUNT = 32 -// ConcurrentMap is a "thread" safe map of type string:Anything. +type Stringer interface { + fmt.Stringer + comparable +} + +// A "thread" safe map of type string:Anything. // To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. -type ConcurrentMap []*ConcurrentMapShared +type ConcurrentMap[K comparable, V any] struct { + shards []*ConcurrentMapShared[K, V] + sharding func(key K) uint32 +} -// ConcurrentMapShared is a "thread" safe string to anything map. -type ConcurrentMapShared struct { - items map[string]interface{} +// A "thread" safe string to anything map. +type ConcurrentMapShared[K comparable, V any] struct { + items map[K]V sync.RWMutex // Read Write mutex, guards access to internal map. } -// New creates a new concurrent map. -func New() ConcurrentMap { - m := make(ConcurrentMap, SHARD_COUNT) +func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] { + m := ConcurrentMap[K, V]{ + sharding: sharding, + shards: make([]*ConcurrentMapShared[K, V], SHARD_COUNT), + } for i := 0; i < SHARD_COUNT; i++ { - m[i] = &ConcurrentMapShared{items: make(map[string]interface{})} + m.shards[i] = &ConcurrentMapShared[K, V]{items: make(map[K]V)} } return m } -// GetShard returns shard under given key -func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared { - return m[uint(fnv32(key))%uint(SHARD_COUNT)] +// Creates a new concurrent map. +func New[V any]() ConcurrentMap[string, V] { + return create[string, V](fnv32) } -func (m ConcurrentMap) MSet(data map[string]interface{}) { +// Creates a new concurrent map. +func NewStringer[K Stringer, V any]() ConcurrentMap[K, V] { + return create[K, V](strfnv32[K]) +} + +// Creates a new concurrent map. +func NewWithCustomShardingFunction[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] { + return create[K, V](sharding) +} + +// GetShard returns shard under given key +func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] { + return m.shards[uint(m.sharding(key))%uint(SHARD_COUNT)] +} + +func (m ConcurrentMap[K, V]) MSet(data map[K]V) { for key, value := range data { shard := m.GetShard(key) shard.Lock() @@ -40,8 +66,8 @@ func (m ConcurrentMap) MSet(data map[string]interface{}) { } } -// Set sets the given value under the specified key. -func (m ConcurrentMap) Set(key string, value interface{}) { +// Sets the given value under the specified key. +func (m ConcurrentMap[K, V]) Set(key K, value V) { // Get map shard. shard := m.GetShard(key) shard.Lock() @@ -49,14 +75,14 @@ func (m ConcurrentMap) Set(key string, value interface{}) { shard.Unlock() } -// UpsertCb is callback to return new element to be inserted into the map +// Callback to return new element to be inserted into the map // It is called while lock is held, therefore it MUST NOT // try to access other keys in same map, as it can lead to deadlock since // Go sync.RWLock is not reentrant -type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{} +type UpsertCb[V any] func(exist bool, valueInMap V, newValue V) V -// Upsert means Insert or Update - updates existing element or inserts a new one using UpsertCb -func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) { +// Insert or Update - updates existing element or inserts a new one using UpsertCb +func (m ConcurrentMap[K, V]) Upsert(key K, value V, cb UpsertCb[V]) (res V) { shard := m.GetShard(key) shard.Lock() v, ok := shard.items[key] @@ -66,8 +92,8 @@ func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res i return res } -// SetIfAbsent sets the given value under the specified key if no value was associated with it. -func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool { +// Sets the given value under the specified key if no value was associated with it. +func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool { // Get map shard. shard := m.GetShard(key) shard.Lock() @@ -80,7 +106,7 @@ func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool { } // Get retrieves an element from map under given key. -func (m ConcurrentMap) Get(key string) (interface{}, bool) { +func (m ConcurrentMap[K, V]) Get(key K) (V, bool) { // Get shard shard := m.GetShard(key) shard.RLock() @@ -91,10 +117,10 @@ func (m ConcurrentMap) Get(key string) (interface{}, bool) { } // Count returns the number of elements within the map. -func (m ConcurrentMap) Count() int { +func (m ConcurrentMap[K, V]) Count() int { count := 0 for i := 0; i < SHARD_COUNT; i++ { - shard := m[i] + shard := m.shards[i] shard.RLock() count += len(shard.items) shard.RUnlock() @@ -102,8 +128,8 @@ func (m ConcurrentMap) Count() int { return count } -// Has looks up an item under specified key -func (m ConcurrentMap) Has(key string) bool { +// Looks up an item under specified key +func (m ConcurrentMap[K, V]) Has(key K) bool { // Get shard shard := m.GetShard(key) shard.RLock() @@ -114,9 +140,9 @@ func (m ConcurrentMap) Has(key string) bool { } // Remove removes an element from the map. -func (m ConcurrentMap) Remove(key ...string) { +func (m ConcurrentMap[K, V]) Remove(keys ...K) { // Try to get shard. - for _, k := range key { + for _, k := range keys { shard := m.GetShard(k) shard.Lock() delete(shard.items, k) @@ -126,12 +152,12 @@ func (m ConcurrentMap) Remove(key ...string) { // RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held // If returns true, the element will be removed from the map -type RemoveCb func(key string, v interface{}, exists bool) bool +type RemoveCb[K any, V any] func(key K, v V, exists bool) bool // RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params // If callback returns true and element exists, it will remove it from the map // Returns the value returned by the callback (even if element was not present in the map) -func (m ConcurrentMap) RemoveCb(key string, cb RemoveCb) bool { +func (m ConcurrentMap[K, V]) RemoveCb(key K, cb RemoveCb[K, V]) bool { // Try to get shard. shard := m.GetShard(key) shard.Lock() @@ -145,7 +171,7 @@ func (m ConcurrentMap) RemoveCb(key string, cb RemoveCb) bool { } // Pop removes an element from the map and returns it -func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) { +func (m ConcurrentMap[K, V]) Pop(key K) (v V, exists bool) { // Try to get shard. shard := m.GetShard(key) shard.Lock() @@ -156,66 +182,66 @@ func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) { } // IsEmpty checks if map is empty. -func (m ConcurrentMap) IsEmpty() bool { +func (m ConcurrentMap[K, V]) IsEmpty() bool { return m.Count() == 0 } -// Tuple is used by the Iter & IterBuffered functions to wrap two variables together over a channel, -type Tuple struct { - Key string - Val interface{} +// Used by the Iter & IterBuffered functions to wrap two variables together over a channel, +type Tuple[K comparable, V any] struct { + Key K + Val V } // Iter returns an iterator which could be used in a for range loop. // -// Deprecated: using IterBuffered() will get a better performance -func (m ConcurrentMap) Iter() <-chan Tuple { +// Deprecated: using IterBuffered() will get a better performence +func (m ConcurrentMap[K, V]) Iter() <-chan Tuple[K, V] { chans := snapshot(m) - ch := make(chan Tuple) + ch := make(chan Tuple[K, V]) go fanIn(chans, ch) return ch } // IterBuffered returns a buffered iterator which could be used in a for range loop. -func (m ConcurrentMap) IterBuffered() <-chan Tuple { +func (m ConcurrentMap[K, V]) IterBuffered() <-chan Tuple[K, V] { chans := snapshot(m) total := 0 for _, c := range chans { total += cap(c) } - ch := make(chan Tuple, total) + ch := make(chan Tuple[K, V], total) go fanIn(chans, ch) return ch } // Clear removes all items from map. -func (m ConcurrentMap) Clear() { +func (m ConcurrentMap[K, V]) Clear() { for item := range m.IterBuffered() { m.Remove(item.Key) } } -// Returns an array of channels that contains elements in each shard, +// Returns a array of channels that contains elements in each shard, // which likely takes a snapshot of `m`. // It returns once the size of each buffered channel is determined, // before all the channels are populated using goroutines. -func snapshot(m ConcurrentMap) (chans []chan Tuple) { +func snapshot[K comparable, V any](m ConcurrentMap[K, V]) (chans []chan Tuple[K, V]) { //When you access map items before initializing. - if len(m) == 0 { + if len(m.shards) == 0 { panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`) } - chans = make([]chan Tuple, SHARD_COUNT) + chans = make([]chan Tuple[K, V], SHARD_COUNT) wg := sync.WaitGroup{} wg.Add(SHARD_COUNT) // Foreach shard. - for index, shard := range m { - go func(index int, shard *ConcurrentMapShared) { + for index, shard := range m.shards { + go func(index int, shard *ConcurrentMapShared[K, V]) { // Foreach key, value pair. shard.RLock() - chans[index] = make(chan Tuple, len(shard.items)) + chans[index] = make(chan Tuple[K, V], len(shard.items)) wg.Done() for key, val := range shard.items { - chans[index] <- Tuple{key, val} + chans[index] <- Tuple[K, V]{key, val} } shard.RUnlock() close(chans[index]) @@ -226,11 +252,11 @@ func snapshot(m ConcurrentMap) (chans []chan Tuple) { } // fanIn reads elements from channels `chans` into channel `out` -func fanIn(chans []chan Tuple, out chan Tuple) { +func fanIn[K comparable, V any](chans []chan Tuple[K, V], out chan Tuple[K, V]) { wg := sync.WaitGroup{} wg.Add(len(chans)) for _, ch := range chans { - go func(ch chan Tuple) { + go func(ch chan Tuple[K, V]) { for t := range ch { out <- t } @@ -241,9 +267,9 @@ func fanIn(chans []chan Tuple, out chan Tuple) { close(out) } -// Items returns all items as map[string]interface{} -func (m ConcurrentMap) Items() map[string]interface{} { - tmp := make(map[string]interface{}) +// Items returns all items as map[string]V +func (m ConcurrentMap[K, V]) Items() map[K]V { + tmp := make(map[K]V) // Insert items to temporary map. for item := range m.IterBuffered() { @@ -253,18 +279,18 @@ func (m ConcurrentMap) Items() map[string]interface{} { return tmp } -// IterCb is iterator callback, called for every key,value found in +// Iterator callbacalled for every key,value found in // maps. RLock is held for all calls for a given shard // therefore callback sess consistent view of a shard, // but not across the shards -type IterCb func(key string, v interface{}) bool +type IterCb[K comparable, V any] func(key K, v V) bool -// IterCb callback based iterator, the cheapest way to read +// Callback based iterator, cheapest way to read // all elements in a map. -func (m ConcurrentMap) IterCb(fn IterCb) { +func (m ConcurrentMap[K, V]) IterCb(fn IterCb[K, V]) { escape := false - for idx := range m { - shard := (m)[idx] + for idx := range m.shards { + shard := (m.shards)[idx] shard.RLock() for key, value := range shard.items { if !fn(key, value) { @@ -280,15 +306,15 @@ func (m ConcurrentMap) IterCb(fn IterCb) { } // Keys returns all keys as []string -func (m ConcurrentMap) Keys() []string { +func (m ConcurrentMap[K, V]) Keys() []K { count := m.Count() - ch := make(chan string, count) + ch := make(chan K, count) go func() { // Foreach shard. wg := sync.WaitGroup{} wg.Add(SHARD_COUNT) - for _, shard := range m { - go func(shard *ConcurrentMapShared) { + for _, shard := range m.shards { + go func(shard *ConcurrentMapShared[K, V]) { // Foreach key, value pair. shard.RLock() for key := range shard.items { @@ -303,17 +329,17 @@ func (m ConcurrentMap) Keys() []string { }() // Generate keys - keys := make([]string, 0, count) + keys := make([]K, 0, count) for k := range ch { keys = append(keys, k) } return keys } -//MarshalJSON reviles ConcurrentMap "private" variables to json marshal. -func (m ConcurrentMap) MarshalJSON() ([]byte, error) { +// Reviles ConcurrentMap "private" variables to json marshal. +func (m ConcurrentMap[K, V]) MarshalJSON() ([]byte, error) { // Create a temporary map, which will hold all item spread across shards. - tmp := make(map[string]interface{}) + tmp := make(map[K]V) // Insert items to temporary map. for item := range m.IterBuffered() { @@ -321,6 +347,9 @@ func (m ConcurrentMap) MarshalJSON() ([]byte, error) { } return json.Marshal(tmp) } +func strfnv32[K fmt.Stringer](key K) uint32 { + return fnv32(key.String()) +} func fnv32(key string) uint32 { hash := uint32(2166136261) @@ -333,24 +362,18 @@ func fnv32(key string) uint32 { return hash } -// Concurrent map uses Interface{} as its value, therefore JSON Unmarshal -// probably won't know which to type to unmarshal into, in such case -// we'll end up with a value of type map[string]interface{}, In most cases this isn't -// out value type, this is why we've decided to remove this functionality. +// Reverse process of Marshal. +func (m *ConcurrentMap[K, V]) UnmarshalJSON(b []byte) (err error) { + tmp := make(map[K]V) -// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) { -// // Reverse process of Marshal. + // Unmarshal into a single map. + if err := json.Unmarshal(b, &tmp); err != nil { + return err + } -// tmp := make(map[string]interface{}) - -// // Unmarshal into a single map. -// if err := json.Unmarshal(b, &tmp); err != nil { -// return nil -// } - -// // foreach key,value pair in temporary map insert into our concurrent map. -// for key, val := range tmp { -// m.Set(key, val) -// } -// return nil -// } + // foreach key,value pair in temporary map insert into our concurrent map. + for key, val := range tmp { + m.Set(key, val) + } + return nil +} diff --git a/utils/melody/hub.go b/utils/melody/hub.go index 1a9f2d2..8478d3b 100644 --- a/utils/melody/hub.go +++ b/utils/melody/hub.go @@ -5,7 +5,7 @@ import ( ) type hub struct { - sessions cmap.ConcurrentMap + sessions cmap.ConcurrentMap[string, *Session] queue chan *envelope register chan *Session unregister chan *Session @@ -15,7 +15,7 @@ type hub struct { func newHub() *hub { return &hub{ - sessions: cmap.New(), + sessions: cmap.New[*Session](), queue: make(chan *envelope), register: make(chan *Session), unregister: make(chan *Session), @@ -38,19 +38,16 @@ loop: if len(m.list) > 0 { for _, uuid := range m.list { if s, ok := h.sessions.Get(uuid); ok { - s := s.(*Session) s.writeMessage(m) } } } else if m.filter == nil { - h.sessions.IterCb(func(uuid string, v interface{}) bool { - s := v.(*Session) + h.sessions.IterCb(func(uuid string, s *Session) bool { s.writeMessage(m) return true }) } else { - h.sessions.IterCb(func(uuid string, v interface{}) bool { - s := v.(*Session) + h.sessions.IterCb(func(uuid string, s *Session) bool { if m.filter(s) { s.writeMessage(m) } @@ -60,8 +57,7 @@ loop: case m := <-h.exit: var keys []string h.open = false - h.sessions.IterCb(func(uuid string, v interface{}) bool { - s := v.(*Session) + h.sessions.IterCb(func(uuid string, s *Session) bool { s.writeMessage(m) s.Close() keys = append(keys, uuid) diff --git a/utils/melody/melody.go b/utils/melody/melody.go index 247e21b..cb311e8 100644 --- a/utils/melody/melody.go +++ b/utils/melody/melody.go @@ -301,27 +301,14 @@ func (m *Melody) SendMultiple(msg []byte, list []string) error { // GetSessionByUUID returns the session with specified uuid. func (m *Melody) GetSessionByUUID(uuid string) (*Session, bool) { - val, ok := m.hub.sessions.Get(uuid) - if !ok { - return nil, false - } - s, ok := val.(*Session) - if !ok { - m.hub.sessions.Remove(uuid) - } - return s, ok + return m.hub.sessions.Get(uuid) } // IterSessions iterates all sessions. func (m *Melody) IterSessions(fn func(uuid string, s *Session) bool) { var invalid []string - m.hub.sessions.IterCb(func(uuid string, v interface{}) bool { - if s, ok := v.(*Session); !ok { - invalid = append(invalid, uuid) - return true - } else { - return fn(uuid, s) - } + m.hub.sessions.IterCb(func(uuid string, s *Session) bool { + return fn(uuid, s) }) m.hub.sessions.Remove(invalid...) } diff --git a/web/src/pages/overview.jsx b/web/src/pages/overview.jsx index 1d88f4e..2789fb1 100644 --- a/web/src/pages/overview.jsx +++ b/web/src/pages/overview.jsx @@ -394,7 +394,7 @@ function overview(props) { <> { URL.revokeObjectURL(screenBlob);