diff --git a/server.go b/server.go index 8d080a8a..95bde3f0 100644 --- a/server.go +++ b/server.go @@ -225,7 +225,7 @@ func (s *Server) Start(address string) error { return err } - s.terminate = make(chan struct{}) + s.terminate = make(chan struct{}, 1) s.done = make(chan struct{}) go s.run() @@ -233,6 +233,22 @@ func (s *Server) Start(address string) error { return nil } +// Close closes all the server resources and waits for the server to exit. +func (s *Server) Close() error { + select { + case s.terminate <- struct{}{}: + default: + } + <-s.done + return nil +} + +// Wait waits until a fatal error. +func (s *Server) Wait() error { + <-s.done + return s.exitError +} + func (s *Server) run() { s.sessions = make(map[string]*ServerSession) s.conns = make(map[*ServerConn]struct{}) @@ -386,19 +402,6 @@ outer: close(s.done) } -// Close closes all the server resources and waits for the server to exit. -func (s *Server) Close() error { - close(s.terminate) - <-s.done - return nil -} - -// Wait waits until a fatal error. -func (s *Server) Wait() error { - <-s.done - return s.exitError -} - // StartAndWait starts the server and waits until a fatal error. func (s *Server) StartAndWait(address string) error { err := s.Start(address) diff --git a/server_test.go b/server_test.go index 842df0cd..793bb35d 100644 --- a/server_test.go +++ b/server_test.go @@ -422,6 +422,17 @@ func TestServerHighLevelPublishRead(t *testing.T) { } } +func TestServerClose(t *testing.T) { + s := &Server{ + Handler: &testServerHandler{}, + } + + err := s.Start("127.0.0.1:8554") + require.NoError(t, err) + s.Close() + s.Close() +} + func TestServerErrorWrongUDPPorts(t *testing.T) { t.Run("non consecutive", func(t *testing.T) { s := &Server{ @@ -449,6 +460,7 @@ func TestServerConnClose(t *testing.T) { Handler: &testServerHandler{ onConnOpen: func(ctx *ServerHandlerOnConnOpenCtx) { ctx.Conn.Close() + ctx.Conn.Close() }, onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { close(connClosed) @@ -887,6 +899,7 @@ func TestServerSessionClose(t *testing.T) { Handler: &testServerHandler{ onSessionOpen: func(ctx *ServerHandlerOnSessionOpenCtx) { ctx.Session.Close() + ctx.Session.Close() }, onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { close(sessionClosed) diff --git a/serverconn.go b/serverconn.go index cd5ca6fa..96e7338b 100644 --- a/serverconn.go +++ b/serverconn.go @@ -63,7 +63,7 @@ type ServerConn struct { // in sessionRemove chan *ServerSession - innerTerminate chan struct{} + terminate chan struct{} parentTerminate chan struct{} } @@ -77,7 +77,7 @@ func newServerConn( wg: wg, nconn: nconn, sessionRemove: make(chan *ServerSession), - innerTerminate: make(chan struct{}, 1), + terminate: make(chan struct{}, 1), parentTerminate: make(chan struct{}), } @@ -90,7 +90,7 @@ func newServerConn( // Close closes the ServerConn. func (sc *ServerConn) Close() error { select { - case sc.innerTerminate <- struct{}{}: + case sc.terminate <- struct{}{}: default: } return nil @@ -214,7 +214,7 @@ func (sc *ServerConn) run() { sc.sessionsWG.Done() } - case <-sc.innerTerminate: + case <-sc.terminate: return liberrors.ErrServerTerminated{} } } diff --git a/serversession.go b/serversession.go index 461b195f..6bd7ba53 100644 --- a/serversession.go +++ b/serversession.go @@ -136,7 +136,7 @@ type ServerSession struct { // in request chan request connRemove chan *ServerConn - innerTerminate chan struct{} + terminate chan struct{} parentTerminate chan struct{} } @@ -156,7 +156,7 @@ func newServerSession( lastRequestTime: time.Now(), request: make(chan request), connRemove: make(chan *ServerConn), - innerTerminate: make(chan struct{}, 1), + terminate: make(chan struct{}, 1), parentTerminate: make(chan struct{}), } @@ -169,7 +169,7 @@ func newServerSession( // Close closes the ServerSession. func (ss *ServerSession) Close() error { select { - case ss.innerTerminate <- struct{}{}: + case ss.terminate <- struct{}{}: default: } return nil @@ -311,7 +311,7 @@ func (ss *ServerSession) run() { ss.WriteFrame(trackID, StreamTypeRTCP, r) } - case <-ss.innerTerminate: + case <-ss.terminate: return liberrors.ErrServerTerminated{} } }