diff --git a/client.go b/client.go index 57691af5..bdfc3442 100644 --- a/client.go +++ b/client.go @@ -161,14 +161,14 @@ func (c *Client) DialReadContext(ctx context.Context, address string) (*ClientCo return nil, err } - tracks, _, err := conn.Describe(u) + tracks, baseURL, _, err := conn.Describe(u) if err != nil { conn.Close() return nil, err } for _, track := range tracks { - _, err := conn.Setup(headers.TransportModePlay, track, 0, 0) + _, err := conn.Setup(headers.TransportModePlay, baseURL, track, 0, 0) if err != nil { conn.Close() return nil, err @@ -229,7 +229,7 @@ func (c *Client) DialPublishContext(ctx context.Context, address string, tracks } for _, track := range tracks { - _, err := conn.Setup(headers.TransportModeRecord, track, 0, 0) + _, err := conn.Setup(headers.TransportModeRecord, u, track, 0, 0) if err != nil { conn.Close() return nil, err diff --git a/client_publish_test.go b/client_publish_test.go index 2d3fe7d2..d51efc0a 100644 --- a/client_publish_test.go +++ b/client_publish_test.go @@ -191,7 +191,7 @@ func TestClientPublishSerial(t *testing.T) { }) }() - err = conn.WriteFrame(track.ID, StreamTypeRTP, + err = conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) require.NoError(t, err) @@ -199,7 +199,7 @@ func TestClientPublishSerial(t *testing.T) { conn.Close() <-done - err = conn.WriteFrame(track.ID, StreamTypeRTP, + err = conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) require.Error(t, err) }) @@ -332,7 +332,7 @@ func TestClientPublishParallel(t *testing.T) { defer t.Stop() for range t.C { - err := conn.WriteFrame(track.ID, StreamTypeRTP, + err := conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) if err != nil { return @@ -480,21 +480,21 @@ func TestClientPublishPauseSerial(t *testing.T) { require.NoError(t, err) defer conn.Close() - err = conn.WriteFrame(track.ID, StreamTypeRTP, + err = conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) require.NoError(t, err) _, err = conn.Pause() require.NoError(t, err) - err = conn.WriteFrame(track.ID, StreamTypeRTP, + err = conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) require.Error(t, err) _, err = conn.Record() require.NoError(t, err) - err = conn.WriteFrame(track.ID, StreamTypeRTP, + err = conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) require.NoError(t, err) }) @@ -625,7 +625,7 @@ func TestClientPublishPauseParallel(t *testing.T) { defer t.Stop() for range t.C { - err := conn.WriteFrame(track.ID, StreamTypeRTP, + err := conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) if err != nil { return @@ -756,7 +756,7 @@ func TestClientPublishAutomaticProtocol(t *testing.T) { require.NoError(t, err) defer conn.Close() - err = conn.WriteFrame(track.ID, StreamTypeRTP, + err = conn.WriteFrame(0, StreamTypeRTP, []byte{0x01, 0x02, 0x03, 0x04}) require.NoError(t, err) } @@ -909,11 +909,11 @@ func TestClientPublishRTCPReport(t *testing.T) { }, Payload: []byte{0x01, 0x02, 0x03, 0x04}, }).Marshal() - err = conn.WriteFrame(track.ID, StreamTypeRTP, byts) + err = conn.WriteFrame(0, StreamTypeRTP, byts) require.NoError(t, err) time.Sleep(1300 * time.Millisecond) - err = conn.WriteFrame(track.ID, StreamTypeRTP, byts) + err = conn.WriteFrame(0, StreamTypeRTP, byts) require.NoError(t, err) } diff --git a/client_read_test.go b/client_read_test.go index 654732b5..d3e63d17 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -150,19 +150,13 @@ func TestClientReadTracks(t *testing.T) { require.Equal(t, Tracks{ { - ID: 0, - BaseURL: mustParseURL("rtsp://localhost:8554/teststream/"), - Media: track1.Media, + Media: track1.Media, }, { - ID: 1, - BaseURL: mustParseURL("rtsp://localhost:8554/teststream/"), - Media: track2.Media, + Media: track2.Media, }, { - ID: 2, - BaseURL: mustParseURL("rtsp://localhost:8554/teststream/"), - Media: track3.Media, + Media: track3.Media, }, }, conn.Tracks()) } @@ -449,6 +443,140 @@ func TestClientRead(t *testing.T) { } } +func TestClientReadPartial(t *testing.T) { + listenIP := multicastCapableIP(t) + l, err := net.Listen("tcp", listenIP+":8554") + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + conn, err := l.Accept() + require.NoError(t, err) + defer conn.Close() + bconn := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) + + req, err := readRequest(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) + require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream"), req.URL) + + track1, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + + track2, err := NewTrackH264(96, []byte("123456"), []byte("123456")) + require.NoError(t, err) + + tracks := cloneAndClearTracks(Tracks{track1, track2}) + + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://" + listenIP + ":8554/teststream/"}, + }, + Body: tracks.Write(), + }.Write(bconn.Writer) + require.NoError(t, err) + + req, err = readRequest(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/trackID=1"), req.URL) + + var inTH headers.Transport + err = inTH.Read(req.Header["Transport"]) + require.NoError(t, err) + require.Equal(t, &[2]int{0, 1}, inTH.InterleavedIDs) + + th := headers.Transport{ + Delivery: func() *base.StreamDelivery { + v := base.StreamDeliveryUnicast + return &v + }(), + Protocol: base.StreamProtocolTCP, + InterleavedIDs: inTH.InterleavedIDs, + } + + err = base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": th.Write(), + }, + }.Write(bconn.Writer) + require.NoError(t, err) + + req, err = readRequest(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/"), req.URL) + + err = base.Response{ + StatusCode: base.StatusOK, + }.Write(bconn.Writer) + require.NoError(t, err) + + err = base.InterleavedFrame{ + TrackID: 0, + StreamType: StreamTypeRTP, + Payload: []byte{0x01, 0x02, 0x03, 0x04}, + }.Write(bconn.Writer) + require.NoError(t, err) + + req, err = readRequest(bconn.Reader) + require.NoError(t, err) + require.Equal(t, base.Teardown, req.Method) + require.Equal(t, mustParseURL("rtsp://"+listenIP+":8554/teststream/"), req.URL) + + err = base.Response{ + StatusCode: base.StatusOK, + }.Write(bconn.Writer) + require.NoError(t, err) + }() + + c := &Client{ + Protocol: func() *ClientProtocol { + v := ClientProtocolTCP + return &v + }(), + } + + u, err := base.ParseURL("rtsp://" + listenIP + ":8554/teststream") + require.NoError(t, err) + + conn, err := c.Dial(u.Scheme, u.Host) + require.NoError(t, err) + defer conn.Close() + + tracks, baseURL, _, err := conn.Describe(u) + require.NoError(t, err) + + _, err = conn.Setup(headers.TransportModePlay, baseURL, tracks[1], 0, 0) + require.NoError(t, err) + + _, err = conn.Play(nil) + require.NoError(t, err) + + done := make(chan struct{}) + frameRecv := make(chan struct{}) + go func() { + defer close(done) + conn.ReadFrames(func(id int, streamType StreamType, payload []byte) { + require.Equal(t, 0, id) + require.Equal(t, StreamTypeRTP, streamType) + require.Equal(t, []byte{0x01, 0x02, 0x03, 0x04}, payload) + close(frameRecv) + }) + }() + + <-frameRecv + conn.Close() + <-done +} + func TestClientReadNoContentBase(t *testing.T) { l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) @@ -2005,11 +2133,11 @@ func TestClientReadSeek(t *testing.T) { _, err = conn.Options(u) require.NoError(t, err) - tracks, _, err := conn.Describe(u) + tracks, baseURL, _, err := conn.Describe(u) require.NoError(t, err) for _, track := range tracks { - _, err := conn.Setup(headers.TransportModePlay, track, 0, 0) + _, err := conn.Setup(headers.TransportModePlay, baseURL, track, 0, 0) require.NoError(t, err) } diff --git a/client_test.go b/client_test.go index b55a48a6..ddd7ed6b 100644 --- a/client_test.go +++ b/client_test.go @@ -95,7 +95,7 @@ func TestClientSession(t *testing.T) { _, err = conn.Options(u) require.NoError(t, err) - _, _, err = conn.Describe(u) + _, _, _, err = conn.Describe(u) require.NoError(t, err) } @@ -174,6 +174,6 @@ func TestClientAuth(t *testing.T) { _, err = conn.Options(u) require.NoError(t, err) - _, _, err = conn.Describe(u) + _, _, _, err = conn.Describe(u) require.NoError(t, err) } diff --git a/clientconn.go b/clientconn.go index 1a1240bc..d981d29b 100644 --- a/clientconn.go +++ b/clientconn.go @@ -92,6 +92,7 @@ type announceReq struct { type setupReq struct { mode headers.TransportMode + baseURL *base.URL track *Track rtpPort int rtcpPort int @@ -112,9 +113,10 @@ type pauseReq struct { } type clientRes struct { - tracks Tracks - res *base.Response - err error + tracks Tracks + baseURL *base.URL + res *base.Response + err error } // ClientConn is a client-side RTSP connection. @@ -235,17 +237,20 @@ func (cc *ClientConn) Close() error { // Tracks returns all the tracks that the connection is reading or publishing. func (cc *ClientConn) Tracks() Tracks { - var ret Tracks - - for _, track := range cc.tracks { - ret = append(ret, track.track) + ids := make([]int, len(cc.tracks)) + pos := 0 + for id := range cc.tracks { + ids[pos] = id + pos++ } - - // sort by ID to generate correct SDPs - sort.Slice(ret, func(i, j int) bool { - return ret[i].ID < ret[j].ID + sort.Slice(ids, func(a, b int) bool { + return ids[a] < ids[b] }) + var ret Tracks + for _, id := range ids { + ret = append(ret, cc.tracks[id].track) + } return ret } @@ -260,15 +265,15 @@ outer: req.res <- clientRes{res: res, err: err} case req := <-cc.describe: - tracks, res, err := cc.doDescribe(req.url) - req.res <- clientRes{tracks: tracks, res: res, err: err} + tracks, baseURL, res, err := cc.doDescribe(req.url) + req.res <- clientRes{tracks: tracks, baseURL: baseURL, res: res, err: err} case req := <-cc.announce: res, err := cc.doAnnounce(req.url, req.tracks) req.res <- clientRes{res: res, err: err} case req := <-cc.setup: - res, err := cc.doSetup(req.mode, req.track, req.rtpPort, req.rtcpPort) + res, err := cc.doSetup(req.mode, req.baseURL, req.track, req.rtpPort, req.rtcpPort) req.res <- clientRes{res: res, err: err} case req := <-cc.play: @@ -388,7 +393,7 @@ func (cc *ClientConn) switchProtocolIfTimeout(err error) error { } for _, track := range prevTracks { - _, err := cc.doSetup(headers.TransportModePlay, track.track, 0, 0) + _, err := cc.doSetup(headers.TransportModePlay, prevBaseURL, track.track, 0, 0) if err != nil { return err } @@ -969,14 +974,14 @@ func (cc *ClientConn) Options(u *base.URL) (*base.Response, error) { } } -func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) { +func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.URL, *base.Response, error) { err := cc.checkState(map[clientConnState]struct{}{ clientConnStateInitial: {}, clientConnStatePrePlay: {}, clientConnStatePreRecord: {}, }) if err != nil { - return nil, nil, err + return nil, nil, nil, err } res, err := cc.do(&base.Request{ @@ -987,7 +992,7 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) { }, }, false) if err != nil { - return nil, nil, err + return nil, nil, nil, err } if res.StatusCode != base.StatusOK { @@ -1001,7 +1006,7 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) { u, err := base.ParseURL(res.Header["Location"][0]) if err != nil { - return nil, nil, err + return nil, nil, nil, err } cc.scheme = u.Scheme @@ -1009,27 +1014,27 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) { err = cc.connOpen() if err != nil { - return nil, nil, err + return nil, nil, nil, err } _, err = cc.doOptions(u) if err != nil { - return nil, nil, err + return nil, nil, nil, err } return cc.doDescribe(u) } - return nil, res, liberrors.ErrClientInvalidStatusCode{Code: res.StatusCode, Message: res.StatusMessage} + return nil, nil, res, liberrors.ErrClientInvalidStatusCode{Code: res.StatusCode, Message: res.StatusMessage} } ct, ok := res.Header["Content-Type"] if !ok || len(ct) != 1 { - return nil, nil, liberrors.ErrClientContentTypeMissing{} + return nil, nil, nil, liberrors.ErrClientContentTypeMissing{} } if ct[0] != "application/sdp" { - return nil, nil, liberrors.ErrClientContentTypeUnsupported{CT: ct} + return nil, nil, nil, liberrors.ErrClientContentTypeUnsupported{CT: ct} } baseURL, err := func() (*base.URL, error) { @@ -1054,27 +1059,27 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) { return u, nil }() if err != nil { - return nil, nil, err + return nil, nil, nil, err } - tracks, err := ReadTracks(res.Body, baseURL) + tracks, err := ReadTracks(res.Body) if err != nil { - return nil, nil, err + return nil, nil, nil, err } - return tracks, res, nil + return tracks, baseURL, res, nil } // Describe writes a DESCRIBE request and reads a Response. -func (cc *ClientConn) Describe(u *base.URL) (Tracks, *base.Response, error) { +func (cc *ClientConn) Describe(u *base.URL) (Tracks, *base.URL, *base.Response, error) { cres := make(chan clientRes) select { case cc.describe <- describeReq{url: u, res: cres}: res := <-cres - return res.tracks, res.res, res.err + return res.tracks, res.baseURL, res.res, res.err case <-cc.ctx.Done(): - return nil, nil, liberrors.ErrClientTerminated{} + return nil, nil, nil, liberrors.ErrClientTerminated{} } } @@ -1090,11 +1095,7 @@ func (cc *ClientConn) doAnnounce(u *base.URL, tracks Tracks) (*base.Response, er // (tested with ffmpeg and gstreamer) baseURL := u.Clone() - // set ID, base URL, control attribute of tracks for i, t := range tracks { - t.ID = i - t.BaseURL = baseURL - if !t.hasControlAttribute() { t.Media.Attributes = append(t.Media.Attributes, psdp.Attribute{ Key: "control", @@ -1142,6 +1143,7 @@ func (cc *ClientConn) Announce(u *base.URL, tracks Tracks) (*base.Response, erro func (cc *ClientConn) doSetup( mode headers.TransportMode, + baseURL *base.URL, track *Track, rtpPort int, rtcpPort int) (*base.Response, error) { @@ -1160,7 +1162,7 @@ func (cc *ClientConn) doSetup( return nil, liberrors.ErrClientCannotReadPublishAtSameTime{} } - if cc.streamBaseURL != nil && *track.BaseURL != *cc.streamBaseURL { + if cc.streamBaseURL != nil && *baseURL != *cc.streamBaseURL { return nil, liberrors.ErrClientCannotSetupTracksDifferentURLs{} } @@ -1192,6 +1194,8 @@ func (cc *ClientConn) doSetup( Mode: &mode, } + trackID := len(cc.tracks) + switch proto { case ClientProtocolUDP: if (rtpPort == 0 && rtcpPort != 0) || @@ -1236,10 +1240,10 @@ func (cc *ClientConn) doSetup( v1 := base.StreamDeliveryUnicast th.Delivery = &v1 th.Protocol = base.StreamProtocolTCP - th.InterleavedIDs = &[2]int{(track.ID * 2), (track.ID * 2) + 1} + th.InterleavedIDs = &[2]int{(trackID * 2), (trackID * 2) + 1} } - trackURL, err := track.URL() + trackURL, err := track.URL(baseURL) if err != nil { if proto == ClientProtocolUDP { rtpListener.close() @@ -1277,7 +1281,7 @@ func (cc *ClientConn) doSetup( v := ClientProtocolTCP cc.protocol = &v - return cc.doSetup(mode, track, 0, 0) + return cc.doSetup(mode, baseURL, track, 0, 0) } return res, liberrors.ErrClientInvalidStatusCode{Code: res.StatusCode, Message: res.StatusMessage} @@ -1360,7 +1364,7 @@ func (cc *ClientConn) doSetup( if thRes.ServerPorts != nil { rtpListener.remotePort = thRes.ServerPorts[0] } - rtpListener.trackID = track.ID + rtpListener.trackID = trackID rtpListener.streamType = StreamTypeRTP cct.udpRTPListener = rtpListener @@ -1369,7 +1373,7 @@ func (cc *ClientConn) doSetup( if thRes.ServerPorts != nil { rtcpListener.remotePort = thRes.ServerPorts[1] } - rtcpListener.trackID = track.ID + rtcpListener.trackID = trackID rtcpListener.streamType = StreamTypeRTCP cct.udpRTCPListener = rtcpListener @@ -1377,14 +1381,14 @@ func (cc *ClientConn) doSetup( rtpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtpListener.remoteZone = "" rtpListener.remotePort = thRes.Ports[0] - rtpListener.trackID = track.ID + rtpListener.trackID = trackID rtpListener.streamType = StreamTypeRTP cct.udpRTPListener = rtpListener rtcpListener.remoteIP = cc.nconn.RemoteAddr().(*net.TCPAddr).IP rtcpListener.remoteZone = "" rtcpListener.remotePort = thRes.Ports[1] - rtcpListener.trackID = track.ID + rtcpListener.trackID = trackID rtcpListener.streamType = StreamTypeRTCP cct.udpRTCPListener = rtcpListener @@ -1402,10 +1406,10 @@ func (cc *ClientConn) doSetup( cct.rtcpSender = rtcpsender.New(clockRate) } - cc.streamBaseURL = track.BaseURL + cc.streamBaseURL = baseURL cc.protocol = &proto - cc.tracks[track.ID] = cct + cc.tracks[trackID] = cct return res, nil } @@ -1415,6 +1419,7 @@ func (cc *ClientConn) doSetup( // if rtpPort and rtcpPort are zero, they are chosen automatically. func (cc *ClientConn) Setup( mode headers.TransportMode, + baseURL *base.URL, track *Track, rtpPort int, rtcpPort int) (*base.Response, error) { @@ -1422,6 +1427,7 @@ func (cc *ClientConn) Setup( select { case cc.setup <- setupReq{ mode: mode, + baseURL: baseURL, track: track, rtpPort: rtpPort, rtcpPort: rtcpPort, diff --git a/examples/client-publish-options/main.go b/examples/client-publish-options/main.go index aa08bd0c..d89d80c5 100644 --- a/examples/client-publish-options/main.go +++ b/examples/client-publish-options/main.go @@ -68,7 +68,7 @@ func main() { } // write RTP frames - err = conn.WriteFrame(track.ID, gortsplib.StreamTypeRTP, buf[:n]) + err = conn.WriteFrame(0, gortsplib.StreamTypeRTP, buf[:n]) if err != nil { panic(err) } diff --git a/examples/client-publish-pause/main.go b/examples/client-publish-pause/main.go index 3ac0c6e9..9e605d6f 100644 --- a/examples/client-publish-pause/main.go +++ b/examples/client-publish-pause/main.go @@ -64,7 +64,7 @@ func main() { } // write RTP frames - err = conn.WriteFrame(track.ID, gortsplib.StreamTypeRTP, buf[:n]) + err = conn.WriteFrame(0, gortsplib.StreamTypeRTP, buf[:n]) if err != nil { break } diff --git a/examples/client-publish/main.go b/examples/client-publish/main.go index 6fa7e4a9..612163cf 100644 --- a/examples/client-publish/main.go +++ b/examples/client-publish/main.go @@ -56,7 +56,7 @@ func main() { } // write RTP frames - err = conn.WriteFrame(track.ID, gortsplib.StreamTypeRTP, buf[:n]) + err = conn.WriteFrame(0, gortsplib.StreamTypeRTP, buf[:n]) if err != nil { panic(err) } diff --git a/examples/client-query/main.go b/examples/client-query/main.go index 9982667d..3dbb92b3 100644 --- a/examples/client-query/main.go +++ b/examples/client-query/main.go @@ -28,7 +28,7 @@ func main() { panic(err) } - tracks, _, err := conn.Describe(u) + tracks, _, _, err := conn.Describe(u) if err != nil { panic(err) } diff --git a/examples/client-read-h264/main.go b/examples/client-read-h264/main.go index c31dd1c5..78f0db83 100644 --- a/examples/client-read-h264/main.go +++ b/examples/client-read-h264/main.go @@ -22,9 +22,9 @@ func main() { // check whether there's a H264 track h264Track := func() int { - for _, track := range conn.Tracks() { + for i, track := range conn.Tracks() { if track.IsH264() { - return track.ID + return i } } return -1 diff --git a/examples/client-read-partial/main.go b/examples/client-read-partial/main.go index 3795ea4a..b9c8c3db 100644 --- a/examples/client-read-partial/main.go +++ b/examples/client-read-partial/main.go @@ -30,7 +30,7 @@ func main() { panic(err) } - tracks, _, err := conn.Describe(u) + tracks, baseURL, _, err := conn.Describe(u) if err != nil { panic(err) } @@ -38,7 +38,7 @@ func main() { // start reading only video tracks, skipping audio or application tracks for _, t := range tracks { if t.Media.MediaName.Media == "video" { - _, err := conn.Setup(headers.TransportModePlay, t, 0, 0) + _, err := conn.Setup(headers.TransportModePlay, baseURL, t, 0, 0) if err != nil { panic(err) } diff --git a/serversession.go b/serversession.go index 6de3f316..4ac81356 100644 --- a/serversession.go +++ b/serversession.go @@ -24,7 +24,10 @@ func setupGetTrackIDPathQuery( url *base.URL, thMode *headers.TransportMode, announcedTracks []ServerSessionAnnouncedTrack, - setuppedPath *string, setuppedQuery *string) (int, string, string, error) { + setuppedPath *string, + setuppedQuery *string, + setuppedBaseURL *base.URL, +) (int, string, string, error) { pathAndQuery, ok := url.RTSPPathAndQuery() if !ok { return 0, "", "", liberrors.ErrServerInvalidPath{} @@ -63,7 +66,7 @@ func setupGetTrackIDPathQuery( } for trackID, track := range announcedTracks { - u, _ := track.track.URL() + u, _ := track.track.URL(setuppedBaseURL) if u.String() == url.String() { return trackID, *setuppedPath, *setuppedQuery, nil } @@ -126,6 +129,7 @@ type ServerSession struct { setuppedTracks map[int]ServerSessionSetuppedTrack setuppedProtocol *base.StreamProtocol setuppedDelivery *base.StreamDelivery + setuppedBaseURL *base.URL // publish setuppedStream *ServerStream // read setuppedPath *string setuppedQuery *string @@ -443,7 +447,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base }, liberrors.ErrServerContentTypeUnsupported{CT: ct} } - tracks, err := ReadTracks(req.Body, req.URL) + tracks, err := ReadTracks(req.Body) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -457,7 +461,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } for _, track := range tracks { - trackURL, err := track.URL() + trackURL, err := track.URL(req.URL) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, @@ -493,6 +497,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base ss.state = ServerSessionStatePreRecord ss.setuppedPath = &path ss.setuppedQuery = &query + ss.setuppedBaseURL = req.URL ss.announcedTracks = make([]ServerSessionAnnouncedTrack, len(tracks)) for trackID, track := range tracks { @@ -530,7 +535,7 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base } trackID, path, query, err := setupGetTrackIDPathQuery(req.URL, inTH.Mode, - ss.announcedTracks, ss.setuppedPath, ss.setuppedQuery) + ss.announcedTracks, ss.setuppedPath, ss.setuppedQuery, ss.setuppedBaseURL) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, diff --git a/track.go b/track.go index 7ab9f4f7..07cd111b 100644 --- a/track.go +++ b/track.go @@ -14,15 +14,9 @@ import ( "github.com/aler9/gortsplib/pkg/sdp" ) -// Track is a track available in a certain URL. +// Track is a RTSP track. type Track struct { - // base URL - BaseURL *base.URL - - // id - ID int - - // codec and info in SDP format + // attributes in SDP format Media *psdp.MediaDescription } @@ -36,8 +30,8 @@ func (t *Track) hasControlAttribute() bool { } // URL returns the track url. -func (t *Track) URL() (*base.URL, error) { - if t.BaseURL == nil { +func (t *Track) URL(baseURL *base.URL) (*base.URL, error) { + if baseURL == nil { return nil, fmt.Errorf("empty base url") } @@ -52,7 +46,7 @@ func (t *Track) URL() (*base.URL, error) { // no control attribute, use base URL if controlAttr == "" { - return t.BaseURL, nil + return baseURL, nil } // control attribute contains an absolute path @@ -63,8 +57,8 @@ func (t *Track) URL() (*base.URL, error) { } // copy host and credentials - ur.Host = t.BaseURL.Host - ur.User = t.BaseURL.User + ur.Host = baseURL.Host + ur.User = baseURL.User return ur, nil } @@ -72,7 +66,7 @@ func (t *Track) URL() (*base.URL, error) { // insert the control attribute at the end of the url // if there's a query, insert it after the query // otherwise insert it after the path - strURL := t.BaseURL.String() + strURL := baseURL.String() if controlAttr[0] != '?' && !strings.HasSuffix(strURL, "/") { strURL += "/" } @@ -337,7 +331,7 @@ func (t *Track) ExtractDataAAC() ([]byte, error) { type Tracks []*Track // ReadTracks decodes tracks from SDP. -func ReadTracks(byts []byte, baseURL *base.URL) (Tracks, error) { +func ReadTracks(byts []byte) (Tracks, error) { desc := sdp.SessionDescription{} err := desc.Unmarshal(byts) if err != nil { @@ -348,9 +342,7 @@ func ReadTracks(byts []byte, baseURL *base.URL) (Tracks, error) { for i, media := range desc.MediaDescriptions { tracks[i] = &Track{ - BaseURL: baseURL, - ID: i, - Media: media, + Media: media, } } diff --git a/track_test.go b/track_test.go index 3354129d..223b9a4c 100644 --- a/track_test.go +++ b/track_test.go @@ -107,10 +107,9 @@ func TestTrackURL(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - tracks, err := ReadTracks(ca.sdp, nil) + tracks, err := ReadTracks(ca.sdp) require.NoError(t, err) - tracks[0].BaseURL = ca.baseURL - ur, err := tracks[0].URL() + ur, err := tracks[0].URL(ca.baseURL) require.NoError(t, err) require.Equal(t, ca.ur, ur) }) @@ -183,7 +182,7 @@ func TestTrackClockRate(t *testing.T) { }, } { t.Run(ca.name, func(t *testing.T) { - tracks, err := ReadTracks(ca.sdp, nil) + tracks, err := ReadTracks(ca.sdp) require.NoError(t, err) clockRate, err := tracks[0].ClockRate()