fix: race condition

In client.go

    if room, ok := message.(outgoing.Room); ok {
            c.info.RoomID = room.ID
    }

this part isn't thread safe. It could happen that user disconnected but
wasn't removed from a room, because the disconnecting go routine
couldn't see the roomID yet.
This commit is contained in:
Jannis Mattheis
2024-10-11 14:47:26 +02:00
parent 54f9fb6b9e
commit a0f3c37498
14 changed files with 163 additions and 74 deletions

View File

@@ -41,7 +41,6 @@ type ClientMessage struct {
type ClientInfo struct { type ClientInfo struct {
ID xid.ID ID xid.ID
RoomID string
Authenticated bool Authenticated bool
AuthenticatedUser string AuthenticatedUser string
Write chan outgoing.Message Write chan outgoing.Message
@@ -60,7 +59,6 @@ func newClient(conn *websocket.Conn, req *http.Request, read chan ClientMessage,
Authenticated: authenticated, Authenticated: authenticated,
AuthenticatedUser: authenticatedUser, AuthenticatedUser: authenticatedUser,
ID: xid.New(), ID: xid.New(),
RoomID: "",
Addr: ip, Addr: ip,
Write: make(chan outgoing.Message, 1), Write: make(chan outgoing.Message, 1),
}, },
@@ -158,10 +156,6 @@ func (c *Client) startWriteHandler(pingPeriod time.Duration) {
continue continue
} }
if room, ok := message.(outgoing.Room); ok {
c.info.RoomID = room.ID
}
if err := writeJSON(c.conn, typed); err != nil { if err := writeJSON(c.conn, typed); err != nil {
c.printWebSocketError("write", err) c.printWebSocketError("write", err)
c.CloseOnError(websocket.CloseNormalClosure, "write error"+err.Error()) c.CloseOnError(websocket.CloseNormalClosure, "write error"+err.Error())

View File

@@ -16,13 +16,9 @@ func init() {
type ClientAnswer outgoing.P2PMessage type ClientAnswer outgoing.P2PMessage
func (e *ClientAnswer) Execute(rooms *Rooms, current ClientInfo) error { func (e *ClientAnswer) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID == "" { room, err := rooms.CurrentRoom(current)
return fmt.Errorf("not in a room") if err != nil {
} return err
room, ok := rooms.Rooms[current.RoomID]
if !ok {
return fmt.Errorf("room with id %s does not exist", current.RoomID)
} }
session, ok := room.Sessions[e.SID] session, ok := room.Sessions[e.SID]

View File

@@ -16,13 +16,9 @@ func init() {
type ClientICE outgoing.P2PMessage type ClientICE outgoing.P2PMessage
func (e *ClientICE) Execute(rooms *Rooms, current ClientInfo) error { func (e *ClientICE) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID == "" { room, err := rooms.CurrentRoom(current)
return fmt.Errorf("not in a room") if err != nil {
} return err
room, ok := rooms.Rooms[current.RoomID]
if !ok {
return fmt.Errorf("room with id %s does not exist", current.RoomID)
} }
session, ok := room.Sessions[e.SID] session, ok := room.Sessions[e.SID]

View File

@@ -3,6 +3,6 @@ package ws
type Connected struct{} type Connected struct{}
func (e Connected) Execute(rooms *Rooms, current ClientInfo) error { func (e Connected) Execute(rooms *Rooms, current ClientInfo) error {
rooms.connected[current.ID] = true rooms.connected[current.ID] = ""
return nil return nil
} }

View File

@@ -23,7 +23,7 @@ type Create struct {
} }
func (e *Create) Execute(rooms *Rooms, current ClientInfo) error { func (e *Create) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID != "" { if rooms.connected[current.ID] != "" {
return fmt.Errorf("cannot join room, you are already in one") return fmt.Errorf("cannot join room, you are already in one")
} }
@@ -74,6 +74,7 @@ func (e *Create) Execute(rooms *Rooms, current ClientInfo) error {
}, },
}, },
} }
rooms.connected[current.ID] = room.ID
rooms.Rooms[e.ID] = room rooms.Rooms[e.ID] = room
room.notifyInfoChanged() room.notifyInfoChanged()
usersJoinedTotal.Inc() usersJoinedTotal.Inc()

View File

@@ -18,14 +18,15 @@ func (e *Disconnected) Execute(rooms *Rooms, current ClientInfo) error {
} }
func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) { func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) {
roomID := rooms.connected[current.ID]
delete(rooms.connected, current.ID) delete(rooms.connected, current.ID)
current.Write <- outgoing.CloseWriter{Code: e.Code, Reason: e.Reason} current.Write <- outgoing.CloseWriter{Code: e.Code, Reason: e.Reason}
if current.RoomID == "" { if roomID == "" {
return return
} }
room, ok := rooms.Rooms[current.RoomID] room, ok := rooms.Rooms[roomID]
if !ok { if !ok {
// room may already be removed // room may already be removed
return return
@@ -63,12 +64,12 @@ func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) {
delete(rooms.connected, member.ID) delete(rooms.connected, member.ID)
member.Write <- outgoing.CloseWriter{Code: websocket.CloseNormalClosure, Reason: CloseOwnerLeft} member.Write <- outgoing.CloseWriter{Code: websocket.CloseNormalClosure, Reason: CloseOwnerLeft}
} }
rooms.closeRoom(current.RoomID) rooms.closeRoom(roomID)
return return
} }
if len(room.Users) == 0 { if len(room.Users) == 0 {
rooms.closeRoom(current.RoomID) rooms.closeRoom(roomID)
return return
} }

