diff --git a/connclient.go b/connclient.go index e4d928d3..e9109ee8 100644 --- a/connclient.go +++ b/connclient.go @@ -228,8 +228,6 @@ func (c *ConnClient) Do(req *base.Request) (*base.Response, error) { } // Options writes an OPTIONS request and reads a response. -// Since this method is not implemented by every RTSP server, the function -// does not fail if the returned code is StatusNotFound. func (c *ConnClient) Options(u *base.URL) (*base.Response, error) { err := c.checkState(map[connClientState]struct{}{ connClientStateInitial: {}, @@ -248,8 +246,8 @@ func (c *ConnClient) Options(u *base.URL) (*base.Response, error) { return nil, err } - if res.StatusCode != base.StatusOK && res.StatusCode != base.StatusNotFound { - return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) + if res.StatusCode != base.StatusOK { + return res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } c.getParameterSupported = func() bool { @@ -291,36 +289,25 @@ func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) { return nil, nil, err } - switch res.StatusCode { - case base.StatusOK: - contentType, ok := res.Header["Content-Type"] - if !ok || len(contentType) != 1 { - return nil, nil, fmt.Errorf("Content-Type not provided") - } - - if contentType[0] != "application/sdp" { - return nil, nil, fmt.Errorf("wrong Content-Type, expected application/sdp") - } - - tracks, err := ReadTracks(res.Content) - if err != nil { - return nil, nil, err - } - - return tracks, res, nil - - case base.StatusMovedPermanently, base.StatusFound, - base.StatusSeeOther, base.StatusNotModified, base.StatusUseProxy: - location, ok := res.Header["Location"] - if !ok || len(location) != 1 { - return nil, nil, fmt.Errorf("Location not provided") - } - - return nil, res, nil - - default: - return nil, nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) + if res.StatusCode != base.StatusOK { + return nil, res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } + + contentType, ok := res.Header["Content-Type"] + if !ok || len(contentType) != 1 { + return nil, nil, fmt.Errorf("Content-Type not provided") + } + + if contentType[0] != "application/sdp" { + return nil, nil, fmt.Errorf("wrong Content-Type, expected application/sdp") + } + + tracks, err := ReadTracks(res.Content) + if err != nil { + return nil, nil, err + } + + return tracks, res, nil } // build an URL by merging baseUrl with the control attribute from track.Media. @@ -485,7 +472,7 @@ func (c *ConnClient) Setup(u *base.URL, mode headers.TransportMode, proto base.S rtpListener.close() rtcpListener.close() } - return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) + return res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } th, err := headers.ReadTransport(res.Header["Transport"]) @@ -575,7 +562,7 @@ func (c *ConnClient) Pause() (*base.Response, error) { } if res.StatusCode != base.StatusOK { - return nil, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) + return res, fmt.Errorf("bad status code: %d (%s)", res.StatusCode, res.StatusMessage) } switch c.state { diff --git a/dialer.go b/dialer.go index 766f99c5..2f42b68e 100644 --- a/dialer.go +++ b/dialer.go @@ -33,9 +33,9 @@ func DialPublish(address string, tracks Tracks) (*ConnClient, error) { // Dialer allows to initialize a ConnClient. type Dialer struct { - // (optional) the stream protocol. - // It defaults to StreamProtocolUDP. - StreamProtocol StreamProtocol + // (optional) the stream protocol (UDP or TCP). + // If nil, it is chosen automatically (first UDP, then, if it fails, TCP). + StreamProtocol *StreamProtocol // (optional) timeout of read operations. // It defaults to 10 seconds @@ -113,29 +113,53 @@ func (d Dialer) DialRead(address string) (*ConnClient, error) { return nil, err } - _, err = conn.Options(u) + res, err := conn.Options(u) if err != nil { - conn.Close() - return nil, err + // since this method is not implemented by every RTSP server, + // return only if status code is not 404 + if res == nil || res.StatusCode != base.StatusNotFound { + conn.Close() + return nil, err + } } tracks, res, err := conn.Describe(u) if err != nil { + // redirect + if res != nil && res.StatusCode >= base.StatusMovedPermanently && + res.StatusCode <= base.StatusUseProxy && + len(res.Header["Location"]) == 1 { + conn.Close() + return d.DialRead(res.Header["Location"][0]) + } + conn.Close() return nil, err } - if res.StatusCode >= base.StatusMovedPermanently && - res.StatusCode <= base.StatusUseProxy { - conn.Close() - return d.DialRead(res.Header["Location"][0]) - } + proto := func() StreamProtocol { + if d.StreamProtocol != nil { + return *d.StreamProtocol + } + return StreamProtocolUDP + }() - for _, track := range tracks { - _, err := conn.Setup(u, headers.TransportModePlay, d.StreamProtocol, track, 0, 0) + for i, track := range tracks { + res, err := conn.Setup(u, headers.TransportModePlay, proto, track, 0, 0) if err != nil { - conn.Close() - return nil, err + // switch protocol automatically + if i == 0 && d.StreamProtocol == nil && res != nil && + res.StatusCode == base.StatusUnsupportedTransport { + proto = StreamProtocolTCP + _, err := conn.Setup(u, headers.TransportModePlay, proto, track, 0, 0) + if err != nil { + conn.Close() + return nil, err + } + } else { + conn.Close() + return nil, err + } } } @@ -160,10 +184,14 @@ func (d Dialer) DialPublish(address string, tracks Tracks) (*ConnClient, error) return nil, err } - _, err = conn.Options(u) + res, err := conn.Options(u) if err != nil { - conn.Close() - return nil, err + // since this method is not implemented by every RTSP server, + // return only if status code is not 404 + if res == nil || res.StatusCode != base.StatusNotFound { + conn.Close() + return nil, err + } } _, err = conn.Announce(u, tracks) @@ -172,11 +200,29 @@ func (d Dialer) DialPublish(address string, tracks Tracks) (*ConnClient, error) return nil, err } - for _, track := range tracks { - _, err = conn.Setup(u, headers.TransportModeRecord, d.StreamProtocol, track, 0, 0) + proto := func() StreamProtocol { + if d.StreamProtocol != nil { + return *d.StreamProtocol + } + return StreamProtocolUDP + }() + + for i, track := range tracks { + res, err := conn.Setup(u, headers.TransportModeRecord, proto, track, 0, 0) if err != nil { - conn.Close() - return nil, err + // switch protocol automatically + if i == 0 && d.StreamProtocol == nil && res != nil && + res.StatusCode == base.StatusUnsupportedTransport { + proto = StreamProtocolTCP + _, err := conn.Setup(u, headers.TransportModePlay, proto, track, 0, 0) + if err != nil { + conn.Close() + return nil, err + } + } else { + conn.Close() + return nil, err + } } } diff --git a/dialer_test.go b/dialer_test.go index de7c89f4..51df989f 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -59,7 +59,7 @@ func (c *container) wait() int { return int(code) } -func TestDialReadParallel(t *testing.T) { +func TestDialRead(t *testing.T) { for _, proto := range []string{ "udp", "tcp", @@ -85,12 +85,16 @@ func TestDialReadParallel(t *testing.T) { time.Sleep(1 * time.Second) - dialer := func() Dialer { - if proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() + dialer := Dialer{ + StreamProtocol: func() *StreamProtocol { + if proto == "udp" { + v := StreamProtocolUDP + return &v + } + v := StreamProtocolTCP + return &v + }(), + } conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) @@ -115,7 +119,48 @@ func TestDialReadParallel(t *testing.T) { } } -func TestDialReadRedirectParallel(t *testing.T) { +func TestDialReadAutomaticProtocol(t *testing.T) { + cnt1, err := newContainer("rtsp-simple-server", "server", []string{ + "protocols: [tcp]\n", + }) + require.NoError(t, err) + defer cnt1.close() + + time.Sleep(1 * time.Second) + + cnt2, err := newContainer("ffmpeg", "publish", []string{ + "-re", + "-stream_loop", "-1", + "-i", "/emptyvideo.ts", + "-c", "copy", + "-f", "rtsp", + "-rtsp_transport", "tcp", + "rtsp://localhost:8554/teststream", + }) + require.NoError(t, err) + defer cnt2.close() + + time.Sleep(1 * time.Second) + + dialer := Dialer{StreamProtocol: nil} + + conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") + require.NoError(t, err) + + var firstFrame int32 + frameRecv := make(chan struct{}) + done := conn.OnFrame(func(id int, typ StreamType, content []byte) { + if atomic.SwapInt32(&firstFrame, 1) == 0 { + close(frameRecv) + } + }) + + <-frameRecv + conn.Close() + <-done +} + +func TestDialReadRedirect(t *testing.T) { cnt1, err := newContainer("rtsp-simple-server", "server", []string{ "paths:\n" + " path1:\n" + @@ -158,7 +203,7 @@ func TestDialReadRedirectParallel(t *testing.T) { <-done } -func TestDialReadPauseParallel(t *testing.T) { +func TestDialReadPause(t *testing.T) { for _, proto := range []string{ "udp", "tcp", @@ -184,12 +229,16 @@ func TestDialReadPauseParallel(t *testing.T) { time.Sleep(1 * time.Second) - dialer := func() Dialer { - if proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() + dialer := Dialer{ + StreamProtocol: func() *StreamProtocol { + if proto == "udp" { + v := StreamProtocolUDP + return &v + } + v := StreamProtocolTCP + return &v + }(), + } conn, err := dialer.DialRead("rtsp://localhost:8554/teststream") require.NoError(t, err) @@ -255,12 +304,16 @@ func TestDialPublishSerial(t *testing.T) { track, err := NewTrackH264(0, sps, pps) require.NoError(t, err) - dialer := func() Dialer { - if proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() + dialer := Dialer{ + StreamProtocol: func() *StreamProtocol { + if proto == "udp" { + v := StreamProtocolUDP + return &v + } + v := StreamProtocolTCP + return &v + }(), + } conn, err := dialer.DialPublish("rtsp://localhost:8554/teststream", Tracks{track}) @@ -337,12 +390,16 @@ func TestDialPublishParallel(t *testing.T) { var conn *ConnClient defer func() { conn.Close() }() - dialer := func() Dialer { - if ca.proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() + dialer := Dialer{ + StreamProtocol: func() *StreamProtocol { + if ca.proto == "udp" { + v := StreamProtocolUDP + return &v + } + v := StreamProtocolTCP + return &v + }(), + } go func() { defer close(writerDone) @@ -421,12 +478,16 @@ func TestDialPublishPauseSerial(t *testing.T) { track, err := NewTrackH264(0, sps, pps) require.NoError(t, err) - dialer := func() Dialer { - if proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() + dialer := Dialer{ + StreamProtocol: func() *StreamProtocol { + if proto == "udp" { + v := StreamProtocolUDP + return &v + } + v := StreamProtocolTCP + return &v + }(), + } conn, err := dialer.DialPublish("rtsp://localhost:8554/teststream", Tracks{track}) @@ -489,12 +550,16 @@ func TestDialPublishPauseParallel(t *testing.T) { track, err := NewTrackH264(0, sps, pps) require.NoError(t, err) - dialer := func() Dialer { - if proto == "udp" { - return Dialer{} - } - return Dialer{StreamProtocol: StreamProtocolTCP} - }() + dialer := Dialer{ + StreamProtocol: func() *StreamProtocol { + if proto == "udp" { + v := StreamProtocolUDP + return &v + } + v := StreamProtocolTCP + return &v + }(), + } conn, err := dialer.DialPublish("rtsp://localhost:8554/teststream", Tracks{track}) diff --git a/examples/client-publish-options.go b/examples/client-publish-options.go index 04a9b1c0..6a975337 100644 --- a/examples/client-publish-options.go +++ b/examples/client-publish-options.go @@ -12,6 +12,7 @@ import ( ) // This example shows how to +// * set additional client options // * generate RTP/H264 frames from a file with Gstreamer // * connect to a RTSP server, announce a H264 track // * write the frames of the track @@ -42,10 +43,10 @@ func main() { panic(err) } - // Dialer allows to set additional options + // Dialer allows to set additional client options dialer := gortsplib.Dialer{ - // the stream protocol - StreamProtocol: gortsplib.StreamProtocolUDP, + // the stream protocol (UDP or TCP). If nil, it is chosen automatically + StreamProtocol: nil, // timeout of read operations ReadTimeout: 10 * time.Second, // timeout of write operations @@ -72,7 +73,7 @@ func main() { break } - // write frames to the server + // write track frames err = conn.WriteFrame(track.Id, gortsplib.StreamTypeRtp, buf[:n]) if err != nil { fmt.Printf("connection is closed (%s)\n", err) diff --git a/examples/client-publish-pause.go b/examples/client-publish-pause.go index d206d224..44127bbd 100644 --- a/examples/client-publish-pause.go +++ b/examples/client-publish-pause.go @@ -65,7 +65,7 @@ func main() { break } - // write frames to the server + // write track frames err = conn.WriteFrame(track.Id, gortsplib.StreamTypeRtp, buf[:n]) if err != nil { break diff --git a/examples/client-publish.go b/examples/client-publish.go index 2bddb389..3e6913ab 100644 --- a/examples/client-publish.go +++ b/examples/client-publish.go @@ -57,7 +57,7 @@ func main() { break } - // write frames to the server + // write track frames err = conn.WriteFrame(track.Id, gortsplib.StreamTypeRtp, buf[:n]) if err != nil { fmt.Printf("connection is closed (%s)\n", err) diff --git a/examples/client-read-options.go b/examples/client-read-options.go index 3e17f189..d72c4e7e 100644 --- a/examples/client-read-options.go +++ b/examples/client-read-options.go @@ -10,14 +10,15 @@ import ( ) // This example shows how to +// * set additional client options // * connect to a RTSP server // * read all tracks on a path func main() { - // Dialer allows to set additional options + // Dialer allows to set additional client options dialer := gortsplib.Dialer{ - // the stream protocol - StreamProtocol: gortsplib.StreamProtocolUDP, + // the stream protocol (UDP or TCP). If nil, it is chosen automatically + StreamProtocol: nil, // timeout of read operations ReadTimeout: 10 * time.Second, // timeout of write operations @@ -35,7 +36,7 @@ func main() { } defer conn.Close() - // read frames from the server + // read track frames readerDone := conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) }) diff --git a/examples/client-read-pause.go b/examples/client-read-pause.go index 32d3e2ad..2293b6f1 100644 --- a/examples/client-read-pause.go +++ b/examples/client-read-pause.go @@ -24,7 +24,7 @@ func main() { defer conn.Close() for { - // read frames from the server + // read track frames readerDone := conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) }) diff --git a/examples/client-read.go b/examples/client-read.go index b2700093..c297f280 100644 --- a/examples/client-read.go +++ b/examples/client-read.go @@ -20,7 +20,7 @@ func main() { } defer conn.Close() - // read frames from the server + // read track frames readerDone := conn.OnFrame(func(id int, typ gortsplib.StreamType, buf []byte) { fmt.Printf("frame from track %d, type %v: %v\n", id, typ, buf) })