diff --git a/connclient.go b/connclient.go index d0f8e9c7..a7ef467a 100644 --- a/connclient.go +++ b/connclient.go @@ -292,6 +292,37 @@ func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) { } if res.StatusCode != base.StatusOK { + // redirect + if !c.d.RedirectDisable && + res.StatusCode >= base.StatusMovedPermanently && + res.StatusCode <= base.StatusUseProxy && + len(res.Header["Location"]) == 1 { + + c.Close() + + u, err := base.ParseURL(res.Header["Location"][0]) + if err != nil { + return nil, nil, err + } + + nc, err := c.d.Dial(u.Host) + if err != nil { + return nil, nil, err + } + *c = *nc + + res, err := c.Options(u) + if err != nil { + // since this method is not implemented by every RTSP server, + // return only if status code is not 404 + if res == nil || res.StatusCode != base.StatusNotFound { + return nil, nil, err + } + } + + return c.Describe(u) + } + return nil, res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } @@ -309,55 +340,18 @@ func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) { return nil, nil, err } + for _, t := range tracks { + t.BaseUrl = u + } + return tracks, res, nil } -// build an URL by merging baseUrl with the control attribute from track.Media. -func (c *ConnClient) urlForTrack(baseUrl *base.URL, mode headers.TransportMode, track *Track) *base.URL { - control := func() string { - // if we're publishing, get control from track ID - if mode == headers.TransportModeRecord { - return "trackID=" + strconv.FormatInt(int64(track.Id), 10) - } - - // otherwise, get from media attributes - for _, attr := range track.Media.Attributes { - if attr.Key == "control" { - return attr.Value - } - } - return "" - }() - - // no control attribute, use base URL - if control == "" { - return baseUrl - } - - // control attribute contains an absolute path - if strings.HasPrefix(control, "rtsp://") { - newUrl, err := base.ParseURL(control) - if err != nil { - return baseUrl - } - - // copy host and credentials - newUrl.Host = baseUrl.Host - newUrl.User = baseUrl.User - return newUrl - } - - // control attribute contains a relative control attribute - newUrl := baseUrl.Clone() - newUrl.AddControlAttribute(control) - return newUrl -} - // Setup writes a SETUP request and reads a Response. // rtpPort and rtcpPort are used only if protocol is UDP. // if rtpPort and rtcpPort are zero, they are chosen automatically. -func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, - track *Track, rtpPort int, rtcpPort int) (*base.Response, error) { +func (c *ConnClient) Setup(mode headers.TransportMode, track *Track, + rtpPort int, rtcpPort int) (*base.Response, error) { err := c.checkState(map[connClientState]struct{}{ connClientStateInitial: {}, connClientStatePrePlay: {}, @@ -376,8 +370,8 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, return nil, fmt.Errorf("cannot read and publish at the same time") } - if c.streamUrl != nil && *u != *c.streamUrl { - return nil, fmt.Errorf("setup has already begun with another url") + if c.streamUrl != nil && *track.BaseUrl != *c.streamUrl { + return nil, fmt.Errorf("cannot setup tracks with different base urls") } var rtpListener *connClientUDPListener @@ -465,9 +459,18 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, transport.InterleavedIds = &[2]int{(track.Id * 2), (track.Id * 2) + 1} } + trackUrl, err := track.Url(mode) + if err != nil { + if proto == StreamProtocolUDP { + rtpListener.close() + rtcpListener.close() + } + return nil, err + } + res, err := c.Do(&base.Request{ Method: base.SETUP, - URL: c.urlForTrack(u, mode, track), + URL: trackUrl, Header: base.Header{ "Transport": transport.Write(), }, @@ -494,7 +497,7 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, v := StreamProtocolTCP c.streamProtocol = &v - return c.Setup(u, headers.TransportModePlay, track, 0, 0) + return c.Setup(headers.TransportModePlay, track, 0, 0) } return res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) @@ -545,7 +548,7 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, c.rtcpSenders[track.Id] = rtcpsender.New(clockRate) } - c.streamUrl = u + c.streamUrl = track.BaseUrl c.streamProtocol = &proto c.tracks = append(c.tracks, track) diff --git a/connclientpublish.go b/connclientpublish.go index 4eae86f3..b08ded56 100644 --- a/connclientpublish.go +++ b/connclientpublish.go @@ -16,6 +16,12 @@ func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error return nil, err } + // fill id and base url + for i, t := range tracks { + t.Id = i + t.BaseUrl = u + } + res, err := c.Do(&base.Request{ Method: base.ANNOUNCE, URL: u, diff --git a/dialer.go b/dialer.go index 8991c3a6..c015ffba 100644 --- a/dialer.go +++ b/dialer.go @@ -33,31 +33,37 @@ func DialPublish(address string, tracks Tracks) (*ConnClient, error) { } // Dialer allows to initialize a ConnClient. +// All fields are optional. type Dialer struct { - // (optional) the stream protocol (UDP or TCP). + // the stream protocol (UDP or TCP). // If nil, it is chosen automatically (first UDP, then, if it fails, TCP). + // It defaults to nil. StreamProtocol *StreamProtocol - // (optional) timeout of read operations. - // It defaults to 10 seconds + // timeout of read operations. + // It defaults to 10 seconds. ReadTimeout time.Duration - // (optional) timeout of write operations. - // It defaults to 10 seconds + // timeout of write operations. + // It defaults to 10 seconds. WriteTimeout time.Duration - // (optional) read buffer count. + // disable being redirected to other servers, that can happen during Describe(). + // It defaults to false. + RedirectDisable bool + + // read buffer count. // If greater than 1, allows to pass buffers to routines different than the one // that is reading frames. - // It defaults to 1 + // It defaults to 1. ReadBufferCount int - // (optional) function used to initialize the TCP client. - // It defaults to net.DialTimeout + // function used to initialize the TCP client. + // It defaults to net.DialTimeout. DialTimeout func(network, address string, timeout time.Duration) (net.Conn, error) - // (optional) function used to initialize UDP listeners. - // It defaults to net.ListenPacket + // function used to initialize UDP listeners. + // It defaults to net.ListenPacket. ListenPacket func(network, address string) (net.PacketConn, error) } @@ -127,20 +133,12 @@ func (d Dialer) DialRead(address string) (*ConnClient, error) { tracks, res, err := conn.Describe(u) if err != nil { - // redirect - if res != nil && res.StatusCode >= base.StatusMovedPermanently && - res.StatusCode <= base.StatusUseProxy && - len(res.Header["Location"]) == 1 { - conn.Close() - return d.DialRead(res.Header["Location"][0]) - } - conn.Close() return nil, err } for _, track := range tracks { - _, err := conn.Setup(u, headers.TransportModePlay, track, 0, 0) + _, err := conn.Setup(headers.TransportModePlay, track, 0, 0) if err != nil { conn.Close() return nil, err @@ -185,7 +183,7 @@ func (d Dialer) DialPublish(address string, tracks Tracks) (*ConnClient, error) } for _, track := range tracks { - _, err := conn.Setup(u, headers.TransportModeRecord, track, 0, 0) + _, err := conn.Setup(headers.TransportModeRecord, track, 0, 0) if err != nil { conn.Close() return nil, err diff --git a/dialer_test.go b/dialer_test.go index 51df989f..e201a4c3 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -301,7 +301,7 @@ func TestDialPublishSerial(t *testing.T) { sps, pps, err := decoder.ReadSPSPPS() require.NoError(t, err) - track, err := NewTrackH264(0, sps, pps) + track, err := NewTrackH264(96, sps, pps) require.NoError(t, err) dialer := Dialer{ @@ -381,7 +381,7 @@ func TestDialPublishParallel(t *testing.T) { sps, pps, err := decoder.ReadSPSPPS() require.NoError(t, err) - track, err := NewTrackH264(0, sps, pps) + track, err := NewTrackH264(96, sps, pps) require.NoError(t, err) writerDone := make(chan struct{}) @@ -475,7 +475,7 @@ func TestDialPublishPauseSerial(t *testing.T) { sps, pps, err := decoder.ReadSPSPPS() require.NoError(t, err) - track, err := NewTrackH264(0, sps, pps) + track, err := NewTrackH264(96, sps, pps) require.NoError(t, err) dialer := Dialer{ @@ -547,7 +547,7 @@ func TestDialPublishPauseParallel(t *testing.T) { sps, pps, err := decoder.ReadSPSPPS() require.NoError(t, err) - track, err := NewTrackH264(0, sps, pps) + track, err := NewTrackH264(96, sps, pps) require.NoError(t, err) dialer := Dialer{ diff --git a/examples/client-publish-options.go b/examples/client-publish-options.go index 36ad772a..2448ba9b 100644 --- a/examples/client-publish-options.go +++ b/examples/client-publish-options.go @@ -38,7 +38,7 @@ func main() { fmt.Println("stream connected") // create a H264 track - track, err := gortsplib.NewTrackH264(0, sps, pps) + track, err := gortsplib.NewTrackH264(96, sps, pps) if err != nil { panic(err) } diff --git a/examples/client-publish-pause.go b/examples/client-publish-pause.go index 44127bbd..8356ee69 100644 --- a/examples/client-publish-pause.go +++ b/examples/client-publish-pause.go @@ -39,7 +39,7 @@ func main() { fmt.Println("stream connected") // create a H264 track - track, err := gortsplib.NewTrackH264(0, sps, pps) + track, err := gortsplib.NewTrackH264(96, sps, pps) if err != nil { panic(err) } diff --git a/examples/client-publish.go b/examples/client-publish.go index 3e6913ab..b60e8ba1 100644 --- a/examples/client-publish.go +++ b/examples/client-publish.go @@ -36,7 +36,7 @@ func main() { fmt.Println("stream connected") // create a H264 track - track, err := gortsplib.NewTrackH264(0, sps, pps) + track, err := gortsplib.NewTrackH264(96, sps, pps) if err != nil { panic(err) } diff --git a/track.go b/track.go index 6d8f2dd2..263da335 100644 --- a/track.go +++ b/track.go @@ -10,28 +10,32 @@ import ( "github.com/notedit/rtmp/codec/aac" psdp "github.com/pion/sdp/v3" + "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/sdp" ) // Track is a track available in a certain URL. type Track struct { - // track id + // base url + BaseUrl *base.URL + + // id Id int - // track codec and info in SDP format + // codec and info in SDP format Media *psdp.MediaDescription } // NewTrackH264 initializes an H264 track. -func NewTrackH264(id int, sps []byte, pps []byte) (*Track, error) { +func NewTrackH264(payloadType uint8, sps []byte, pps []byte) (*Track, error) { spropParameterSets := base64.StdEncoding.EncodeToString(sps) + "," + base64.StdEncoding.EncodeToString(pps) profileLevelId := strings.ToUpper(hex.EncodeToString(sps[1:4])) - typ := strconv.FormatInt(int64(96+id), 10) + typ := strconv.FormatInt(int64(payloadType), 10) return &Track{ - Id: id, Media: &psdp.MediaDescription{ MediaName: psdp.MediaName{ Media: "video", @@ -55,7 +59,7 @@ func NewTrackH264(id int, sps []byte, pps []byte) (*Track, error) { } // NewTrackAAC initializes an AAC track. -func NewTrackAAC(id int, config []byte) (*Track, error) { +func NewTrackAAC(payloadType uint8, config []byte) (*Track, error) { codec, err := aac.FromMPEG4AudioConfigBytes(config) if err != nil { return nil, err @@ -77,10 +81,9 @@ func NewTrackAAC(id int, config []byte) (*Track, error) { return nil, err } - typ := strconv.FormatInt(int64(96+id), 10) + typ := strconv.FormatInt(int64(payloadType), 10) return &Track{ - Id: id, Media: &psdp.MediaDescription{ MediaName: psdp.MediaName{ Media: "audio", @@ -159,6 +162,51 @@ func (t *Track) ClockRate() (int, error) { return 0, fmt.Errorf("attribute 'rtpmap' not found") } +// Url returns the track url. +func (t *Track) Url(mode headers.TransportMode) (*base.URL, error) { + if t.BaseUrl == nil { + return nil, fmt.Errorf("empty base url") + } + + control := func() string { + // if we're publishing, get control from track ID + if mode == headers.TransportModeRecord { + return "trackID=" + strconv.FormatInt(int64(t.Id), 10) + } + + // otherwise, get from media attributes + for _, attr := range t.Media.Attributes { + if attr.Key == "control" { + return attr.Value + } + } + return "" + }() + + // no control attribute, use base URL + if control == "" { + return t.BaseUrl, nil + } + + // control attribute contains an absolute path + if strings.HasPrefix(control, "rtsp://") { + ur, err := base.ParseURL(control) + if err != nil { + return nil, err + } + + // copy host and credentials + ur.Host = t.BaseUrl.Host + ur.User = t.BaseUrl.User + return ur, nil + } + + // control attribute contains a relative control attribute + ur := t.BaseUrl.Clone() + ur.AddControlAttribute(control) + return ur, nil +} + // Tracks is a list of tracks. type Tracks []*Track