optimize: update concurrent_map

This commit is contained in:
XZB-1248
2023-05-21 23:11:09 +08:00
parent 76bfa6c1c6
commit bb68f494fd
15 changed files with 191 additions and 215 deletions

2
.gitignore vendored
View File

@@ -3,6 +3,6 @@
/tools /tools
/logs /logs
/.idea /.idea
/Config.json /config.json
dist/ dist/
node_modules/ node_modules/

View File

@@ -62,7 +62,7 @@ const imageQuality = 70
var lock = &sync.Mutex{} var lock = &sync.Mutex{}
var working = false var working = false
var sessions = cmap.New() var sessions = cmap.New[*session]()
var prevDesktop *image.RGBA var prevDesktop *image.RGBA
var displayBounds image.Rectangle var displayBounds image.Rectangle
var errNoImage = errors.New(`DESKTOP.NO_IMAGE_YET`) var errNoImage = errors.New(`DESKTOP.NO_IMAGE_YET`)
@@ -127,8 +127,7 @@ func worker() {
} }
func sendImageDiff(diff []*[]byte) { func sendImageDiff(diff []*[]byte) {
sessions.IterCb(func(uuid string, t any) bool { sessions.IterCb(func(uuid string, desktop *session) bool {
desktop := t.(*session)
desktop.lock.Lock() desktop.lock.Lock()
if !desktop.escape { if !desktop.escape {
if len(desktop.channel) >= frameBuffer { if len(desktop.channel) >= frameBuffer {
@@ -146,9 +145,8 @@ func sendImageDiff(diff []*[]byte) {
func quitAllDesktop(info string) { func quitAllDesktop(info string) {
keys := make([]string, 0) keys := make([]string, 0)
sessions.IterCb(func(uuid string, t any) bool { sessions.IterCb(func(uuid string, desktop *session) bool {
keys = append(keys, uuid) keys = append(keys, uuid)
desktop := t.(*session)
desktop.escape = true desktop.escape = true
desktop.channel <- message{t: 1, info: info} desktop.channel <- message{t: 1, info: info}
return true return true
@@ -346,24 +344,23 @@ func PingDesktop(pack modules.Packet) {
} else { } else {
uuid = val.(string) uuid = val.(string)
} }
if val, ok := sessions.Get(uuid); ok { desktop, ok := sessions.Get(uuid)
desktop = val.(*session) if !ok {
desktop.lastPack = utils.Unix return
} }
desktop.lastPack = utils.Unix
} }
func KillDesktop(pack modules.Packet) { func KillDesktop(pack modules.Packet) {
var uuid string var uuid string
var desktop *session
if val, ok := pack.GetData(`desktop`, reflect.String); !ok { if val, ok := pack.GetData(`desktop`, reflect.String); !ok {
return return
} else { } else {
uuid = val.(string) uuid = val.(string)
} }
if val, ok := sessions.Get(uuid); !ok { desktop, ok := sessions.Get(uuid)
if !ok {
return return
} else {
desktop = val.(*session)
} }
sessions.Remove(uuid) sessions.Remove(uuid)
data, _ := utils.JSON.Marshal(modules.Packet{Act: `DESKTOP_QUIT`, Msg: `${i18n|DESKTOP.SESSION_CLOSED}`}) data, _ := utils.JSON.Marshal(modules.Packet{Act: `DESKTOP_QUIT`, Msg: `${i18n|DESKTOP.SESSION_CLOSED}`})
@@ -383,10 +380,9 @@ func GetDesktop(pack modules.Packet) {
} else { } else {
uuid = val.(string) uuid = val.(string)
} }
if val, ok := sessions.Get(uuid); !ok { desktop, ok := sessions.Get(uuid)
if !ok {
return return
} else {
desktop = val.(*session)
} }
if !desktop.escape { if !desktop.escape {
lock.Lock() lock.Lock()
@@ -450,8 +446,7 @@ func healthCheck() {
timestamp := now.Unix() timestamp := now.Unix()
// stores sessions to be disconnected // stores sessions to be disconnected
keys := make([]string, 0) keys := make([]string, 0)
sessions.IterCb(func(uuid string, t any) bool { sessions.IterCb(func(uuid string, desktop *session) bool {
desktop := t.(*session)
if timestamp-desktop.lastPack > MaxInterval { if timestamp-desktop.lastPack > MaxInterval {
keys = append(keys, uuid) keys = append(keys, uuid)
} }

View File

@@ -1,11 +1,9 @@
package terminal package terminal
import ( import (
"Spark/utils/cmap"
"errors" "errors"
) )
var terminals = cmap.New()
var ( var (
errDataNotFound = errors.New(`no input found in packet`) errDataNotFound = errors.New(`no input found in packet`)
errDataInvalid = errors.New(`can not parse data in packet`) errDataInvalid = errors.New(`can not parse data in packet`)

View File

@@ -6,6 +6,7 @@ import (
"Spark/client/common" "Spark/client/common"
"Spark/modules" "Spark/modules"
"Spark/utils" "Spark/utils"
"Spark/utils/cmap"
"encoding/hex" "encoding/hex"
"github.com/creack/pty" "github.com/creack/pty"
"os" "os"
@@ -23,6 +24,7 @@ type terminal struct {
cmd *exec.Cmd cmd *exec.Cmd
} }
var terminals = cmap.New[*terminal]()
var defaultShell = `` var defaultShell = ``
func init() { func init() {
@@ -88,11 +90,8 @@ func InitTerminal(pack modules.Packet) error {
} }
func InputRawTerminal(input []byte, uuid string) { func InputRawTerminal(input []byte, uuid string) {
var session *terminal session, ok := terminals.Get(uuid)
if !ok {
if val, ok := terminals.Get(uuid); ok {
session = val.(*terminal)
} else {
return return
} }
session.pty.Write(input) session.pty.Write(input)
@@ -165,10 +164,9 @@ func KillTerminal(pack modules.Packet) {
} else { } else {
uuid = val.(string) uuid = val.(string)
} }
if val, ok := terminals.Get(uuid); !ok { session, ok := terminals.Get(uuid)
if !ok {
return return
} else {
session = val.(*terminal)
} }
terminals.Remove(uuid) terminals.Remove(uuid)
data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`}) data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`})
@@ -180,18 +178,16 @@ func KillTerminal(pack modules.Packet) {
func PingTerminal(pack modules.Packet) { func PingTerminal(pack modules.Packet) {
var termUUID string var termUUID string
var termSession *terminal
if val, ok := pack.GetData(`terminal`, reflect.String); !ok { if val, ok := pack.GetData(`terminal`, reflect.String); !ok {
return return
} else { } else {
termUUID = val.(string) termUUID = val.(string)
} }
if val, ok := terminals.Get(termUUID); !ok { session, ok := terminals.Get(termUUID)
if !ok {
return return
} else {
termSession = val.(*terminal)
termSession.lastPack = utils.Unix
} }
session.lastPack = utils.Unix
} }
func doKillTerminal(terminal *terminal) { func doKillTerminal(terminal *terminal) {
@@ -234,11 +230,10 @@ func healthCheck() {
timestamp := now.Unix() timestamp := now.Unix()
// stores sessions to be disconnected // stores sessions to be disconnected
queue := make([]string, 0) queue := make([]string, 0)
terminals.IterCb(func(uuid string, t any) bool { terminals.IterCb(func(uuid string, session *terminal) bool {
termSession := t.(*terminal) if timestamp-session.lastPack > MaxInterval {
if timestamp-termSession.lastPack > MaxInterval {
queue = append(queue, uuid) queue = append(queue, uuid)
doKillTerminal(termSession) doKillTerminal(session)
} }
return true return true
}) })

View File

@@ -4,6 +4,7 @@ import (
"Spark/client/common" "Spark/client/common"
"Spark/modules" "Spark/modules"
"Spark/utils" "Spark/utils"
"Spark/utils/cmap"
"encoding/hex" "encoding/hex"
"io" "io"
"os/exec" "os/exec"
@@ -23,6 +24,7 @@ type terminal struct {
stdin *io.WriteCloser stdin *io.WriteCloser
} }
var terminals = cmap.New[*terminal]()
var defaultCmd = `` var defaultCmd = ``
func init() { func init() {
@@ -108,11 +110,8 @@ func InitTerminal(pack modules.Packet) error {
} }
func InputRawTerminal(input []byte, uuid string) { func InputRawTerminal(input []byte, uuid string) {
var session *terminal session, ok := terminals.Get(uuid)
if !ok {
if val, ok := terminals.Get(uuid); ok {
session = val.(*terminal)
} else {
return return
} }
(*session.stdin).Write(input) (*session.stdin).Write(input)
@@ -152,16 +151,14 @@ func ResizeTerminal(pack modules.Packet) error {
func KillTerminal(pack modules.Packet) { func KillTerminal(pack modules.Packet) {
var uuid string var uuid string
var session *terminal
if val, ok := pack.GetData(`terminal`, reflect.String); !ok { if val, ok := pack.GetData(`terminal`, reflect.String); !ok {
return return
} else { } else {
uuid = val.(string) uuid = val.(string)
} }
if val, ok := terminals.Get(uuid); !ok { session, ok := terminals.Get(uuid)
if !ok {
return return
} else {
session = val.(*terminal)
} }
terminals.Remove(uuid) terminals.Remove(uuid)
data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`}) data, _ := utils.JSON.Marshal(modules.Packet{Act: `TERMINAL_QUIT`, Msg: `${i18n|TERMINAL.SESSION_CLOSED}`})
@@ -179,12 +176,11 @@ func PingTerminal(pack modules.Packet) {
} else { } else {
uuid = val.(string) uuid = val.(string)
} }
if val, ok := terminals.Get(uuid); !ok { session, ok := terminals.Get(uuid)
if !ok {
return return
} else {
session = val.(*terminal)
session.lastPack = utils.Unix
} }
session.lastPack = utils.Unix
} }
func doKillTerminal(terminal *terminal) { func doKillTerminal(terminal *terminal) {
@@ -221,11 +217,10 @@ func healthCheck() {
timestamp := now.Unix() timestamp := now.Unix()
// stores sessions to be disconnected // stores sessions to be disconnected
keys := make([]string, 0) keys := make([]string, 0)
terminals.IterCb(func(uuid string, t any) bool { terminals.IterCb(func(uuid string, session *terminal) bool {
termSession := t.(*terminal) if timestamp-session.lastPack > MaxInterval {
if timestamp-termSession.lastPack > MaxInterval {
keys = append(keys, uuid) keys = append(keys, uuid)
doKillTerminal(termSession) doKillTerminal(session)
} }
return true return true
}) })

View File

@@ -17,7 +17,7 @@ import (
const MaxMessageSize = (2 << 15) + 1024 const MaxMessageSize = (2 << 15) + 1024
var Melody = melody.New() var Melody = melody.New()
var Devices = cmap.New() var Devices = cmap.New[*modules.Device]()
func SendPackByUUID(pack modules.Packet, uuid string) bool { func SendPackByUUID(pack modules.Packet, uuid string) bool {
session, ok := Melody.GetSessionByUUID(uuid) session, ok := Melody.GetSessionByUUID(uuid)
@@ -164,8 +164,7 @@ func CheckDevice(deviceID, connUUID string) (string, bool) {
} }
} else { } else {
tempConnUUID := `` tempConnUUID := ``
Devices.IterCb(func(uuid string, v any) bool { Devices.IterCb(func(uuid string, device *modules.Device) bool {
device := v.(*modules.Device)
if device.ID == deviceID { if device.ID == deviceID {
tempConnUUID = uuid tempConnUUID = uuid
return false return false

View File

@@ -15,7 +15,7 @@ type event struct {
remove chan bool remove chan bool
} }
var events = cmap.New() var events = cmap.New[*event]()
// CallEvent tries to call the callback with the given uuid // CallEvent tries to call the callback with the given uuid
// after that, it will notify the caller via the channel // after that, it will notify the caller via the channel
@@ -23,11 +23,10 @@ func CallEvent(pack modules.Packet, session *melody.Session) {
if len(pack.Event) == 0 { if len(pack.Event) == 0 {
return return
} }
v, ok := events.Get(pack.Event) ev, ok := events.Get(pack.Event)
if !ok { if !ok {
return return
} }
ev := v.(*event)
if session != nil && session.UUID != ev.connection { if session != nil && session.UUID != ev.connection {
return return
} }
@@ -76,12 +75,11 @@ func AddEvent(fn EventCallback, connUUID, trigger string) {
// RemoveEvent deletes the event with the given event trigger. // RemoveEvent deletes the event with the given event trigger.
// The ok will be returned to caller if the event is temp (only once). // The ok will be returned to caller if the event is temp (only once).
func RemoveEvent(trigger string, ok ...bool) { func RemoveEvent(trigger string, ok ...bool) {
v, found := events.Get(trigger) ev, found := events.Get(trigger)
if !found { if !found {
return return
} }
events.Remove(trigger) events.Remove(trigger)
ev := v.(*event)
if ev.remove != nil { if ev.remove != nil {
if len(ok) > 0 { if len(ok) > 0 {
ev.remove <- ok[0] ev.remove <- ok[0]
@@ -89,7 +87,6 @@ func RemoveEvent(trigger string, ok ...bool) {
ev.remove <- false ev.remove <- false
} }
} }
v = nil
ev = nil ev = nil
} }

View File

@@ -1,7 +1,6 @@
package common package common
import ( import (
"Spark/modules"
"Spark/server/config" "Spark/server/config"
"Spark/utils" "Spark/utils"
"Spark/utils/melody" "Spark/utils/melody"
@@ -81,9 +80,8 @@ func getLog(ctx any, event, status, msg string, args map[string]any) string {
} }
} }
if targetInfo { if targetInfo {
val, ok := Devices.Get(connUUID) device, ok := Devices.Get(connUUID)
if ok { if ok {
device := val.(*modules.Device)
args[`target`] = map[string]any{ args[`target`] = map[string]any{
`name`: device.Hostname, `name`: device.Hostname,
`ip`: device.WAN, `ip`: device.WAN,

View File

@@ -28,15 +28,14 @@ type Bridge struct {
OnFinish func(bridge *Bridge) OnFinish func(bridge *Bridge)
} }
var bridges = cmap.New() var bridges = cmap.New[*Bridge]()
func init() { func init() {
go func() { go func() {
for now := range time.NewTicker(15 * time.Second).C { for now := range time.NewTicker(15 * time.Second).C {
var queue []string var queue []string
timestamp := now.Unix() timestamp := now.Unix()
bridges.IterCb(func(k string, v any) bool { bridges.IterCb(func(k string, b *Bridge) bool {
b := v.(*Bridge)
if timestamp-b.creation > 60 && !b.using { if timestamp-b.creation > 60 && !b.using {
b.lock.Lock() b.lock.Lock()
if b.Src != nil && b.Src.Request.Body != nil { if b.Src != nil && b.Src.Request.Body != nil {
@@ -63,12 +62,12 @@ func CheckBridge(ctx *gin.Context) *Bridge {
ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_PARAMETER}`}) ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_PARAMETER}`})
return nil return nil
} }
val, ok := bridges.Get(form.Bridge) b, ok := bridges.Get(form.Bridge)
if !ok { if !ok {
ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_BRIDGE_ID}`}) ctx.AbortWithStatusJSON(http.StatusBadRequest, modules.Packet{Code: -1, Msg: `${i18n|COMMON.INVALID_BRIDGE_ID}`})
return nil return nil
} }
return val.(*Bridge) return b
} }
func BridgePush(ctx *gin.Context) { func BridgePush(ctx *gin.Context) {
@@ -218,12 +217,11 @@ func AddBridgeWithDst(ext any, uuid string, Dst *gin.Context) *Bridge {
} }
func RemoveBridge(uuid string) { func RemoveBridge(uuid string) {
val, ok := bridges.Get(uuid) b, ok := bridges.Get(uuid)
if !ok { if !ok {
return return
} }
bridges.Remove(uuid) bridges.Remove(uuid)
b := val.(*Bridge)
if b.Src != nil && b.Src.Request.Body != nil { if b.Src != nil && b.Src.Request.Body != nil {
b.Src.Request.Body.Close() b.Src.Request.Body.Close()
} }

View File

@@ -70,8 +70,7 @@ func OnDevicePack(data []byte, session *melody.Session) error {
// If so, then find the session and let client quit. // If so, then find the session and let client quit.
// This will keep only one connection remained per device. // This will keep only one connection remained per device.
exSession := `` exSession := ``
common.Devices.IterCb(func(uuid string, v any) bool { common.Devices.IterCb(func(uuid string, device *modules.Device) bool {
device := v.(*modules.Device)
if device.ID == pack.Device.ID { if device.ID == pack.Device.ID {
exSession = uuid exSession = uuid
target, ok := common.Melody.GetSessionByUUID(uuid) target, ok := common.Melody.GetSessionByUUID(uuid)
@@ -94,14 +93,13 @@ func OnDevicePack(data []byte, session *melody.Session) error {
}, },
}) })
} else { } else {
val, ok := common.Devices.Get(session.UUID) device, ok := common.Devices.Get(session.UUID)
if ok { if ok {
deviceInfo := val.(*modules.Device) device.CPU = pack.Device.CPU
deviceInfo.CPU = pack.Device.CPU device.RAM = pack.Device.RAM
deviceInfo.RAM = pack.Device.RAM device.Net = pack.Device.Net
deviceInfo.Net = pack.Device.Net device.Disk = pack.Device.Disk
deviceInfo.Disk = pack.Device.Disk device.Uptime = pack.Device.Uptime
deviceInfo.Uptime = pack.Device.Uptime
} }
} }
common.SendPack(modules.Packet{Code: 0}, session) common.SendPack(modules.Packet{Code: 0}, session)
@@ -268,8 +266,7 @@ func ExecDeviceCmd(ctx *gin.Context) {
// GetDevices will return all info about all clients. // GetDevices will return all info about all clients.
func GetDevices(ctx *gin.Context) { func GetDevices(ctx *gin.Context) {
devices := map[string]any{} devices := map[string]any{}
common.Devices.IterCb(func(uuid string, v any) bool { common.Devices.IterCb(func(uuid string, device *modules.Device) bool {
device := v.(*modules.Device)
devices[uuid] = *device devices[uuid] = *device
return true return true
}) })

View File

@@ -32,7 +32,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var blocked = cmap.New() var blocked = cmap.New[int64]()
var lastRequest = time.Now().Unix() var lastRequest = time.Now().Unix()
func main() { func main() {
@@ -212,14 +212,13 @@ func wsOnMessageBinary(session *melody.Session, data []byte) {
} }
func wsOnDisconnect(session *melody.Session) { func wsOnDisconnect(session *melody.Session) {
if val, ok := common.Devices.Get(session.UUID); ok { if device, ok := common.Devices.Get(session.UUID); ok {
deviceInfo := val.(*modules.Device) terminal.CloseSessionsByDevice(device.ID)
terminal.CloseSessionsByDevice(deviceInfo.ID) desktop.CloseSessionsByDevice(device.ID)
desktop.CloseSessionsByDevice(deviceInfo.ID)
common.Info(nil, `CLIENT_OFFLINE`, ``, ``, map[string]any{ common.Info(nil, `CLIENT_OFFLINE`, ``, ``, map[string]any{
`device`: map[string]any{ `device`: map[string]any{
`name`: deviceInfo.Hostname, `name`: device.Hostname,
`ip`: deviceInfo.WAN, `ip`: device.WAN,
}, },
}) })
} else { } else {
@@ -289,10 +288,9 @@ func pingDevice(s *melody.Session) {
trigger := utils.GetStrUUID() trigger := utils.GetStrUUID()
common.SendPack(modules.Packet{Act: `PING`, Event: trigger}, s) common.SendPack(modules.Packet{Act: `PING`, Event: trigger}, s)
common.AddEventOnce(func(packet modules.Packet, session *melody.Session) { common.AddEventOnce(func(packet modules.Packet, session *melody.Session) {
val, ok := common.Devices.Get(s.UUID) device, ok := common.Devices.Get(s.UUID)
if ok { if ok {
deviceInfo := val.(*modules.Device) device.Latency = uint(time.Now().UnixMilli()-t) / 2
deviceInfo.Latency = uint(time.Now().UnixMilli()-t) / 2
} }
}, s.UUID, trigger, 3*time.Second) }, s.UUID, trigger, 3*time.Second)
} }
@@ -300,12 +298,12 @@ func pingDevice(s *melody.Session) {
func checkAuth() gin.HandlerFunc { func checkAuth() gin.HandlerFunc {
// Token as key and update timestamp as value. // Token as key and update timestamp as value.
// Stores authenticated tokens. // Stores authenticated tokens.
tokens := cmap.New() tokens := cmap.New[int64]()
go func() { go func() {
for now := range time.NewTicker(60 * time.Second).C { for now := range time.NewTicker(60 * time.Second).C {
var queue []string var queue []string
tokens.IterCb(func(key string, v any) bool { tokens.IterCb(func(key string, t int64) bool {
if now.Unix()-v.(int64) > 1800 { if now.Unix()-t > 1800 {
queue = append(queue, key) queue = append(queue, key)
} }
return true return true
@@ -313,8 +311,8 @@ func checkAuth() gin.HandlerFunc {
tokens.Remove(queue...) tokens.Remove(queue...)
queue = nil queue = nil
blocked.IterCb(func(addr string, v any) bool { blocked.IterCb(func(addr string, t int64) bool {
if now.Unix() > v.(int64) { if now.Unix() > t {
queue = append(queue, addr) queue = append(queue, addr)
} }
return true return true
@@ -347,7 +345,7 @@ func checkAuth() gin.HandlerFunc {
if !passed { if !passed {
addr := common.GetRealIP(ctx) addr := common.GetRealIP(ctx)
if expire, ok := blocked.Get(addr); ok { if expire, ok := blocked.Get(addr); ok {
if now < expire.(int64) { if now < expire {
ctx.AbortWithStatusJSON(http.StatusTooManyRequests, modules.Packet{Code: 1}) ctx.AbortWithStatusJSON(http.StatusTooManyRequests, modules.Packet{Code: 1})
return return
} }

View File

@@ -2,36 +2,62 @@ package cmap
import ( import (
"encoding/json" "encoding/json"
"fmt"
"sync" "sync"
) )
const SHARD_COUNT = 32 var SHARD_COUNT = 32
// ConcurrentMap is a "thread" safe map of type string:Anything. type Stringer interface {
fmt.Stringer
comparable
}
// A "thread" safe map of type string:Anything.
// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. // To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
type ConcurrentMap []*ConcurrentMapShared type ConcurrentMap[K comparable, V any] struct {
shards []*ConcurrentMapShared[K, V]
sharding func(key K) uint32
}
// ConcurrentMapShared is a "thread" safe string to anything map. // A "thread" safe string to anything map.
type ConcurrentMapShared struct { type ConcurrentMapShared[K comparable, V any] struct {
items map[string]interface{} items map[K]V
sync.RWMutex // Read Write mutex, guards access to internal map. sync.RWMutex // Read Write mutex, guards access to internal map.
} }
// New creates a new concurrent map. func create[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
func New() ConcurrentMap { m := ConcurrentMap[K, V]{
m := make(ConcurrentMap, SHARD_COUNT) sharding: sharding,
shards: make([]*ConcurrentMapShared[K, V], SHARD_COUNT),
}
for i := 0; i < SHARD_COUNT; i++ { for i := 0; i < SHARD_COUNT; i++ {
m[i] = &ConcurrentMapShared{items: make(map[string]interface{})} m.shards[i] = &ConcurrentMapShared[K, V]{items: make(map[K]V)}
} }
return m return m
} }
// GetShard returns shard under given key // Creates a new concurrent map.
func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared { func New[V any]() ConcurrentMap[string, V] {
return m[uint(fnv32(key))%uint(SHARD_COUNT)] return create[string, V](fnv32)
} }
func (m ConcurrentMap) MSet(data map[string]interface{}) { // Creates a new concurrent map.
func NewStringer[K Stringer, V any]() ConcurrentMap[K, V] {
return create[K, V](strfnv32[K])
}
// Creates a new concurrent map.
func NewWithCustomShardingFunction[K comparable, V any](sharding func(key K) uint32) ConcurrentMap[K, V] {
return create[K, V](sharding)
}
// GetShard returns shard under given key
func (m ConcurrentMap[K, V]) GetShard(key K) *ConcurrentMapShared[K, V] {
return m.shards[uint(m.sharding(key))%uint(SHARD_COUNT)]
}
func (m ConcurrentMap[K, V]) MSet(data map[K]V) {
for key, value := range data { for key, value := range data {
shard := m.GetShard(key) shard := m.GetShard(key)
shard.Lock() shard.Lock()
@@ -40,8 +66,8 @@ func (m ConcurrentMap) MSet(data map[string]interface{}) {
} }
} }
// Set sets the given value under the specified key. // Sets the given value under the specified key.
func (m ConcurrentMap) Set(key string, value interface{}) { func (m ConcurrentMap[K, V]) Set(key K, value V) {
// Get map shard. // Get map shard.
shard := m.GetShard(key) shard := m.GetShard(key)
shard.Lock() shard.Lock()
@@ -49,14 +75,14 @@ func (m ConcurrentMap) Set(key string, value interface{}) {
shard.Unlock() shard.Unlock()
} }
// UpsertCb is callback to return new element to be inserted into the map // Callback to return new element to be inserted into the map
// It is called while lock is held, therefore it MUST NOT // It is called while lock is held, therefore it MUST NOT
// try to access other keys in same map, as it can lead to deadlock since // try to access other keys in same map, as it can lead to deadlock since
// Go sync.RWLock is not reentrant // Go sync.RWLock is not reentrant
type UpsertCb func(exist bool, valueInMap interface{}, newValue interface{}) interface{} type UpsertCb[V any] func(exist bool, valueInMap V, newValue V) V
// Upsert means Insert or Update - updates existing element or inserts a new one using UpsertCb // Insert or Update - updates existing element or inserts a new one using UpsertCb
func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res interface{}) { func (m ConcurrentMap[K, V]) Upsert(key K, value V, cb UpsertCb[V]) (res V) {
shard := m.GetShard(key) shard := m.GetShard(key)
shard.Lock() shard.Lock()
v, ok := shard.items[key] v, ok := shard.items[key]
@@ -66,8 +92,8 @@ func (m ConcurrentMap) Upsert(key string, value interface{}, cb UpsertCb) (res i
return res return res
} }
// SetIfAbsent sets the given value under the specified key if no value was associated with it. // Sets the given value under the specified key if no value was associated with it.
func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool { func (m ConcurrentMap[K, V]) SetIfAbsent(key K, value V) bool {
// Get map shard. // Get map shard.
shard := m.GetShard(key) shard := m.GetShard(key)
shard.Lock() shard.Lock()
@@ -80,7 +106,7 @@ func (m ConcurrentMap) SetIfAbsent(key string, value interface{}) bool {
} }
// Get retrieves an element from map under given key. // Get retrieves an element from map under given key.
func (m ConcurrentMap) Get(key string) (interface{}, bool) { func (m ConcurrentMap[K, V]) Get(key K) (V, bool) {
// Get shard // Get shard
shard := m.GetShard(key) shard := m.GetShard(key)
shard.RLock() shard.RLock()
@@ -91,10 +117,10 @@ func (m ConcurrentMap) Get(key string) (interface{}, bool) {
} }
// Count returns the number of elements within the map. // Count returns the number of elements within the map.
func (m ConcurrentMap) Count() int { func (m ConcurrentMap[K, V]) Count() int {
count := 0 count := 0
for i := 0; i < SHARD_COUNT; i++ { for i := 0; i < SHARD_COUNT; i++ {
shard := m[i] shard := m.shards[i]
shard.RLock() shard.RLock()
count += len(shard.items) count += len(shard.items)
shard.RUnlock() shard.RUnlock()
@@ -102,8 +128,8 @@ func (m ConcurrentMap) Count() int {
return count return count
} }
// Has looks up an item under specified key // Looks up an item under specified key
func (m ConcurrentMap) Has(key string) bool { func (m ConcurrentMap[K, V]) Has(key K) bool {
// Get shard // Get shard
shard := m.GetShard(key) shard := m.GetShard(key)
shard.RLock() shard.RLock()
@@ -114,9 +140,9 @@ func (m ConcurrentMap) Has(key string) bool {
} }
// Remove removes an element from the map. // Remove removes an element from the map.
func (m ConcurrentMap) Remove(key ...string) { func (m ConcurrentMap[K, V]) Remove(keys ...K) {
// Try to get shard. // Try to get shard.
for _, k := range key { for _, k := range keys {
shard := m.GetShard(k) shard := m.GetShard(k)
shard.Lock() shard.Lock()
delete(shard.items, k) delete(shard.items, k)
@@ -126,12 +152,12 @@ func (m ConcurrentMap) Remove(key ...string) {
// RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held // RemoveCb is a callback executed in a map.RemoveCb() call, while Lock is held
// If returns true, the element will be removed from the map // If returns true, the element will be removed from the map
type RemoveCb func(key string, v interface{}, exists bool) bool type RemoveCb[K any, V any] func(key K, v V, exists bool) bool
// RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params // RemoveCb locks the shard containing the key, retrieves its current value and calls the callback with those params
// If callback returns true and element exists, it will remove it from the map // If callback returns true and element exists, it will remove it from the map
// Returns the value returned by the callback (even if element was not present in the map) // Returns the value returned by the callback (even if element was not present in the map)
func (m ConcurrentMap) RemoveCb(key string, cb RemoveCb) bool { func (m ConcurrentMap[K, V]) RemoveCb(key K, cb RemoveCb[K, V]) bool {
// Try to get shard. // Try to get shard.
shard := m.GetShard(key) shard := m.GetShard(key)
shard.Lock() shard.Lock()
@@ -145,7 +171,7 @@ func (m ConcurrentMap) RemoveCb(key string, cb RemoveCb) bool {
} }
// Pop removes an element from the map and returns it // Pop removes an element from the map and returns it
func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) { func (m ConcurrentMap[K, V]) Pop(key K) (v V, exists bool) {
// Try to get shard. // Try to get shard.
shard := m.GetShard(key) shard := m.GetShard(key)
shard.Lock() shard.Lock()
@@ -156,66 +182,66 @@ func (m ConcurrentMap) Pop(key string) (v interface{}, exists bool) {
} }
// IsEmpty checks if map is empty. // IsEmpty checks if map is empty.
func (m ConcurrentMap) IsEmpty() bool { func (m ConcurrentMap[K, V]) IsEmpty() bool {
return m.Count() == 0 return m.Count() == 0
} }
// Tuple is used by the Iter & IterBuffered functions to wrap two variables together over a channel, // Used by the Iter & IterBuffered functions to wrap two variables together over a channel,
type Tuple struct { type Tuple[K comparable, V any] struct {
Key string Key K
Val interface{} Val V
} }
// Iter returns an iterator which could be used in a for range loop. // Iter returns an iterator which could be used in a for range loop.
// //
// Deprecated: using IterBuffered() will get a better performance // Deprecated: using IterBuffered() will get a better performence
func (m ConcurrentMap) Iter() <-chan Tuple { func (m ConcurrentMap[K, V]) Iter() <-chan Tuple[K, V] {
chans := snapshot(m) chans := snapshot(m)
ch := make(chan Tuple) ch := make(chan Tuple[K, V])
go fanIn(chans, ch) go fanIn(chans, ch)
return ch return ch
} }
// IterBuffered returns a buffered iterator which could be used in a for range loop. // IterBuffered returns a buffered iterator which could be used in a for range loop.
func (m ConcurrentMap) IterBuffered() <-chan Tuple { func (m ConcurrentMap[K, V]) IterBuffered() <-chan Tuple[K, V] {
chans := snapshot(m) chans := snapshot(m)
total := 0 total := 0
for _, c := range chans { for _, c := range chans {
total += cap(c) total += cap(c)
} }
ch := make(chan Tuple, total) ch := make(chan Tuple[K, V], total)
go fanIn(chans, ch) go fanIn(chans, ch)
return ch return ch
} }
// Clear removes all items from map. // Clear removes all items from map.
func (m ConcurrentMap) Clear() { func (m ConcurrentMap[K, V]) Clear() {
for item := range m.IterBuffered() { for item := range m.IterBuffered() {
m.Remove(item.Key) m.Remove(item.Key)
} }
} }
// Returns an array of channels that contains elements in each shard, // Returns a array of channels that contains elements in each shard,
// which likely takes a snapshot of `m`. // which likely takes a snapshot of `m`.
// It returns once the size of each buffered channel is determined, // It returns once the size of each buffered channel is determined,
// before all the channels are populated using goroutines. // before all the channels are populated using goroutines.
func snapshot(m ConcurrentMap) (chans []chan Tuple) { func snapshot[K comparable, V any](m ConcurrentMap[K, V]) (chans []chan Tuple[K, V]) {
//When you access map items before initializing. //When you access map items before initializing.
if len(m) == 0 { if len(m.shards) == 0 {
panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`) panic(`cmap.ConcurrentMap is not initialized. Should run New() before usage.`)
} }
chans = make([]chan Tuple, SHARD_COUNT) chans = make([]chan Tuple[K, V], SHARD_COUNT)
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT) wg.Add(SHARD_COUNT)
// Foreach shard. // Foreach shard.
for index, shard := range m { for index, shard := range m.shards {
go func(index int, shard *ConcurrentMapShared) { go func(index int, shard *ConcurrentMapShared[K, V]) {
// Foreach key, value pair. // Foreach key, value pair.
shard.RLock() shard.RLock()
chans[index] = make(chan Tuple, len(shard.items)) chans[index] = make(chan Tuple[K, V], len(shard.items))
wg.Done() wg.Done()
for key, val := range shard.items { for key, val := range shard.items {
chans[index] <- Tuple{key, val} chans[index] <- Tuple[K, V]{key, val}
} }
shard.RUnlock() shard.RUnlock()
close(chans[index]) close(chans[index])
@@ -226,11 +252,11 @@ func snapshot(m ConcurrentMap) (chans []chan Tuple) {
} }
// fanIn reads elements from channels `chans` into channel `out` // fanIn reads elements from channels `chans` into channel `out`
func fanIn(chans []chan Tuple, out chan Tuple) { func fanIn[K comparable, V any](chans []chan Tuple[K, V], out chan Tuple[K, V]) {
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(len(chans)) wg.Add(len(chans))
for _, ch := range chans { for _, ch := range chans {
go func(ch chan Tuple) { go func(ch chan Tuple[K, V]) {
for t := range ch { for t := range ch {
out <- t out <- t
} }
@@ -241,9 +267,9 @@ func fanIn(chans []chan Tuple, out chan Tuple) {
close(out) close(out)
} }
// Items returns all items as map[string]interface{} // Items returns all items as map[string]V
func (m ConcurrentMap) Items() map[string]interface{} { func (m ConcurrentMap[K, V]) Items() map[K]V {
tmp := make(map[string]interface{}) tmp := make(map[K]V)
// Insert items to temporary map. // Insert items to temporary map.
for item := range m.IterBuffered() { for item := range m.IterBuffered() {
@@ -253,18 +279,18 @@ func (m ConcurrentMap) Items() map[string]interface{} {
return tmp return tmp
} }
// IterCb is iterator callback, called for every key,value found in // Iterator callbacalled for every key,value found in
// maps. RLock is held for all calls for a given shard // maps. RLock is held for all calls for a given shard
// therefore callback sess consistent view of a shard, // therefore callback sess consistent view of a shard,
// but not across the shards // but not across the shards
type IterCb func(key string, v interface{}) bool type IterCb[K comparable, V any] func(key K, v V) bool
// IterCb callback based iterator, the cheapest way to read // Callback based iterator, cheapest way to read
// all elements in a map. // all elements in a map.
func (m ConcurrentMap) IterCb(fn IterCb) { func (m ConcurrentMap[K, V]) IterCb(fn IterCb[K, V]) {
escape := false escape := false
for idx := range m { for idx := range m.shards {
shard := (m)[idx] shard := (m.shards)[idx]
shard.RLock() shard.RLock()
for key, value := range shard.items { for key, value := range shard.items {
if !fn(key, value) { if !fn(key, value) {
@@ -280,15 +306,15 @@ func (m ConcurrentMap) IterCb(fn IterCb) {
} }
// Keys returns all keys as []string // Keys returns all keys as []string
func (m ConcurrentMap) Keys() []string { func (m ConcurrentMap[K, V]) Keys() []K {
count := m.Count() count := m.Count()
ch := make(chan string, count) ch := make(chan K, count)
go func() { go func() {
// Foreach shard. // Foreach shard.
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
wg.Add(SHARD_COUNT) wg.Add(SHARD_COUNT)
for _, shard := range m { for _, shard := range m.shards {
go func(shard *ConcurrentMapShared) { go func(shard *ConcurrentMapShared[K, V]) {
// Foreach key, value pair. // Foreach key, value pair.
shard.RLock() shard.RLock()
for key := range shard.items { for key := range shard.items {
@@ -303,17 +329,17 @@ func (m ConcurrentMap) Keys() []string {
}() }()
// Generate keys // Generate keys
keys := make([]string, 0, count) keys := make([]K, 0, count)
for k := range ch { for k := range ch {
keys = append(keys, k) keys = append(keys, k)
} }
return keys return keys
} }
//MarshalJSON reviles ConcurrentMap "private" variables to json marshal. // Reviles ConcurrentMap "private" variables to json marshal.
func (m ConcurrentMap) MarshalJSON() ([]byte, error) { func (m ConcurrentMap[K, V]) MarshalJSON() ([]byte, error) {
// Create a temporary map, which will hold all item spread across shards. // Create a temporary map, which will hold all item spread across shards.
tmp := make(map[string]interface{}) tmp := make(map[K]V)
// Insert items to temporary map. // Insert items to temporary map.
for item := range m.IterBuffered() { for item := range m.IterBuffered() {
@@ -321,6 +347,9 @@ func (m ConcurrentMap) MarshalJSON() ([]byte, error) {
} }
return json.Marshal(tmp) return json.Marshal(tmp)
} }
func strfnv32[K fmt.Stringer](key K) uint32 {
return fnv32(key.String())
}
func fnv32(key string) uint32 { func fnv32(key string) uint32 {
hash := uint32(2166136261) hash := uint32(2166136261)
@@ -333,24 +362,18 @@ func fnv32(key string) uint32 {
return hash return hash
} }
// Concurrent map uses Interface{} as its value, therefore JSON Unmarshal // Reverse process of Marshal.
// probably won't know which to type to unmarshal into, in such case func (m *ConcurrentMap[K, V]) UnmarshalJSON(b []byte) (err error) {
// we'll end up with a value of type map[string]interface{}, In most cases this isn't tmp := make(map[K]V)
// out value type, this is why we've decided to remove this functionality.
// func (m *ConcurrentMap) UnmarshalJSON(b []byte) (err error) { // Unmarshal into a single map.
// // Reverse process of Marshal. if err := json.Unmarshal(b, &tmp); err != nil {
return err
}
// tmp := make(map[string]interface{}) // foreach key,value pair in temporary map insert into our concurrent map.
for key, val := range tmp {
// // Unmarshal into a single map. m.Set(key, val)
// if err := json.Unmarshal(b, &tmp); err != nil { }
// return nil return nil
// } }
// // foreach key,value pair in temporary map insert into our concurrent map.
// for key, val := range tmp {
// m.Set(key, val)
// }
// return nil
// }

View File

@@ -5,7 +5,7 @@ import (
) )
type hub struct { type hub struct {
sessions cmap.ConcurrentMap sessions cmap.ConcurrentMap[string, *Session]
queue chan *envelope queue chan *envelope
register chan *Session register chan *Session
unregister chan *Session unregister chan *Session
@@ -15,7 +15,7 @@ type hub struct {
func newHub() *hub { func newHub() *hub {
return &hub{ return &hub{
sessions: cmap.New(), sessions: cmap.New[*Session](),
queue: make(chan *envelope), queue: make(chan *envelope),
register: make(chan *Session), register: make(chan *Session),
unregister: make(chan *Session), unregister: make(chan *Session),
@@ -38,19 +38,16 @@ loop:
if len(m.list) > 0 { if len(m.list) > 0 {
for _, uuid := range m.list { for _, uuid := range m.list {
if s, ok := h.sessions.Get(uuid); ok { if s, ok := h.sessions.Get(uuid); ok {
s := s.(*Session)
s.writeMessage(m) s.writeMessage(m)
} }
} }
} else if m.filter == nil { } else if m.filter == nil {
h.sessions.IterCb(func(uuid string, v interface{}) bool { h.sessions.IterCb(func(uuid string, s *Session) bool {
s := v.(*Session)
s.writeMessage(m) s.writeMessage(m)
return true return true
}) })
} else { } else {
h.sessions.IterCb(func(uuid string, v interface{}) bool { h.sessions.IterCb(func(uuid string, s *Session) bool {
s := v.(*Session)
if m.filter(s) { if m.filter(s) {
s.writeMessage(m) s.writeMessage(m)
} }
@@ -60,8 +57,7 @@ loop:
case m := <-h.exit: case m := <-h.exit:
var keys []string var keys []string
h.open = false h.open = false
h.sessions.IterCb(func(uuid string, v interface{}) bool { h.sessions.IterCb(func(uuid string, s *Session) bool {
s := v.(*Session)
s.writeMessage(m) s.writeMessage(m)
s.Close() s.Close()
keys = append(keys, uuid) keys = append(keys, uuid)

View File

@@ -301,27 +301,14 @@ func (m *Melody) SendMultiple(msg []byte, list []string) error {
// GetSessionByUUID returns the session with specified uuid. // GetSessionByUUID returns the session with specified uuid.
func (m *Melody) GetSessionByUUID(uuid string) (*Session, bool) { func (m *Melody) GetSessionByUUID(uuid string) (*Session, bool) {
val, ok := m.hub.sessions.Get(uuid) return m.hub.sessions.Get(uuid)
if !ok {
return nil, false
}
s, ok := val.(*Session)
if !ok {
m.hub.sessions.Remove(uuid)
}
return s, ok
} }
// IterSessions iterates all sessions. // IterSessions iterates all sessions.
func (m *Melody) IterSessions(fn func(uuid string, s *Session) bool) { func (m *Melody) IterSessions(fn func(uuid string, s *Session) bool) {
var invalid []string var invalid []string
m.hub.sessions.IterCb(func(uuid string, v interface{}) bool { m.hub.sessions.IterCb(func(uuid string, s *Session) bool {
if s, ok := v.(*Session); !ok { return fn(uuid, s)
invalid = append(invalid, uuid)
return true
} else {
return fn(uuid, s)
}
}) })
m.hub.sessions.Remove(invalid...) m.hub.sessions.Remove(invalid...)
} }

View File

@@ -394,7 +394,7 @@ function overview(props) {
<> <>
<Image <Image
preview={{ preview={{
visible: screenBlob, visible: !!screenBlob,
src: screenBlob, src: screenBlob,
onVisibleChange: () => { onVisibleChange: () => {
URL.revokeObjectURL(screenBlob); URL.revokeObjectURL(screenBlob);