diff --git a/client.go b/client.go index 3af98b2c..f3ab785e 100644 --- a/client.go +++ b/client.go @@ -586,14 +586,12 @@ func (c *Client) trySwitchingProtocol() error { prevScheme := c.scheme prevHost := c.host prevBaseURL := c.baseURL - prevUseGetParameter := c.useGetParameter prevMedias := c.medias c.reset() v := TransportTCP c.effectiveTransport = &v - c.useGetParameter = prevUseGetParameter c.scheme = prevScheme c.host = prevHost @@ -623,6 +621,26 @@ func (c *Client) trySwitchingProtocol() error { return nil } +func (c *Client) trySwitchingProtocol2(medi *media.Media, baseURL *url.URL) (*base.Response, error) { + prevScheme := c.scheme + prevHost := c.host + + c.reset() + + v := TransportTCP + c.effectiveTransport = &v + c.scheme = prevScheme + c.host = prevHost + + // some Hikvision cameras require a describe before a setup + _, _, _, err := c.doDescribe(c.lastDescribeURL) + if err != nil { + return nil, err + } + + return c.doSetup(medi, baseURL, 0, 0) +} + func (c *Client) playRecordStart() { // stop connCloser c.connCloserStop() @@ -1128,7 +1146,7 @@ func (c *Client) doSetup( c.effectiveTransport = &v } - transport := func() Transport { + requestedTransport := func() Transport { // transport set by previous Setup() or trySwitchingProtocol() if c.effectiveTransport != nil { return *c.effectiveTransport @@ -1154,7 +1172,7 @@ func (c *Client) doSetup( cm := newClientMedia(c) - switch transport { + switch requestedTransport { case TransportUDP: if (rtpPort == 0 && rtcpPort != 0) || (rtpPort != 0 && rtcpPort == 0) { @@ -1233,7 +1251,22 @@ func (c *Client) doSetup( return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err} } - switch transport { + switch requestedTransport { + case TransportUDP, TransportUDPMulticast: + if thRes.Protocol == headers.TransportProtocolTCP { + cm.close() + + // switch transport automatically + if c.effectiveTransport == nil && + c.Transport == nil { + return c.trySwitchingProtocol2(medi, baseURL) + } + + return nil, liberrors.ErrClientServerRequestedTCP{} + } + } + + switch requestedTransport { case TransportUDP: if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { cm.close() @@ -1318,6 +1351,10 @@ func (c *Client) doSetup( } case TransportTCP: + if thRes.Protocol != headers.TransportProtocolTCP { + return nil, liberrors.ErrClientServerRequestedUDP{} + } + if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} } @@ -1353,7 +1390,7 @@ func (c *Client) doSetup( cm.setMedia(medi) c.baseURL = baseURL - c.effectiveTransport = &transport + c.effectiveTransport = &requestedTransport if mode == headers.TransportModePlay { c.state = clientStatePrePlay diff --git a/client_play_test.go b/client_play_test.go index 0640357f..3800e9cb 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -1059,6 +1059,169 @@ func TestClientPlayAutomaticProtocol(t *testing.T) { <-packetRecv }) + t.Run("switch after tcp response", func(t *testing.T) { + l, err := net.Listen("tcp", "localhost:8554") + require.NoError(t, err) + defer l.Close() + + serverDone := make(chan struct{}) + defer func() { <-serverDone }() + go func() { + defer close(serverDone) + + medias := media.Medias{testH264Media} + medias.SetControls() + + func() { + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + co := conn.NewConn(nconn) + + req, err := co.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + err = co.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err) + + req, err = co.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) + + err = co.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, + }, + Body: mustMarshalSDP(medias.Marshal(false)), + }) + require.NoError(t, err) + + req, err = co.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + + err = co.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + ServerPorts: &[2]int{12312, 12313}, + }.Marshal(), + }, + }) + require.NoError(t, err) + + _, err = co.ReadRequest() + require.Error(t, err) + }() + + func() { + nconn, err := l.Accept() + require.NoError(t, err) + defer nconn.Close() + co := conn.NewConn(nconn) + + req, err := co.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Options, req.Method) + + err = co.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Public": base.HeaderValue{strings.Join([]string{ + string(base.Describe), + string(base.Setup), + string(base.Play), + }, ", ")}, + }, + }) + require.NoError(t, err) + + req, err = co.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Describe, req.Method) + + err = co.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Content-Type": base.HeaderValue{"application/sdp"}, + "Content-Base": base.HeaderValue{"rtsp://localhost:8554/teststream/"}, + }, + Body: mustMarshalSDP(medias.Marshal(false)), + }) + require.NoError(t, err) + + req, err = co.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Setup, req.Method) + + var inTH headers.Transport + err = inTH.Unmarshal(req.Header["Transport"]) + require.NoError(t, err) + require.Equal(t, headers.TransportProtocolTCP, inTH.Protocol) + + err = co.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + Header: base.Header{ + "Transport": headers.Transport{ + Protocol: headers.TransportProtocolTCP, + Delivery: func() *headers.TransportDelivery { + v := headers.TransportDeliveryUnicast + return &v + }(), + InterleavedIDs: &[2]int{0, 1}, + }.Marshal(), + }, + }) + require.NoError(t, err) + + req, err = co.ReadRequest() + require.NoError(t, err) + require.Equal(t, base.Play, req.Method) + + err = co.WriteResponse(&base.Response{ + StatusCode: base.StatusOK, + }) + require.NoError(t, err) + + err = co.WriteInterleavedFrame(&base.InterleavedFrame{ + Channel: 0, + Payload: testRTPPacketMarshaled, + }, make([]byte, 1024)) + require.NoError(t, err) + }() + }() + + packetRecv := make(chan struct{}) + + c := Client{} + err = readAll(&c, "rtsp://localhost:8554/teststream", + func(medi *media.Media, forma format.Format, pkt *rtp.Packet) { + close(packetRecv) + }) + require.NoError(t, err) + defer c.Close() + + <-packetRecv + }) + t.Run("switch after timeout", func(t *testing.T) { l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) diff --git a/pkg/liberrors/client.go b/pkg/liberrors/client.go index edc2072f..5e839ac0 100644 --- a/pkg/liberrors/client.go +++ b/pkg/liberrors/client.go @@ -107,6 +107,22 @@ func (e ErrClientTransportHeaderInvalid) Error() string { return fmt.Sprintf("invalid transport header: %v", e.Err) } +// ErrClientServerRequestedTCP is an error that can be returned by a client. +type ErrClientServerRequestedTCP struct{} + +// Error implements the error interface. +func (e ErrClientServerRequestedTCP) Error() string { + return "server wants to use the TCP transport protocol" +} + +// ErrClientServerRequestedUDP is an error that can be returned by a client. +type ErrClientServerRequestedUDP struct{} + +// Error implements the error interface. +func (e ErrClientServerRequestedUDP) Error() string { + return "server wants to use the UDP transport protocol" +} + // ErrClientTransportHeaderInvalidDelivery is an error that can be returned by a client. type ErrClientTransportHeaderInvalidDelivery struct{}