diff --git a/pkg/liberrors/server.go b/pkg/liberrors/server.go index f79cc708..e6572a25 100644 --- a/pkg/liberrors/server.go +++ b/pkg/liberrors/server.go @@ -253,3 +253,11 @@ type ErrServerSessionNotInUse struct{} func (e ErrServerSessionNotInUse) Error() string { return "not in use" } + +// ErrServerUnexpectedFrame is an error that can be returned by a server. +type ErrServerUnexpectedFrame struct{} + +// Error implements the error interface. +func (e ErrServerUnexpectedFrame) Error() string { + return "received unexpected interleaved frame" +} diff --git a/server_publish_test.go b/server_publish_test.go index 0e395d6f..eb1c61ab 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -832,8 +832,14 @@ func TestServerPublish(t *testing.T) { } func TestServerPublishErrorInvalidProtocol(t *testing.T) { + errorRecv := make(chan struct{}) + s := &Server{ Handler: &testServerHandler{ + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.EqualError(t, ctx.Error, "received unexpected interleaved frame") + close(errorRecv) + }, onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, @@ -937,6 +943,8 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { Payload: []byte{0x01, 0x02, 0x03, 0x04}, }, make([]byte, 1024)) require.NoError(t, err) + + <-errorRecv } func TestServerPublishRTCPReport(t *testing.T) { diff --git a/serverconn.go b/serverconn.go index be76210b..1f410f75 100644 --- a/serverconn.go +++ b/serverconn.go @@ -210,21 +210,27 @@ func (sc *ServerConn) readFuncStandard(readRequest chan readReq) error { sc.nconn.SetReadDeadline(time.Time{}) for { - req, err := sc.conn.ReadRequest() + any, err := sc.conn.ReadInterleavedFrameOrRequest() if err != nil { return err } - cres := make(chan error) - select { - case readRequest <- readReq{req: req, res: cres}: - err = <-cres - if err != nil { - return err + switch what := any.(type) { + case *base.Request: + cres := make(chan error) + select { + case readRequest <- readReq{req: what, res: cres}: + err = <-cres + if err != nil { + return err + } + + case <-sc.ctx.Done(): + return liberrors.ErrServerTerminated{} } - case <-sc.ctx.Done(): - return liberrors.ErrServerTerminated{} + default: + return liberrors.ErrServerUnexpectedFrame{} } } }