From 7469a3362a7d738b682ea571a853fbc8ba465368 Mon Sep 17 00:00:00 2001 From: aler9 <46489434+aler9@users.noreply.github.com> Date: Sat, 8 May 2021 22:05:22 +0200 Subject: [PATCH] server: add author to OnSessionOpen --- examples/server-tls/main.go | 14 +++++----- examples/server/main.go | 14 +++++----- server.go | 2 +- server_publish_test.go | 32 ++++++++++----------- server_read_test.go | 28 +++++++++---------- server_test.go | 56 ++++++++++++++++++------------------- serverconn.go | 9 ++++-- serverhandler.go | 33 ++++++++++++++++++---- serversession.go | 23 +++++++++++---- 9 files changed, 125 insertions(+), 86 deletions(-) diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index b3f76ddf..8b2cf3e6 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -23,32 +23,32 @@ type serverHandler struct { } // called after a connection is opened. -func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { +func (sh *serverHandler) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) { log.Printf("conn opened") } // called after a connection is closed. -func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { - log.Printf("conn closed (%v)", err) +func (sh *serverHandler) OnConnClose(ctx *gortsplib.ServerHandlerOnConnCloseCtx) { + log.Printf("conn closed (%v)", ctx.Error) } // called after a session is opened. -func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { +func (sh *serverHandler) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) { log.Printf("session opened") } // called after a session is closed. -func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { +func (sh *serverHandler) OnSessionClose(ctx *gortsplib.ServerHandlerOnSessionCloseCtx) { log.Printf("session closed") sh.mutex.Lock() defer sh.mutex.Unlock() - if ss == sh.publisher { + if ctx.Session == sh.publisher { sh.publisher = nil sh.sdp = nil } else { - delete(sh.readers, ss) + delete(sh.readers, ctx.Session) } } diff --git a/examples/server/main.go b/examples/server/main.go index abd8e367..b14b03cd 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -22,32 +22,32 @@ type serverHandler struct { } // called after a connection is opened. -func (sh *serverHandler) OnConnOpen(sc *gortsplib.ServerConn) { +func (sh *serverHandler) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) { log.Printf("conn opened") } // called after a connection is closed. -func (sh *serverHandler) OnConnClose(sc *gortsplib.ServerConn, err error) { - log.Printf("conn closed (%v)", err) +func (sh *serverHandler) OnConnClose(ctx *gortsplib.ServerHandlerOnConnCloseCtx) { + log.Printf("conn closed (%v)", ctx.Error) } // called after a session is opened. -func (sh *serverHandler) OnSessionOpen(ss *gortsplib.ServerSession) { +func (sh *serverHandler) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) { log.Printf("session opened") } // called after a session is closed. -func (sh *serverHandler) OnSessionClose(ss *gortsplib.ServerSession, err error) { +func (sh *serverHandler) OnSessionClose(ctx *gortsplib.ServerHandlerOnSessionCloseCtx) { log.Printf("session closed") sh.mutex.Lock() defer sh.mutex.Unlock() - if ss == sh.publisher { + if ctx.Session == sh.publisher { sh.publisher = nil sh.sdp = nil } else { - delete(sh.readers, ss) + delete(sh.readers, ctx.Session) } } diff --git a/server.go b/server.go index 4a1919ed..8d080a8a 100644 --- a/server.go +++ b/server.go @@ -302,7 +302,7 @@ outer: continue } - ss := newServerSession(s, id, &wg) + ss := newServerSession(s, id, &wg, req.sc) s.sessions[id] = ss ss.request <- req } diff --git a/server_publish_test.go b/server_publish_test.go index 6483a6ba..30cad6e6 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -213,8 +213,8 @@ func TestServerPublishErrorAnnounce(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, ca.err, err.Error()) + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Equal(t, ca.err, ctx.Error.Error()) close(connClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { @@ -404,8 +404,8 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - serverErr <- err + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + serverErr <- ctx.Error }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ @@ -492,8 +492,8 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - serverErr <- err + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + serverErr <- ctx.Error }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ @@ -595,8 +595,8 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - serverErr <- err + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + serverErr <- ctx.Error }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ @@ -715,16 +715,16 @@ func TestServerPublish(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnOpen: func(sc *ServerConn) { + onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) { close(connOpened) }, - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(connClosed) }, - onSessionOpen: func(ss *ServerSession) { + onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { close(sessionOpened) }, - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { @@ -1231,10 +1231,10 @@ func TestServerPublishTimeout(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(connClosed) }, - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { @@ -1368,10 +1368,10 @@ func TestServerPublishWithoutTeardown(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(connClosed) }, - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { diff --git a/server_read_test.go b/server_read_test.go index 771227ea..86ec945d 100644 --- a/server_read_test.go +++ b/server_read_test.go @@ -120,8 +120,8 @@ func TestServerReadErrorSetupDifferentPaths(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, "can't setup tracks with different paths", err.Error()) + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Equal(t, "can't setup tracks with different paths", ctx.Error.Error()) close(connClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { @@ -193,8 +193,8 @@ func TestServerReadErrorSetupTrackTwice(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, "track 0 has already been setup", err.Error()) + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Equal(t, "track 0 has already been setup", ctx.Error.Error()) close(connClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { @@ -276,16 +276,16 @@ func TestServerRead(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnOpen: func(sc *ServerConn) { + onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) { close(connOpened) }, - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(connClosed) }, - onSessionOpen: func(ss *ServerSession) { + onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { close(sessionOpened) }, - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { @@ -514,7 +514,7 @@ func TestServerReadTCPResponseBeforeFrames(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(writerTerminate) <-writerDone }, @@ -693,7 +693,7 @@ func TestServerReadPlayPausePlay(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(writerTerminate) <-writerDone }, @@ -817,7 +817,7 @@ func TestServerReadPlayPausePause(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(writerTerminate) <-writerDone }, @@ -942,7 +942,7 @@ func TestServerReadTimeout(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { @@ -1035,10 +1035,10 @@ func TestServerReadWithoutTeardown(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(connClosed) }, - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { diff --git a/server_test.go b/server_test.go index 0eebe6be..842df0cd 100644 --- a/server_test.go +++ b/server_test.go @@ -31,10 +31,10 @@ func readResponseIgnoreFrames(br *bufio.Reader) (*base.Response, error) { } type testServerHandler struct { - onConnOpen func(*ServerConn) - onConnClose func(*ServerConn, error) - onSessionOpen func(*ServerSession) - onSessionClose func(*ServerSession, error) + onConnOpen func(*ServerHandlerOnConnOpenCtx) + onConnClose func(*ServerHandlerOnConnCloseCtx) + onSessionOpen func(*ServerHandlerOnSessionOpenCtx) + onSessionClose func(*ServerHandlerOnSessionCloseCtx) onDescribe func(*ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) onAnnounce func(*ServerHandlerOnAnnounceCtx) (*base.Response, error) onSetup func(*ServerHandlerOnSetupCtx) (*base.Response, error) @@ -46,27 +46,27 @@ type testServerHandler struct { onGetParameter func(*ServerHandlerOnGetParameterCtx) (*base.Response, error) } -func (sh *testServerHandler) OnConnOpen(sc *ServerConn) { +func (sh *testServerHandler) OnConnOpen(ctx *ServerHandlerOnConnOpenCtx) { if sh.onConnOpen != nil { - sh.onConnOpen(sc) + sh.onConnOpen(ctx) } } -func (sh *testServerHandler) OnConnClose(sc *ServerConn, err error) { +func (sh *testServerHandler) OnConnClose(ctx *ServerHandlerOnConnCloseCtx) { if sh.onConnClose != nil { - sh.onConnClose(sc, err) + sh.onConnClose(ctx) } } -func (sh *testServerHandler) OnSessionOpen(ss *ServerSession) { +func (sh *testServerHandler) OnSessionOpen(ctx *ServerHandlerOnSessionOpenCtx) { if sh.onSessionOpen != nil { - sh.onSessionOpen(ss) + sh.onSessionOpen(ctx) } } -func (sh *testServerHandler) OnSessionClose(ss *ServerSession, err error) { +func (sh *testServerHandler) OnSessionClose(ctx *ServerHandlerOnSessionCloseCtx) { if sh.onSessionClose != nil { - sh.onSessionClose(ss, err) + sh.onSessionClose(ctx) } } @@ -226,15 +226,15 @@ func TestServerHighLevelPublishRead(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { mutex.Lock() defer mutex.Unlock() - if ss == publisher { + if ctx.Session == publisher { publisher = nil sdp = nil } else { - delete(readers, ss) + delete(readers, ctx.Session) } }, onDescribe: func(ctx *ServerHandlerOnDescribeCtx) (*base.Response, []byte, error) { @@ -447,10 +447,10 @@ func TestServerConnClose(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnOpen: func(sc *ServerConn) { - sc.Close() + onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) { + ctx.Conn.Close() }, - onConnClose: func(sc *ServerConn, err error) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(connClosed) }, }, @@ -498,8 +498,8 @@ func TestServerErrorCSeqMissing(t *testing.T) { connClosed := make(chan struct{}) h := &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, "CSeq is missing", err.Error()) + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Equal(t, "CSeq is missing", ctx.Error.Error()) close(connClosed) }, } @@ -530,8 +530,8 @@ func TestServerErrorCSeqMissing(t *testing.T) { func TestServerErrorInvalidMethod(t *testing.T) { h := &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, "unhandled request (INVALID rtsp://localhost:8554/)", err.Error()) + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Equal(t, "unhandled request (INVALID rtsp://localhost:8554/)", ctx.Error.Error()) }, } @@ -885,10 +885,10 @@ func TestServerSessionClose(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onSessionOpen: func(ss *ServerSession) { - ss.Close() + onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { + ctx.Session.Close() }, - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { @@ -937,7 +937,7 @@ func TestServerSessionAutoClose(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onSessionClose: func(ss *ServerSession, err error) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) }, onSetup: func(ctx *ServerHandlerOnSetupCtx) (*base.Response, error) { @@ -999,8 +999,8 @@ func TestServerErrorInvalidPath(t *testing.T) { t.Run(string(method), func(t *testing.T) { s := &Server{ Handler: &testServerHandler{ - onConnClose: func(sc *ServerConn, err error) { - require.Equal(t, "invalid path", err.Error()) + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Equal(t, "invalid path", ctx.Error.Error()) }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ diff --git a/serverconn.go b/serverconn.go index fa8e35f6..cd5ca6fa 100644 --- a/serverconn.go +++ b/serverconn.go @@ -113,7 +113,9 @@ func (sc *ServerConn) run() { defer sc.wg.Done() if h, ok := sc.s.Handler.(ServerHandlerOnConnOpen); ok { - h.OnConnOpen(sc) + h.OnConnOpen(&ServerHandlerOnConnOpenCtx{ + Conn: sc, + }) } conn := func() net.Conn { @@ -266,7 +268,10 @@ func (sc *ServerConn) run() { close(sc.sessionRemove) if h, ok := sc.s.Handler.(ServerHandlerOnConnClose); ok { - h.OnConnClose(sc, err) + h.OnConnClose(&ServerHandlerOnConnCloseCtx{ + Conn: sc, + Error: err, + }) } } diff --git a/serverhandler.go b/serverhandler.go index ff19949c..2825cc0b 100644 --- a/serverhandler.go +++ b/serverhandler.go @@ -9,24 +9,47 @@ import ( type ServerHandler interface { } +// ServerHandlerOnConnOpenCtx is the context of a connection opening. +type ServerHandlerOnConnOpenCtx struct { + Conn *ServerConn +} + // ServerHandlerOnConnOpen can be implemented by a ServerHandler. type ServerHandlerOnConnOpen interface { - OnConnOpen(*ServerConn) + OnConnOpen(*ServerHandlerOnConnOpenCtx) +} + +// ServerHandlerOnConnCloseCtx is the context of a connection closure. +type ServerHandlerOnConnCloseCtx struct { + Conn *ServerConn + Error error } // ServerHandlerOnConnClose can be implemented by a ServerHandler. type ServerHandlerOnConnClose interface { - OnConnClose(*ServerConn, error) + OnConnClose(*ServerHandlerOnConnCloseCtx) +} + +// ServerHandlerOnSessionOpenCtx is the context of a session opening. +type ServerHandlerOnSessionOpenCtx struct { + Session *ServerSession + Conn *ServerConn } // ServerHandlerOnSessionOpen can be implemented by a ServerHandler. type ServerHandlerOnSessionOpen interface { - OnSessionOpen(*ServerSession) + OnSessionOpen(*ServerHandlerOnSessionOpenCtx) +} + +// ServerHandlerOnSessionCloseCtx is the context of a session closure. +type ServerHandlerOnSessionCloseCtx struct { + Session *ServerSession + Error error } // ServerHandlerOnSessionClose can be implemented by a ServerHandler. type ServerHandlerOnSessionClose interface { - OnSessionClose(*ServerSession, error) + OnSessionClose(*ServerHandlerOnSessionCloseCtx) } // ServerHandlerOnRequest can be implemented by a ServerHandler. @@ -152,7 +175,7 @@ type ServerHandlerOnSetParameter interface { OnSetParameter(*ServerHandlerOnSetParameterCtx) (*base.Response, error) } -// ServerHandlerOnFrameCtx is the context of a frame request. +// ServerHandlerOnFrameCtx is the context of a frame. type ServerHandlerOnFrameCtx struct { Session *ServerSession TrackID int diff --git a/serversession.go b/serversession.go index e9ebc8e9..461b195f 100644 --- a/serversession.go +++ b/serversession.go @@ -114,9 +114,10 @@ type ServerSessionAnnouncedTrack struct { // ServerSession is a server-side RTSP session. type ServerSession struct { - s *Server - id string - wg *sync.WaitGroup + s *Server + id string + wg *sync.WaitGroup + author *ServerConn conns map[*ServerConn]struct{} connsWG sync.WaitGroup @@ -139,15 +140,18 @@ type ServerSession struct { parentTerminate chan struct{} } -func newServerSession(s *Server, +func newServerSession( + s *Server, id string, wg *sync.WaitGroup, + author *ServerConn, ) *ServerSession { ss := &ServerSession{ s: s, id: id, wg: wg, + author: author, conns: make(map[*ServerConn]struct{}), lastRequestTime: time.Now(), request: make(chan request), @@ -214,7 +218,11 @@ func (ss *ServerSession) run() { defer ss.wg.Done() if h, ok := ss.s.Handler.(ServerHandlerOnSessionOpen); ok { - h.OnSessionOpen(ss) + h.OnSessionOpen(&ServerHandlerOnSessionOpenCtx{ + Session: ss, + Conn: ss.author, + }) + ss.author = nil } err := func() error { @@ -365,7 +373,10 @@ func (ss *ServerSession) run() { close(ss.connRemove) if h, ok := ss.s.Handler.(ServerHandlerOnSessionClose); ok { - h.OnSessionClose(ss, err) + h.OnSessionClose(&ServerHandlerOnSessionCloseCtx{ + Session: ss, + Error: err, + }) } }