diff --git a/serverconn.go b/serverconn.go index bf9bc2d4..3daab6a6 100644 --- a/serverconn.go +++ b/serverconn.go @@ -43,7 +43,7 @@ func stringsReverseIndex(s, substr string) int { func extractTrackIDAndPath(url *base.URL, thMode *headers.TransportMode, publishTracks []ServerConnAnnouncedTrack, - publishPath string) (int, string, error) { + setupPath *string) (int, string, error) { pathAndQuery, ok := url.RTSPPathAndQuery() if !ok { @@ -73,13 +73,17 @@ func extractTrackIDAndPath(url *base.URL, path, _ := base.PathSplitQuery(pathAndQuery) + if setupPath != nil && path != *setupPath { + return 0, "", fmt.Errorf("can't setup tracks with different paths") + } + return trackID, path, nil } for trackID, track := range publishTracks { u, _ := track.track.URL() if u.String() == url.String() { - return trackID, publishPath, nil + return trackID, *setupPath, nil } } @@ -184,7 +188,8 @@ type ServerConn struct { bw *bufio.Writer state ServerConnState tracks map[int]ServerConnTrack - streamProtocol *StreamProtocol + setupProtocol *StreamProtocol + setupPath *string // frame mode only doEnableFrames bool @@ -197,7 +202,6 @@ type ServerConn struct { readHandlers ServerConnReadHandlers // publish only - publishPath string publishTracks []ServerConnAnnouncedTrack backgroundRecordTerminate chan struct{} backgroundRecordDone chan struct{} @@ -245,7 +249,7 @@ func (sc *ServerConn) State() ServerConnState { // StreamProtocol returns the setupped tracks protocol. func (sc *ServerConn) StreamProtocol() *StreamProtocol { - return sc.streamProtocol + return sc.setupProtocol } // HasTrack checks whether a track has been setup. @@ -312,7 +316,7 @@ func (sc *ServerConn) zone() string { func (sc *ServerConn) frameModeEnable() { switch sc.state { case ServerConnStatePlay: - if *sc.streamProtocol == StreamProtocolTCP { + if *sc.setupProtocol == StreamProtocolTCP { sc.doEnableFrames = true } else { // readers can send RTCP frames, they cannot sent RTP frames @@ -322,7 +326,7 @@ func (sc *ServerConn) frameModeEnable() { } case ServerConnStateRecord: - if *sc.streamProtocol == StreamProtocolTCP { + if *sc.setupProtocol == StreamProtocolTCP { sc.doEnableFrames = true sc.readTimeoutEnabled = true @@ -348,7 +352,7 @@ func (sc *ServerConn) frameModeEnable() { func (sc *ServerConn) frameModeDisable() { switch sc.state { case ServerConnStatePlay: - if *sc.streamProtocol == StreamProtocolTCP { + if *sc.setupProtocol == StreamProtocolTCP { sc.framesEnabled = false sc.frameRingBuffer.Close() <-sc.backgroundWriteDone @@ -363,7 +367,7 @@ func (sc *ServerConn) frameModeDisable() { close(sc.backgroundRecordTerminate) <-sc.backgroundRecordDone - if *sc.streamProtocol == StreamProtocolTCP { + if *sc.setupProtocol == StreamProtocolTCP { sc.readTimeoutEnabled = false sc.nconn.SetReadDeadline(time.Time{}) @@ -515,7 +519,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { if res.StatusCode == 200 { sc.state = ServerConnStatePreRecord - sc.publishPath = reqPath + sc.setupPath = &reqPath sc.publishTracks = make([]ServerConnAnnouncedTrack, len(tracks)) for trackID, track := range tracks { @@ -553,6 +557,26 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("transport header: %s", err) } + if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast { + return &base.Response{ + StatusCode: base.StatusUnsupportedTransport, + }, nil + } + + trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode, + sc.publishTracks, sc.setupPath) + if err != nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, err + } + + if _, ok := sc.tracks[trackID]; ok { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("track %d has already been setup", trackID) + } + switch sc.state { case ServerConnStateInitial, ServerConnStatePrePlay: // play if th.Mode != nil && *th.Mode != headers.TransportModePlay { @@ -569,32 +593,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } } - if th.Delivery != nil && *th.Delivery == base.StreamDeliveryMulticast { - return &base.Response{ - StatusCode: base.StatusUnsupportedTransport, - }, nil - } - - trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode, - sc.publishTracks, sc.publishPath) - if err != nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, err - } - - if _, ok := sc.tracks[trackID]; ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("track %d has already been setup", trackID) - } - - if sc.streamProtocol != nil && *sc.streamProtocol != th.Protocol { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("can't setup tracks with different protocols") - } - if th.Protocol == StreamProtocolUDP { if sc.udpRTPListener == nil { return &base.Response{ @@ -624,10 +622,16 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } } + if sc.setupProtocol != nil && *sc.setupProtocol != th.Protocol { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("can't setup tracks with different protocols") + } + res, err := sc.readHandlers.OnSetup(req, th, path, trackID) if res.StatusCode == 200 { - sc.streamProtocol = &th.Protocol + sc.setupProtocol = &th.Protocol if sc.tracks == nil { sc.tracks = make(map[int]ServerConnTrack) @@ -668,6 +672,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { switch sc.state { case ServerConnStateInitial: sc.state = ServerConnStatePrePlay + sc.setupPath = &path } // workaround to prevent a bug in rtspclientsink @@ -949,7 +954,7 @@ func (sc *ServerConn) Read(readHandlers ServerConnReadHandlers) chan error { // WriteFrame writes a frame. func (sc *ServerConn) WriteFrame(trackID int, streamType StreamType, payload []byte) { - if *sc.streamProtocol == StreamProtocolUDP { + if *sc.setupProtocol == StreamProtocolUDP { track := sc.tracks[trackID] if streamType == StreamTypeRTP { @@ -990,7 +995,7 @@ func (sc *ServerConn) backgroundRecord() { for { select { case <-checkStreamTicker.C: - if *sc.streamProtocol != StreamProtocolUDP { + if *sc.setupProtocol != StreamProtocolUDP { continue } diff --git a/serverconnpublish_test.go b/serverconnpublish_test.go index 5c6f1f38..fecdc381 100644 --- a/serverconnpublish_test.go +++ b/serverconnpublish_test.go @@ -187,6 +187,101 @@ func TestServerConnPublishSetupPath(t *testing.T) { } } +func TestServerConnPublishSetupDifferentPaths(t *testing.T) { + s, err := Serve("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := s.Accept() + require.NoError(t, err) + defer conn.Close() + + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + <-conn.Read(ServerConnReadHandlers{ + OnSetup: onSetup, + }) + }() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + track.Media.Attributes = append(track.Media.Attributes, psdp.Attribute{ + Key: "control", + Value: "trackID=0", + }) + + sout := &psdp.SessionDescription{ + SessionName: psdp.SessionName("Stream"), + Origin: psdp.Origin{ + Username: "-", + NetworkType: "IN", + AddressType: "IP4", + UnicastAddress: "127.0.0.1", + }, + TimeDescriptions: []psdp.TimeDescription{ + {Timing: psdp.Timing{0, 0}}, //nolint:govet + }, + MediaDescriptions: []*psdp.MediaDescription{ + track.Media, + }, + } + + byts, _ := sout.Marshal() + + err = base.Request{ + Method: base.Announce, + URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: byts, + }.Write(bconn.Writer) + require.NoError(t, err) + + th := &headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIds: &[2]int{0, 1}, + } + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/test2stream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) +} + func TestServerConnPublishReceivePackets(t *testing.T) { for _, proto := range []string{ "udp", diff --git a/serverconnread_test.go b/serverconnread_test.go index d5a0a824..5f7de823 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -131,6 +131,81 @@ func TestServerConnReadSetupPath(t *testing.T) { } } +func TestServerConnReadSetupDifferentPaths(t *testing.T) { + s, err := Serve("127.0.0.1:8554") + require.NoError(t, err) + defer s.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := s.Accept() + require.NoError(t, err) + defer conn.Close() + + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + } + + <-conn.Read(ServerConnReadHandlers{ + OnSetup: onSetup, + }) + }() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + th := &headers.Transport{ + Protocol: StreamProtocolTCP, + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Mode: func() *headers.TransportMode { + v := headers.TransportModePlay + return &v + }(), + InterleavedIds: &[2]int{0, 1}, + } + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/teststream/trackID=0"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + var res base.Response + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) + + th.InterleavedIds = &[2]int{2, 3} + + err = base.Request{ + Method: base.Setup, + URL: base.MustParseURL("rtsp://localhost:8554/test12stream/trackID=1"), + Header: base.Header{ + "CSeq": base.HeaderValue{"2"}, + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = res.Read(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) +} + func TestServerConnReadReceivePackets(t *testing.T) { for _, proto := range []string{ "udp",