diff --git a/client.go b/client.go index 84ca17e5..ea66314f 100644 --- a/client.go +++ b/client.go @@ -31,6 +31,26 @@ import ( "github.com/bluenviron/gortsplib/v3/pkg/url" ) +// convert an URL into an address, in particular: +// * add default port +// * handle IPv6 with or without square brackets. +// Adapted from net/http: +// https://cs.opensource.google/go/go/+/refs/tags/go1.20.5:src/net/http/transport.go;l=2747 +func canonicalAddr(u *url.URL) string { + addr := u.Hostname() + + port := u.Port() + if port == "" { + if u.Scheme == "rtsp" { + port = "554" + } else { // rtsps + port = "322" + } + } + + return net.JoinHostPort(addr, port) +} + func isAnyPort(p int) bool { return p == 0 || p == 1 } @@ -249,8 +269,7 @@ type Client struct { checkTimeoutPeriod time.Duration keepalivePeriod time.Duration - scheme string - host string + connURL *url.URL ctx context.Context ctxCancel func() state clientState @@ -383,8 +402,10 @@ func (c *Client) Start(scheme string, host string) error { ctx, ctxCancel := context.WithCancel(context.Background()) - c.scheme = scheme - c.host = host + c.connURL = &url.URL{ + Scheme: scheme, + Host: host, + } c.ctx = ctx c.ctxCancel = ctxCancel c.checkTimeoutTimer = emptyTimer() @@ -577,8 +598,7 @@ func (c *Client) checkState(allowed map[clientState]struct{}) error { func (c *Client) trySwitchingProtocol() error { c.OnTransportSwitch(fmt.Errorf("no UDP packets received, switching to TCP")) - prevScheme := c.scheme - prevHost := c.host + prevConnURL := c.connURL prevBaseURL := c.baseURL prevMedias := c.medias @@ -586,8 +606,7 @@ func (c *Client) trySwitchingProtocol() error { v := TransportTCP c.effectiveTransport = &v - c.scheme = prevScheme - c.host = prevHost + c.connURL = prevConnURL // some Hikvision cameras require a describe before a setup _, _, _, err := c.doDescribe(c.lastDescribeURL) @@ -618,15 +637,13 @@ func (c *Client) trySwitchingProtocol() error { func (c *Client) trySwitchingProtocol2(medi *media.Media, baseURL *url.URL) (*base.Response, error) { c.OnTransportSwitch(fmt.Errorf("switching to TCP because server requested it")) - prevScheme := c.scheme - prevHost := c.host + prevConnURL := c.connURL c.reset() v := TransportTCP c.effectiveTransport = &v - c.scheme = prevScheme - c.host = prevHost + c.connURL = prevConnURL // some Hikvision cameras require a describe before a setup _, _, _, err := c.doDescribe(c.lastDescribeURL) @@ -700,41 +717,30 @@ func (c *Client) playRecordStop(isClosing bool) { } func (c *Client) connOpen() error { - if c.scheme != "rtsp" && c.scheme != "rtsps" { - return fmt.Errorf("unsupported scheme '%s'", c.scheme) + if c.connURL.Scheme != "rtsp" && c.connURL.Scheme != "rtsps" { + return fmt.Errorf("unsupported scheme '%s'", c.connURL.Scheme) } - if c.scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP { + if c.connURL.Scheme == "rtsps" && c.Transport != nil && *c.Transport != TransportTCP { return fmt.Errorf("RTSPS can be used only with TCP") } - // add default port - _, _, err := net.SplitHostPort(c.host) - if err != nil { - if c.scheme == "rtsp" { - c.host = net.JoinHostPort(c.host, "554") - } else { // rtsps - c.host = net.JoinHostPort(c.host, "322") - } - } - dialCtx, dialCtxCancel := context.WithTimeout(c.ctx, c.ReadTimeout) defer dialCtxCancel() - nconn, err := c.DialContext(dialCtx, "tcp", c.host) + nconn, err := c.DialContext(dialCtx, "tcp", canonicalAddr(c.connURL)) if err != nil { return err } - if c.scheme == "rtsps" { + if c.connURL.Scheme == "rtsps" { tlsConfig := c.TLSConfig if tlsConfig == nil { tlsConfig = &tls.Config{} } - host, _, _ := net.SplitHostPort(c.host) - tlsConfig.ServerName = host + tlsConfig.ServerName = c.connURL.Hostname() nconn = tls.Client(nconn, tlsConfig) } @@ -997,8 +1003,10 @@ func (c *Client) doDescribe(u *url.URL) (media.Medias, *url.URL, *base.Response, ru.User = u.User } - c.scheme = ru.Scheme - c.host = ru.Host + c.connURL = &url.URL{ + Scheme: ru.Scheme, + Host: ru.Host, + } return c.doDescribe(ru) } @@ -1125,7 +1133,7 @@ func (c *Client) doSetup( } // always use TCP if encrypted - if c.scheme == "rtsps" { + if c.connURL.Scheme == "rtsps" { v := TransportTCP c.effectiveTransport = &v } diff --git a/client_test.go b/client_test.go index e34f1163..40810ff0 100644 --- a/client_test.go +++ b/client_test.go @@ -23,6 +23,35 @@ func mustParseURL(s string) *url.URL { return u } +func TestClientURLToAddress(t *testing.T) { + for _, ca := range []struct { + name string + url string + addr string + }{ + { + "rtsp ipv6 with port", + "rtsp://[::1]:8888/path", + "[::1]:8888", + }, + { + "rtsp ipv6 without port", + "rtsp://[::1]/path", + "[::1]:554", + }, + { + "rtsps without port", + "rtsps://2.2.2.2/path", + "2.2.2.2:322", + }, + } { + t.Run(ca.name, func(t *testing.T) { + addr := canonicalAddr(mustParseURL(ca.url)) + require.Equal(t, ca.addr, addr) + }) + } +} + func TestClientTLSSetServerName(t *testing.T) { l, err := net.Listen("tcp", "localhost:8554") require.NoError(t, err) diff --git a/pkg/url/url.go b/pkg/url/url.go index cc9d7a80..ff4601e3 100644 --- a/pkg/url/url.go +++ b/pkg/url/url.go @@ -89,3 +89,18 @@ func (u *URL) RTSPPathAndQuery() (string, bool) { return pathAndQuery, true } + +// Hostname returns u.Host, stripping any valid port number if present. +// +// If the result is enclosed in square brackets, as literal IPv6 addresses are, +// the square brackets are removed from the result. +func (u *URL) Hostname() string { + return (*url.URL)(u).Hostname() +} + +// Port returns the port part of u.Host, without the leading colon. +// +// If u.Host doesn't contain a valid numeric port, Port returns an empty string. +func (u *URL) Port() string { + return (*url.URL)(u).Port() +}