diff --git a/examples/server-tls/main.go b/examples/server-tls/main.go index dc832a51..fdab6e8c 100644 --- a/examples/server-tls/main.go +++ b/examples/server-tls/main.go @@ -72,7 +72,7 @@ func handleConn(conn *gortsplib.ServerConn) { } // called after receiving a SETUP request. - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/examples/server-udp/main.go b/examples/server-udp/main.go index b5501b9e..fb226d0a 100644 --- a/examples/server-udp/main.go +++ b/examples/server-udp/main.go @@ -71,7 +71,7 @@ func handleConn(conn *gortsplib.ServerConn) { } // called after receiving a SETUP request. - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/examples/server/main.go b/examples/server/main.go index 45b95f2b..73067cc0 100644 --- a/examples/server/main.go +++ b/examples/server/main.go @@ -71,7 +71,7 @@ func handleConn(conn *gortsplib.ServerConn) { } // called after receiving a SETUP request. - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, Header: base.Header{ diff --git a/pkg/base/url.go b/pkg/base/url.go index d42a5580..a840ca96 100644 --- a/pkg/base/url.go +++ b/pkg/base/url.go @@ -3,37 +3,9 @@ package base import ( "fmt" "net/url" - "strconv" "strings" ) -func stringsReverseIndex(s, substr string) int { - for i := len(s) - 1 - len(substr); i >= 0; i-- { - if s[i:i+len(substr)] == substr { - return i - } - } - return -1 -} - -// PathSplitControlAttribute splits a path and query from a control attribute. -func PathSplitControlAttribute(pathAndQuery string) (int, string, bool) { - i := stringsReverseIndex(pathAndQuery, "/trackID=") - - // URL doesn't contain trackID - we assume it's track 0 - if i < 0 { - return 0, pathAndQuery, true - } - - tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64) - if err != nil || tmp < 0 { - return 0, "", false - } - trackID := int(tmp) - - return trackID, pathAndQuery[:i], true -} - // PathSplitQuery splits a path from a query. func PathSplitQuery(pathAndQuery string) (string, string) { i := strings.Index(pathAndQuery, "?") diff --git a/serverconn.go b/serverconn.go index 66b1d941..bf9bc2d4 100644 --- a/serverconn.go +++ b/serverconn.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "strconv" "strings" "sync/atomic" "time" @@ -30,6 +31,61 @@ var ( errServerCSeqMissing = errors.New("CSeq is missing") ) +func stringsReverseIndex(s, substr string) int { + for i := len(s) - 1 - len(substr); i >= 0; i-- { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} + +func extractTrackIDAndPath(url *base.URL, + thMode *headers.TransportMode, + publishTracks []ServerConnAnnouncedTrack, + publishPath string) (int, string, error) { + + pathAndQuery, ok := url.RTSPPathAndQuery() + if !ok { + return 0, "", fmt.Errorf("invalid URL (%s)", url) + } + + if thMode == nil || *thMode == headers.TransportModePlay { + i := stringsReverseIndex(pathAndQuery, "/trackID=") + + // URL doesn't contain trackID - it's track zero + if i < 0 { + if !strings.HasSuffix(pathAndQuery, "/") { + return 0, "", fmt.Errorf("path must end with a slash (%v)", pathAndQuery) + } + pathAndQuery = pathAndQuery[:len(pathAndQuery)-1] + + // we assume it's track 0 + return 0, pathAndQuery, nil + } + + tmp, err := strconv.ParseInt(pathAndQuery[i+len("/trackID="):], 10, 64) + if err != nil || tmp < 0 { + return 0, "", fmt.Errorf("unable to parse track ID (%v)", pathAndQuery) + } + trackID := int(tmp) + pathAndQuery = pathAndQuery[:i] + + path, _ := base.PathSplitQuery(pathAndQuery) + + return trackID, path, nil + } + + for trackID, track := range publishTracks { + u, _ := track.track.URL() + if u.String() == url.String() { + return trackID, publishPath, nil + } + } + + return 0, "", fmt.Errorf("invalid track path (%s)", pathAndQuery) +} + // ServerConnState is the state of the connection. type ServerConnState int @@ -92,7 +148,7 @@ type ServerConnReadHandlers struct { OnAnnounce func(req *base.Request, tracks Tracks) (*base.Response, error) // called after receiving a SETUP request. - OnSetup func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) + OnSetup func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) // called after receiving a PLAY request. OnPlay func(req *base.Request) (*base.Response, error) @@ -127,20 +183,22 @@ type ServerConn struct { br *bufio.Reader bw *bufio.Writer state ServerConnState - readHandlers ServerConnReadHandlers tracks map[int]ServerConnTrack streamProtocol *StreamProtocol - announcedTracks []ServerConnAnnouncedTrack - doEnableFrames bool - framesEnabled bool - readTimeoutEnabled bool - - // writer + // frame mode only + doEnableFrames bool + framesEnabled bool + readTimeoutEnabled bool frameRingBuffer *ringbuffer.RingBuffer backgroundWriteDone chan struct{} - // background record + // read only + readHandlers ServerConnReadHandlers + + // publish only + publishPath string + publishTracks []ServerConnAnnouncedTrack backgroundRecordTerminate chan struct{} backgroundRecordDone chan struct{} udpTimeout int32 @@ -457,14 +515,14 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { if res.StatusCode == 200 { sc.state = ServerConnStatePreRecord + sc.publishPath = reqPath - sc.announcedTracks = make([]ServerConnAnnouncedTrack, len(tracks)) - + sc.publishTracks = make([]ServerConnAnnouncedTrack, len(tracks)) for trackID, track := range tracks { clockRate, _ := track.ClockRate() v := time.Now().Unix() - sc.announcedTracks[trackID] = ServerConnAnnouncedTrack{ + sc.publishTracks[trackID] = ServerConnAnnouncedTrack{ track: track, rtcpReceiver: rtcpreceiver.New(nil, clockRate), udpLastFrameTime: &v, @@ -488,13 +546,6 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, err } - pathAndQuery, ok := req.URL.RTSPPathAndQuery() - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", req.URL) - } - th, err := headers.ReadTransport(req.Header["Transport"]) if err != nil { return &base.Response{ @@ -524,25 +575,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, nil } - trackID, err := func() (int, error) { - if th.Mode == nil || *th.Mode == headers.TransportModePlay { - trackID, _, ok := base.PathSplitControlAttribute(pathAndQuery) - if !ok { - return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery) - } - - return trackID, nil - } - - for trackID, track := range sc.announcedTracks { - u, _ := track.track.URL() - if u.String() == req.URL.String() { - return trackID, nil - } - } - - return 0, fmt.Errorf("invalid track path (%s)", pathAndQuery) - }() + trackID, path, err := extractTrackIDAndPath(req.URL, th.Mode, + sc.publishTracks, sc.publishPath) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -590,7 +624,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { } } - res, err := sc.readHandlers.OnSetup(req, th, trackID) + res, err := sc.readHandlers.OnSetup(req, th, path, trackID) if res.StatusCode == 200 { sc.streamProtocol = &th.Protocol @@ -697,7 +731,7 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) { }, fmt.Errorf("no tracks have been setup") } - if len(sc.tracks) != len(sc.announcedTracks) { + if len(sc.tracks) != len(sc.publishTracks) { return &base.Response{ StatusCode: base.StatusBadRequest, }, fmt.Errorf("not all tracks have been setup") @@ -860,7 +894,7 @@ outer: // forward frame only if it has been set up if _, ok := sc.tracks[frame.TrackID]; ok { if sc.state == ServerConnStateRecord { - sc.announcedTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), + sc.publishTracks[frame.TrackID].rtcpReceiver.ProcessFrame(time.Now(), frame.StreamType, frame.Payload) } sc.readHandlers.OnFrame(frame.TrackID, frame.StreamType, frame.Payload) @@ -961,7 +995,7 @@ func (sc *ServerConn) backgroundRecord() { } now := time.Now() - for _, track := range sc.announcedTracks { + for _, track := range sc.publishTracks { last := time.Unix(atomic.LoadInt64(track.udpLastFrameTime), 0) if now.Sub(last) >= sc.conf.ReadTimeout { @@ -973,7 +1007,7 @@ func (sc *ServerConn) backgroundRecord() { case <-receiverReportTicker.C: now := time.Now() - for trackID, track := range sc.announcedTracks { + for trackID, track := range sc.publishTracks { r := track.rtcpReceiver.Report(now) sc.WriteFrame(trackID, StreamTypeRTP, r) } diff --git a/serverconn_test.go b/serverconn_test.go index 86b6f584..050d0319 100644 --- a/serverconn_test.go +++ b/serverconn_test.go @@ -145,30 +145,11 @@ func (ts *testServ) handleConn(conn *ServerConn) { }, nil } - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { - switch conn.State() { - case ServerConnStateInitial, ServerConnStatePrePlay: - pathAndQuery, ok := req.URL.RTSPPathAndQuery() - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", req.URL) - } - - _, pathAndQuery, ok = base.PathSplitControlAttribute(pathAndQuery) - if !ok { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", req.URL) - } - - reqPath, _ := base.PathSplitQuery(pathAndQuery) - - if reqPath != "teststream" { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, fmt.Errorf("invalid path (%s)", req.URL) - } + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { + if path != "teststream" { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, fmt.Errorf("invalid path (%s)", req.URL) } return &base.Response{ diff --git a/serverconnpublish_test.go b/serverconnpublish_test.go index 15ff63ea..5c6f1f38 100644 --- a/serverconnpublish_test.go +++ b/serverconnpublish_test.go @@ -21,35 +21,58 @@ func TestServerConnPublishSetupPath(t *testing.T) { name string control string url string + path string trackID int }{ { "normal", "trackID=0", "rtsp://localhost:8554/teststream/trackID=0", + "teststream", 0, }, { "unordered id", "trackID=2", "rtsp://localhost:8554/teststream/trackID=2", + "teststream", 0, }, { "custom param name", "testing=0", "rtsp://localhost:8554/teststream/testing=0", + "teststream", 0, }, { "query", "?testing=0", "rtsp://localhost:8554/teststream?testing=0", + "teststream", + 0, + }, + { + "subpath", + "trackID=0", + "rtsp://localhost:8554/test/stream/trackID=0", + "test/stream", + 0, + }, + { + "subpath and query", + "?testing=0", + "rtsp://localhost:8554/test/stream?testing=0", + "test/stream", 0, }, } { t.Run(ca.name, func(t *testing.T) { - setupDone := make(chan int) + type pathTrackIDPair struct { + path string + trackID int + } + setupDone := make(chan pathTrackIDPair) s, err := Serve("127.0.0.1:8554") require.NoError(t, err) @@ -70,8 +93,8 @@ func TestServerConnPublishSetupPath(t *testing.T) { }, nil } - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { - setupDone <- trackID + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { + setupDone <- pathTrackIDPair{path, trackID} return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -116,7 +139,7 @@ func TestServerConnPublishSetupPath(t *testing.T) { err = base.Request{ Method: base.Announce, - URL: base.MustParseURL("rtsp://localhost:8554/teststream"), + URL: base.MustParseURL("rtsp://localhost:8554/" + ca.path), Header: base.Header{ "CSeq": base.HeaderValue{"1"}, "Content-Type": base.HeaderValue{"application/sdp"}, @@ -153,8 +176,9 @@ func TestServerConnPublishSetupPath(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) - trackID := <-setupDone - require.Equal(t, ca.trackID, trackID) + pair := <-setupDone + require.Equal(t, ca.path, pair.path) + require.Equal(t, ca.trackID, pair.trackID) err = res.Read(bconn.Reader) require.NoError(t, err) @@ -197,7 +221,7 @@ func TestServerConnPublishReceivePackets(t *testing.T) { }, nil } - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, }, nil diff --git a/serverconnread_test.go b/serverconnread_test.go index ed51698d..55b36469 100644 --- a/serverconnread_test.go +++ b/serverconnread_test.go @@ -17,21 +17,47 @@ func TestServerConnReadSetupPath(t *testing.T) { for _, ca := range []struct { name string url string + path string trackID int }{ { "normal", "rtsp://localhost:8554/teststream/trackID=0", + "teststream", 0, }, { "unordered id", "rtsp://localhost:8554/teststream/trackID=2", + "teststream", 2, }, + { + // this is needed to support reading mpegts with ffmpeg + "without track id", + "rtsp://localhost:8554/teststream/", + "teststream", + 0, + }, + { + "subpath", + "rtsp://localhost:8554/test/stream/trackID=0", + "test/stream", + 0, + }, + { + "subpath without track id", + "rtsp://localhost:8554/test/stream/", + "test/stream", + 0, + }, } { t.Run(ca.name, func(t *testing.T) { - setupDone := make(chan int) + type pathTrackIDPair struct { + path string + trackID int + } + setupDone := make(chan pathTrackIDPair) s, err := Serve("127.0.0.1:8554") require.NoError(t, err) @@ -46,8 +72,8 @@ func TestServerConnReadSetupPath(t *testing.T) { require.NoError(t, err) defer conn.Close() - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { - setupDone <- trackID + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { + setupDone <- pathTrackIDPair{path, trackID} return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -87,8 +113,9 @@ func TestServerConnReadSetupPath(t *testing.T) { }.Write(bconn.Writer) require.NoError(t, err) - trackID := <-setupDone - require.Equal(t, ca.trackID, trackID) + pair := <-setupDone + require.Equal(t, ca.path, pair.path) + require.Equal(t, ca.trackID, pair.trackID) var res base.Response err = res.Read(bconn.Reader) @@ -124,7 +151,7 @@ func TestServerConnReadReceivePackets(t *testing.T) { require.NoError(t, err) defer conn.Close() - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -231,94 +258,6 @@ func TestServerConnReadReceivePackets(t *testing.T) { } } -func TestServerConnReadWithoutSetupTrackID(t *testing.T) { - s, err := Serve("127.0.0.1:8554") - require.NoError(t, err) - defer s.Close() - - serverDone := make(chan struct{}) - defer func() { <-serverDone }() - go func() { - defer close(serverDone) - - conn, err := s.Accept() - require.NoError(t, err) - defer conn.Close() - - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - onPlay := func(req *base.Request) (*base.Response, error) { - go func() { - time.Sleep(100 * time.Millisecond) - conn.WriteFrame(0, StreamTypeRTP, []byte("\x00\x00\x00\x00")) - }() - - return &base.Response{ - StatusCode: base.StatusOK, - }, nil - } - - err = <-conn.Read(ServerConnReadHandlers{ - OnSetup: onSetup, - OnPlay: onPlay, - }) - require.Equal(t, io.EOF, err) - }() - - conn, err := net.Dial("tcp", "localhost:8554") - require.NoError(t, err) - defer conn.Close() - bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) - - err = base.Request{ - Method: base.Setup, - URL: base.MustParseURL("rtsp://localhost:8554/teststream"), - Header: base.Header{ - "CSeq": base.HeaderValue{"1"}, - "Transport": headers.Transport{ - Protocol: StreamProtocolTCP, - Delivery: func() *base.StreamDelivery { - v := base.StreamDeliveryUnicast - return &v - }(), - Mode: func() *headers.TransportMode { - v := headers.TransportModePlay - return &v - }(), - InterleavedIds: &[2]int{0, 1}, - }.Write(), - }, - }.Write(bconn.Writer) - require.NoError(t, err) - - var res base.Response - err = res.Read(bconn.Reader) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - - err = base.Request{ - Method: base.Play, - URL: base.MustParseURL("rtsp://localhost:8554/teststream"), - Header: base.Header{ - "CSeq": base.HeaderValue{"2"}, - }, - }.Write(bconn.Writer) - require.NoError(t, err) - - err = res.Read(bconn.Reader) - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - - var fr base.InterleavedFrame - fr.Payload = make([]byte, 2048) - err = fr.Read(bconn.Reader) - require.NoError(t, err) -} - func TestServerConnReadTCPResponseBeforeFrames(t *testing.T) { s, err := Serve("127.0.0.1:8554") require.NoError(t, err) @@ -338,7 +277,7 @@ func TestServerConnReadTCPResponseBeforeFrames(t *testing.T) { writerTerminate := make(chan struct{}) defer close(writerTerminate) - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -446,7 +385,7 @@ func TestServerConnReadPlayMultiple(t *testing.T) { writerTerminate := make(chan struct{}) defer close(writerTerminate) - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, }, nil @@ -561,7 +500,7 @@ func TestServerConnReadPauseMultiple(t *testing.T) { writerTerminate := make(chan struct{}) defer close(writerTerminate) - onSetup := func(req *base.Request, th *headers.Transport, trackID int) (*base.Response, error) { + onSetup := func(req *base.Request, th *headers.Transport, path string, trackID int) (*base.Response, error) { return &base.Response{ StatusCode: base.StatusOK, }, nil diff --git a/serverudpl.go b/serverudpl.go index 978875cd..226297ff 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -121,8 +121,8 @@ func (s *serverUDPListener) run() { if clientData.isPublishing { now := time.Now() - atomic.StoreInt64(clientData.sc.announcedTracks[clientData.trackID].udpLastFrameTime, now.Unix()) - clientData.sc.announcedTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n]) + atomic.StoreInt64(clientData.sc.publishTracks[clientData.trackID].udpLastFrameTime, now.Unix()) + clientData.sc.publishTracks[clientData.trackID].rtcpReceiver.ProcessFrame(now, s.streamType, buf[:n]) } clientData.sc.readHandlers.OnFrame(clientData.trackID, s.streamType, buf[:n])