diff --git a/server_publish_test.go b/server_publish_test.go index 0b7181c0..fb3095da 100644 --- a/server_publish_test.go +++ b/server_publish_test.go @@ -305,17 +305,12 @@ func TestServerPublishSetupPath(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL(ca.url), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": th.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) @@ -386,17 +381,12 @@ func TestServerPublishErrorSetupDifferentPaths(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/test2stream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": th.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) @@ -468,22 +458,21 @@ func TestServerPublishErrorSetupTrackTwice(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": th.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), @@ -570,22 +559,21 @@ func TestServerPublishErrorRecordPartialTracks(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": th.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), @@ -739,22 +727,21 @@ func TestServerPublish(t *testing.T) { inTH.InterleavedIDs = &[2]int{0, 1} } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": inTH.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + var th headers.Transport err = th.Read(res.Header["Transport"]) require.NoError(t, err) @@ -937,22 +924,21 @@ func TestServerPublishNonStandardFrameSize(t *testing.T) { InterleavedIDs: &[2]int{0, 1}, } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": inTH.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), @@ -1042,22 +1028,21 @@ func TestServerPublishErrorInvalidProtocol(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": inTH.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + var th headers.Transport err = th.Read(res.Header["Transport"]) require.NoError(t, err) @@ -1141,10 +1126,6 @@ func TestServerPublishRTCPReport(t *testing.T) { require.NoError(t, err) defer l2.Close() - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), @@ -1162,12 +1143,15 @@ func TestServerPublishRTCPReport(t *testing.T) { Protocol: headers.TransportProtocolUDP, ClientPorts: &[2]int{34556, 34557}, }.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + var th headers.Transport err = th.Read(res.Header["Transport"]) require.NoError(t, err) @@ -1327,22 +1311,21 @@ func TestServerPublishTimeout(t *testing.T) { inTH.InterleavedIDs = &[2]int{0, 1} } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": inTH.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + var th headers.Transport err = th.Read(res.Header["Transport"]) require.NoError(t, err) @@ -1454,22 +1437,21 @@ func TestServerPublishWithoutTeardown(t *testing.T) { inTH.InterleavedIDs = &[2]int{0, 1} } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": inTH.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + var th headers.Transport err = th.Read(res.Header["Transport"]) require.NoError(t, err) @@ -1567,22 +1549,21 @@ func TestServerPublishUDPChangeConn(t *testing.T) { ClientPorts: &[2]int{35466, 35467}, } - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Setup, URL: mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), Header: base.Header{ "CSeq": base.HeaderValue{"2"}, "Transport": inTH.Write(), - "Session": base.HeaderValue{sx.Session}, }, }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + var sx headers.Session + err = sx.Read(res.Header["Session"]) + require.NoError(t, err) + res, err = writeReqReadRes(conn, br, base.Request{ Method: base.Record, URL: mustParseURL("rtsp://localhost:8554/teststream"), diff --git a/server_test.go b/server_test.go index e5c791ef..169517c3 100644 --- a/server_test.go +++ b/server_test.go @@ -1186,12 +1186,6 @@ func TestServerErrorInvalidPath(t *testing.T) { }) require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) - - var sx headers.Session - err = sx.Read(res.Header["Session"]) - require.NoError(t, err) - - sxID = sx.Session } if method == base.Play || method == base.Record || method == base.Pause { diff --git a/serverconn.go b/serverconn.go index a794e516..4c2b2f2a 100644 --- a/serverconn.go +++ b/serverconn.go @@ -543,11 +543,15 @@ func (sc *ServerConn) handleRequestInSession( ) (*base.Response, error) { // handle directly in Session if sc.session != nil { - // the connection can't communicate with two sessions at once. - if sxID != sc.session.secretID { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerLinkedToOtherSession{} + // the SETUP request after ANNOUNCE don't have the session ID + // since ANNOUNCE didn't provide it. + if req.Method != base.Setup || sxID != "" { + // the connection can't communicate with two sessions at once. + if sxID != sc.session.secretID { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerLinkedToOtherSession{} + } } cres := make(chan sessionRequestRes) diff --git a/serversession.go b/serversession.go index ec0f4480..e66aaf18 100644 --- a/serversession.go +++ b/serversession.go @@ -281,17 +281,20 @@ func (ss *ServerSession) run() { var returnedSession *ServerSession if err == nil || err == errSwitchReadFunc { - if res.Header == nil { - res.Header = make(base.Header) - } + // ANNOUNCE responses don't contain the session header. + if req.req.Method != base.Announce { + if res.Header == nil { + res.Header = make(base.Header) + } - res.Header["Session"] = headers.Session{ - Session: ss.secretID, - Timeout: func() *uint { - v := uint(ss.s.sessionTimeout / time.Second) - return &v - }(), - }.Write() + res.Header["Session"] = headers.Session{ + Session: ss.secretID, + Timeout: func() *uint { + v := uint(ss.s.sessionTimeout / time.Second) + return &v + }(), + }.Write() + } if req.req.Method != base.Teardown { returnedSession = ss