View File

@@ -16,13 +16,9 @@ func init() {
type HostICE outgoing.P2PMessage type HostICE outgoing.P2PMessage
func (e *HostICE) Execute(rooms *Rooms, current ClientInfo) error { func (e *HostICE) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID == "" { room, err := rooms.CurrentRoom(current)
return fmt.Errorf("not in a room") if err != nil {
} return err
room, ok := rooms.Rooms[current.RoomID]
if !ok {
return fmt.Errorf("room with id %s does not exist", current.RoomID)
} }
session, ok := room.Sessions[e.SID] session, ok := room.Sessions[e.SID]

View File

@@ -16,13 +16,9 @@ func init() {
type HostOffer outgoing.P2PMessage type HostOffer outgoing.P2PMessage
func (e *HostOffer) Execute(rooms *Rooms, current ClientInfo) error { func (e *HostOffer) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID == "" { room, err := rooms.CurrentRoom(current)
return fmt.Errorf("not in a room") if err != nil {
} return err
room, ok := rooms.Rooms[current.RoomID]
if !ok {
return fmt.Errorf("room with id %s does not exist", current.RoomID)
} }
session, ok := room.Sessions[e.SID] session, ok := room.Sessions[e.SID]

View File

@@ -16,7 +16,7 @@ type Join struct {
} }
func (e *Join) Execute(rooms *Rooms, current ClientInfo) error { func (e *Join) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID != "" { if rooms.connected[current.ID] != "" {
return fmt.Errorf("cannot join room, you are already in one") return fmt.Errorf("cannot join room, you are already in one")
} }
@@ -40,6 +40,7 @@ func (e *Join) Execute(rooms *Rooms, current ClientInfo) error {
Addr: current.Addr, Addr: current.Addr,
Write: current.Write, Write: current.Write,
} }
rooms.connected[current.ID] = room.ID
room.notifyInfoChanged() room.notifyInfoChanged()
usersJoinedTotal.Inc() usersJoinedTotal.Inc()

View File

@@ -1,9 +1,5 @@
package ws package ws
import (
"fmt"
)
func init() { func init() {
register("name", func() Event { register("name", func() Event {
return &Name{} return &Name{}
@@ -15,13 +11,9 @@ type Name struct {
} }
func (e *Name) Execute(rooms *Rooms, current ClientInfo) error { func (e *Name) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID == "" { room, err := rooms.CurrentRoom(current)
return fmt.Errorf("not in a room") if err != nil {
} return err
room, ok := rooms.Rooms[current.RoomID]
if !ok {
return fmt.Errorf("room with id %s does not exist", current.RoomID)
} }
room.Users[current.ID].Name = e.UserName room.Users[current.ID].Name = e.UserName

View File

@@ -1,9 +1,5 @@
package ws package ws
import (
"fmt"
)
func init() { func init() {
register("share", func() Event { register("share", func() Event {
return &StartShare{} return &StartShare{}
@@ -13,13 +9,9 @@ func init() {
type StartShare struct{} type StartShare struct{}
func (e *StartShare) Execute(rooms *Rooms, current ClientInfo) error { func (e *StartShare) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID == "" { room, err := rooms.CurrentRoom(current)
return fmt.Errorf("not in a room") if err != nil {
} return err
room, ok := rooms.Rooms[current.RoomID]
if !ok {
return fmt.Errorf("room with id %s does not exist", current.RoomID)
} }
room.Users[current.ID].Streaming = true room.Users[current.ID].Streaming = true

View File

@@ -2,7 +2,6 @@ package ws
import ( import (
"bytes" "bytes"
"fmt"
"github.com/screego/server/ws/outgoing" "github.com/screego/server/ws/outgoing"
) )
@@ -16,13 +15,9 @@ func init() {
type StopShare struct{} type StopShare struct{}
func (e *StopShare) Execute(rooms *Rooms, current ClientInfo) error { func (e *StopShare) Execute(rooms *Rooms, current ClientInfo) error {
if current.RoomID == "" { room, err := rooms.CurrentRoom(current)
return fmt.Errorf("not in a room") if err != nil {
} return err
room, ok := rooms.Rooms[current.RoomID]
if !ok {
return fmt.Errorf("room with id %s does not exist", current.RoomID)
} }
room.Users[current.ID].Streaming = false room.Users[current.ID].Streaming = false

View File

@@ -20,7 +20,7 @@ func NewRooms(tServer turn.Server, users *auth.Users, conf config.Config) *Rooms
return &Rooms{ return &Rooms{
Rooms: map[string]*Room{}, Rooms: map[string]*Room{},
Incoming: make(chan ClientMessage), Incoming: make(chan ClientMessage),
connected: map[xid.ID]bool{}, connected: map[xid.ID]string{},
turnServer: tServer, turnServer: tServer,
users: users, users: users,
config: conf, config: conf,
@@ -51,7 +51,23 @@ type Rooms struct {
users *auth.Users users *auth.Users
config config.Config config config.Config
r *rand.Rand r *rand.Rand
connected map[xid.ID]bool connected map[xid.ID]string
}
func (r *Rooms) CurrentRoom(info ClientInfo) (*Room, error) {
roomID, ok := r.connected[info.ID]
if !ok {
return nil, fmt.Errorf("not connected")
}
if roomID == "" {
return nil, fmt.Errorf("not in a room")
}
room, ok := r.Rooms[roomID]
if !ok {
return nil, fmt.Errorf("room with id %s does not exist", roomID)
}
return room, nil
} }
func (r *Rooms) RandUserName() string { func (r *Rooms) RandUserName() string {
@@ -81,7 +97,8 @@ func (r *Rooms) Upgrade(w http.ResponseWriter, req *http.Request) {
func (r *Rooms) Start() { func (r *Rooms) Start() {
for msg := range r.Incoming { for msg := range r.Incoming {
if !msg.SkipConnectedCheck && !r.connected[msg.Info.ID] { _, connected := r.connected[msg.Info.ID]
if !msg.SkipConnectedCheck && !connected {
log.Debug().Interface("event", fmt.Sprintf("%T", msg.Incoming)).Interface("payload", msg.Incoming).Msg("WebSocket Ignore") log.Debug().Interface("event", fmt.Sprintf("%T", msg.Incoming)).Interface("payload", msg.Incoming).Msg("WebSocket Ignore")
continue continue
} }

112
ws/rooms_test.go Normal file
View File

@@ -0,0 +1,112 @@
package ws
import (
"encoding/json"
"fmt"
"math/rand"
"sync"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/rs/xid"
)
const SERVER = "ws://localhost:5050/stream"
func TestMultipleClients(t *testing.T) {
t.Skip("only for manual testing")
r := rand.New(rand.NewSource(time.Now().UnixMicro()))
var wg sync.WaitGroup
for j := 0; j < 100; j++ {
name := fmt.Sprint(1)
users := r.Intn(5000)
for i := 0; i < users; i++ {
wg.Add(1)
go func() {
defer wg.Done()
testClient(r.Int63(), name)
}()
if i%100 == 0 {
time.Sleep(10 * time.Millisecond)
}
}
time.Sleep(50 * time.Millisecond)
}
wg.Wait()
}
func testClient(i int64, room string) {
r := rand.New(rand.NewSource(i))
conn, _, err := websocket.DefaultDialer.Dial(SERVER, nil)
if err != nil {
panic(err)
}
go func() {
for {
_ = conn.SetReadDeadline(time.Now().Add(10 * time.Second))
_, _, err := conn.ReadMessage()
if err != nil {
return
}
}
}()
defer conn.Close()
ops := r.Intn(100)
for i := 0; i < ops; i++ {
m := msg(r, room)
err = conn.WriteMessage(websocket.TextMessage, m)
if err != nil {
fmt.Println("err", err)
}
time.Sleep(30 * time.Millisecond)
}
}
func msg(r *rand.Rand, room string) []byte {
typed := Typed{}
var e Event
switch r.Intn(8) {
case 0:
typed.Type = "clientanswer"
e = &ClientAnswer{SID: xid.New(), Value: nil}
case 1:
typed.Type = "clientice"
e = &ClientICE{SID: xid.New(), Value: nil}
case 2:
typed.Type = "hostice"
e = &HostICE{SID: xid.New(), Value: nil}
case 3:
typed.Type = "hostoffer"
e = &HostOffer{SID: xid.New(), Value: nil}
case 4:
typed.Type = "name"
e = &Name{UserName: "a"}
case 5:
typed.Type = "share"
e = &StartShare{}
case 6:
typed.Type = "stopshare"
e = &StopShare{}
case 7:
typed.Type = "create"
e = &Create{ID: room, CloseOnOwnerLeave: r.Intn(2) == 0, JoinIfExist: r.Intn(2) == 0, Mode: ConnectionSTUN, UserName: "hello"}
}
b, err := json.Marshal(e)
if err != nil {
panic(err)
}
typed.Payload = json.RawMessage(b)
b, err = json.Marshal(typed)
if err != nil {
panic(err)
}
return b
}