mirror of
https://github.com/screego/server.git
synced 2025-12-24 12:57:51 +08:00
fix: race condition
In client.go
if room, ok := message.(outgoing.Room); ok {
c.info.RoomID = room.ID
}
this part isn't thread safe. It could happen that user disconnected but
wasn't removed from a room, because the disconnecting go routine
couldn't see the roomID yet.
This commit is contained in:
@@ -41,7 +41,6 @@ type ClientMessage struct {
|
||||
|
||||
type ClientInfo struct {
|
||||
ID xid.ID
|
||||
RoomID string
|
||||
Authenticated bool
|
||||
AuthenticatedUser string
|
||||
Write chan outgoing.Message
|
||||
@@ -60,7 +59,6 @@ func newClient(conn *websocket.Conn, req *http.Request, read chan ClientMessage,
|
||||
Authenticated: authenticated,
|
||||
AuthenticatedUser: authenticatedUser,
|
||||
ID: xid.New(),
|
||||
RoomID: "",
|
||||
Addr: ip,
|
||||
Write: make(chan outgoing.Message, 1),
|
||||
},
|
||||
@@ -158,10 +156,6 @@ func (c *Client) startWriteHandler(pingPeriod time.Duration) {
|
||||
continue
|
||||
}
|
||||
|
||||
if room, ok := message.(outgoing.Room); ok {
|
||||
c.info.RoomID = room.ID
|
||||
}
|
||||
|
||||
if err := writeJSON(c.conn, typed); err != nil {
|
||||
c.printWebSocketError("write", err)
|
||||
c.CloseOnError(websocket.CloseNormalClosure, "write error"+err.Error())
|
||||
|
||||
@@ -16,13 +16,9 @@ func init() {
|
||||
type ClientAnswer outgoing.P2PMessage
|
||||
|
||||
func (e *ClientAnswer) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID == "" {
|
||||
return fmt.Errorf("not in a room")
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room with id %s does not exist", current.RoomID)
|
||||
room, err := rooms.CurrentRoom(current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := room.Sessions[e.SID]
|
||||
|
||||
@@ -16,13 +16,9 @@ func init() {
|
||||
type ClientICE outgoing.P2PMessage
|
||||
|
||||
func (e *ClientICE) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID == "" {
|
||||
return fmt.Errorf("not in a room")
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room with id %s does not exist", current.RoomID)
|
||||
room, err := rooms.CurrentRoom(current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := room.Sessions[e.SID]
|
||||
|
||||
@@ -3,6 +3,6 @@ package ws
|
||||
type Connected struct{}
|
||||
|
||||
func (e Connected) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
rooms.connected[current.ID] = true
|
||||
rooms.connected[current.ID] = ""
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,7 +23,7 @@ type Create struct {
|
||||
}
|
||||
|
||||
func (e *Create) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID != "" {
|
||||
if rooms.connected[current.ID] != "" {
|
||||
return fmt.Errorf("cannot join room, you are already in one")
|
||||
}
|
||||
|
||||
@@ -74,6 +74,7 @@ func (e *Create) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
},
|
||||
},
|
||||
}
|
||||
rooms.connected[current.ID] = room.ID
|
||||
rooms.Rooms[e.ID] = room
|
||||
room.notifyInfoChanged()
|
||||
usersJoinedTotal.Inc()
|
||||
|
||||
@@ -18,14 +18,15 @@ func (e *Disconnected) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
}
|
||||
|
||||
func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) {
|
||||
roomID := rooms.connected[current.ID]
|
||||
delete(rooms.connected, current.ID)
|
||||
current.Write <- outgoing.CloseWriter{Code: e.Code, Reason: e.Reason}
|
||||
|
||||
if current.RoomID == "" {
|
||||
if roomID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
room, ok := rooms.Rooms[roomID]
|
||||
if !ok {
|
||||
// room may already be removed
|
||||
return
|
||||
@@ -63,12 +64,12 @@ func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) {
|
||||
delete(rooms.connected, member.ID)
|
||||
member.Write <- outgoing.CloseWriter{Code: websocket.CloseNormalClosure, Reason: CloseOwnerLeft}
|
||||
}
|
||||
rooms.closeRoom(current.RoomID)
|
||||
rooms.closeRoom(roomID)
|
||||
return
|
||||
}
|
||||
|
||||
if len(room.Users) == 0 {
|
||||
rooms.closeRoom(current.RoomID)
|
||||
rooms.closeRoom(roomID)
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -16,13 +16,9 @@ func init() {
|
||||
type HostICE outgoing.P2PMessage
|
||||
|
||||
func (e *HostICE) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID == "" {
|
||||
return fmt.Errorf("not in a room")
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room with id %s does not exist", current.RoomID)
|
||||
room, err := rooms.CurrentRoom(current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := room.Sessions[e.SID]
|
||||
|
||||
@@ -16,13 +16,9 @@ func init() {
|
||||
type HostOffer outgoing.P2PMessage
|
||||
|
||||
func (e *HostOffer) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID == "" {
|
||||
return fmt.Errorf("not in a room")
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room with id %s does not exist", current.RoomID)
|
||||
room, err := rooms.CurrentRoom(current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session, ok := room.Sessions[e.SID]
|
||||
|
||||
@@ -16,7 +16,7 @@ type Join struct {
|
||||
}
|
||||
|
||||
func (e *Join) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID != "" {
|
||||
if rooms.connected[current.ID] != "" {
|
||||
return fmt.Errorf("cannot join room, you are already in one")
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ func (e *Join) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
Addr: current.Addr,
|
||||
Write: current.Write,
|
||||
}
|
||||
rooms.connected[current.ID] = room.ID
|
||||
room.notifyInfoChanged()
|
||||
usersJoinedTotal.Inc()
|
||||
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
register("name", func() Event {
|
||||
return &Name{}
|
||||
@@ -15,13 +11,9 @@ type Name struct {
|
||||
}
|
||||
|
||||
func (e *Name) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID == "" {
|
||||
return fmt.Errorf("not in a room")
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room with id %s does not exist", current.RoomID)
|
||||
room, err := rooms.CurrentRoom(current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
room.Users[current.ID].Name = e.UserName
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
register("share", func() Event {
|
||||
return &StartShare{}
|
||||
@@ -13,13 +9,9 @@ func init() {
|
||||
type StartShare struct{}
|
||||
|
||||
func (e *StartShare) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID == "" {
|
||||
return fmt.Errorf("not in a room")
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room with id %s does not exist", current.RoomID)
|
||||
room, err := rooms.CurrentRoom(current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
room.Users[current.ID].Streaming = true
|
||||
|
||||
@@ -2,7 +2,6 @@ package ws
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/screego/server/ws/outgoing"
|
||||
)
|
||||
@@ -16,13 +15,9 @@ func init() {
|
||||
type StopShare struct{}
|
||||
|
||||
func (e *StopShare) Execute(rooms *Rooms, current ClientInfo) error {
|
||||
if current.RoomID == "" {
|
||||
return fmt.Errorf("not in a room")
|
||||
}
|
||||
|
||||
room, ok := rooms.Rooms[current.RoomID]
|
||||
if !ok {
|
||||
return fmt.Errorf("room with id %s does not exist", current.RoomID)
|
||||
room, err := rooms.CurrentRoom(current)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
room.Users[current.ID].Streaming = false
|
||||
|
||||
23
ws/rooms.go
23
ws/rooms.go
@@ -20,7 +20,7 @@ func NewRooms(tServer turn.Server, users *auth.Users, conf config.Config) *Rooms
|
||||
return &Rooms{
|
||||
Rooms: map[string]*Room{},
|
||||
Incoming: make(chan ClientMessage),
|
||||
connected: map[xid.ID]bool{},
|
||||
connected: map[xid.ID]string{},
|
||||
turnServer: tServer,
|
||||
users: users,
|
||||
config: conf,
|
||||
@@ -51,7 +51,23 @@ type Rooms struct {
|
||||
users *auth.Users
|
||||
config config.Config
|
||||
r *rand.Rand
|
||||
connected map[xid.ID]bool
|
||||
connected map[xid.ID]string
|
||||
}
|
||||
|
||||
func (r *Rooms) CurrentRoom(info ClientInfo) (*Room, error) {
|
||||
roomID, ok := r.connected[info.ID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
if roomID == "" {
|
||||
return nil, fmt.Errorf("not in a room")
|
||||
}
|
||||
room, ok := r.Rooms[roomID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("room with id %s does not exist", roomID)
|
||||
}
|
||||
|
||||
return room, nil
|
||||
}
|
||||
|
||||
func (r *Rooms) RandUserName() string {
|
||||
@@ -81,7 +97,8 @@ func (r *Rooms) Upgrade(w http.ResponseWriter, req *http.Request) {
|
||||
|
||||
func (r *Rooms) Start() {
|
||||
for msg := range r.Incoming {
|
||||
if !msg.SkipConnectedCheck && !r.connected[msg.Info.ID] {
|
||||
_, connected := r.connected[msg.Info.ID]
|
||||
if !msg.SkipConnectedCheck && !connected {
|
||||
log.Debug().Interface("event", fmt.Sprintf("%T", msg.Incoming)).Interface("payload", msg.Incoming).Msg("WebSocket Ignore")
|
||||
continue
|
||||
}
|
||||
|
||||
112
ws/rooms_test.go
Normal file
112
ws/rooms_test.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
const SERVER = "ws://localhost:5050/stream"
|
||||
|
||||
func TestMultipleClients(t *testing.T) {
|
||||
t.Skip("only for manual testing")
|
||||
r := rand.New(rand.NewSource(time.Now().UnixMicro()))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for j := 0; j < 100; j++ {
|
||||
name := fmt.Sprint(1)
|
||||
|
||||
users := r.Intn(5000)
|
||||
for i := 0; i < users; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
testClient(r.Int63(), name)
|
||||
}()
|
||||
if i%100 == 0 {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func testClient(i int64, room string) {
|
||||
r := rand.New(rand.NewSource(i))
|
||||
conn, _, err := websocket.DefaultDialer.Dial(SERVER, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||||
_, _, err := conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer conn.Close()
|
||||
|
||||
ops := r.Intn(100)
|
||||
for i := 0; i < ops; i++ {
|
||||
m := msg(r, room)
|
||||
err = conn.WriteMessage(websocket.TextMessage, m)
|
||||
if err != nil {
|
||||
fmt.Println("err", err)
|
||||
}
|
||||
time.Sleep(30 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func msg(r *rand.Rand, room string) []byte {
|
||||
typed := Typed{}
|
||||
var e Event
|
||||
switch r.Intn(8) {
|
||||
case 0:
|
||||
typed.Type = "clientanswer"
|
||||
e = &ClientAnswer{SID: xid.New(), Value: nil}
|
||||
case 1:
|
||||
typed.Type = "clientice"
|
||||
e = &ClientICE{SID: xid.New(), Value: nil}
|
||||
case 2:
|
||||
typed.Type = "hostice"
|
||||
e = &HostICE{SID: xid.New(), Value: nil}
|
||||
case 3:
|
||||
typed.Type = "hostoffer"
|
||||
e = &HostOffer{SID: xid.New(), Value: nil}
|
||||
case 4:
|
||||
typed.Type = "name"
|
||||
e = &Name{UserName: "a"}
|
||||
case 5:
|
||||
typed.Type = "share"
|
||||
e = &StartShare{}
|
||||
case 6:
|
||||
typed.Type = "stopshare"
|
||||
e = &StopShare{}
|
||||
case 7:
|
||||
typed.Type = "create"
|
||||
e = &Create{ID: room, CloseOnOwnerLeave: r.Intn(2) == 0, JoinIfExist: r.Intn(2) == 0, Mode: ConnectionSTUN, UserName: "hello"}
|
||||
}
|
||||
|
||||
b, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
typed.Payload = json.RawMessage(b)
|
||||
|
||||
b, err = json.Marshal(typed)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return b
|
||||
}
|
||||
Reference in New Issue
Block a user