use base.URL instead of url.URL

This commit is contained in:
aler9
2020-11-01 19:00:05 +01:00
parent 5945937a5f
commit b2de7dd899
9 changed files with 201 additions and 162 deletions

View File

@@ -19,9 +19,12 @@ type Client struct {
nonce string nonce string
} }
// NewClient allocates a Client. // NewClient allocates a Client with the WWW-Authenticate header provided by
// header is the WWW-Authenticate header provided by the server. // the server and a set of credentials.
func NewClient(v base.HeaderValue, user string, pass string) (*Client, error) { func NewClient(v base.HeaderValue, userinfo *url.Userinfo) (*Client, error) {
pass, _ := userinfo.Password()
user := userinfo.Username()
// prefer digest // prefer digest
if headerAuthDigest := func() string { if headerAuthDigest := func() string {
for _, vi := range v { for _, vi := range v {
@@ -83,7 +86,9 @@ func NewClient(v base.HeaderValue, user string, pass string) (*Client, error) {
// GenerateHeader generates an Authorization Header that allows to authenticate a request with // GenerateHeader generates an Authorization Header that allows to authenticate a request with
// the given method and url. // the given method and url.
func (ac *Client) GenerateHeader(method base.Method, ur *url.URL) base.HeaderValue { func (ac *Client) GenerateHeader(method base.Method, ur *base.URL) base.HeaderValue {
ur = ur.CloneWithoutCredentials()
switch ac.method { switch ac.method {
case headers.AuthBasic: case headers.AuthBasic:
response := base64.StdEncoding.EncodeToString([]byte(ac.user + ":" + ac.pass)) response := base64.StdEncoding.EncodeToString([]byte(ac.user + ":" + ac.pass))

View File

@@ -34,13 +34,13 @@ func TestAuthMethods(t *testing.T) {
authServer := NewServer("testuser", "testpass", c.methods) authServer := NewServer("testuser", "testpass", c.methods)
wwwAuthenticate := authServer.GenerateHeader() wwwAuthenticate := authServer.GenerateHeader()
ac, err := NewClient(wwwAuthenticate, "testuser", "testpass") ac, err := NewClient(wwwAuthenticate, url.UserPassword("testuser", "testpass"))
require.NoError(t, err) require.NoError(t, err)
authorization := ac.GenerateHeader(base.ANNOUNCE, authorization := ac.GenerateHeader(base.ANNOUNCE,
&url.URL{Scheme: "rtsp", Host: "myhost", Path: "mypath"}) base.MustParseURL("rtsp://myhost/mypath"))
err = authServer.ValidateHeader(authorization, base.ANNOUNCE, err = authServer.ValidateHeader(authorization, base.ANNOUNCE,
&url.URL{Scheme: "rtsp", Host: "myhost", Path: "mypath"}) base.MustParseURL("rtsp://myhost/mypath"))
require.NoError(t, err) require.NoError(t, err)
}) })
} }
@@ -51,12 +51,12 @@ func TestAuthVLC(t *testing.T) {
[]headers.AuthMethod{headers.AuthBasic, headers.AuthDigest}) []headers.AuthMethod{headers.AuthBasic, headers.AuthDigest})
wwwAuthenticate := authServer.GenerateHeader() wwwAuthenticate := authServer.GenerateHeader()
ac, err := NewClient(wwwAuthenticate, "testuser", "testpass") ac, err := NewClient(wwwAuthenticate, url.UserPassword("testuser", "testpass"))
require.NoError(t, err) require.NoError(t, err)
authorization := ac.GenerateHeader(base.ANNOUNCE, authorization := ac.GenerateHeader(base.ANNOUNCE,
&url.URL{Scheme: "rtsp", Host: "myhost", Path: "/mypath/"}) base.MustParseURL("rtsp://myhost/mypath/"))
err = authServer.ValidateHeader(authorization, base.ANNOUNCE, err = authServer.ValidateHeader(authorization, base.ANNOUNCE,
&url.URL{Scheme: "rtsp", Host: "myhost", Path: "/mypath/trackId=0"}) base.MustParseURL("rtsp://myhost/mypath/trackId=0"))
require.NoError(t, err) require.NoError(t, err)
} }

View File

@@ -5,7 +5,6 @@ import (
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/url"
"strings" "strings"
"github.com/aler9/gortsplib/base" "github.com/aler9/gortsplib/base"
@@ -66,7 +65,7 @@ func (as *Server) GenerateHeader() base.HeaderValue {
// ValidateHeader validates the Authorization header sent by a client after receiving the // ValidateHeader validates the Authorization header sent by a client after receiving the
// WWW-Authenticate header. // WWW-Authenticate header.
func (as *Server) ValidateHeader(v base.HeaderValue, method base.Method, ur *url.URL) error { func (as *Server) ValidateHeader(v base.HeaderValue, method base.Method, ur *base.URL) error {
if len(v) == 0 { if len(v) == 0 {
return fmt.Errorf("authorization header not provided") return fmt.Errorf("authorization header not provided")
} }
@@ -127,9 +126,10 @@ func (as *Server) ValidateHeader(v base.HeaderValue, method base.Method, ur *url
if *auth.URI != uri { if *auth.URI != uri {
// VLC strips the control path; do another try without the control path // VLC strips the control path; do another try without the control path
base, _, ok := base.URLGetBaseControlPath(ur) base, _, ok := ur.BaseControlPath()
if ok { if ok {
ur.Path = "/" + base + "/" ur = ur.Clone()
ur.SetPath("/" + base + "/")
uri = ur.String() uri = ur.String()
} }

View File

@@ -4,7 +4,6 @@ package base
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"net/url"
"strconv" "strconv"
) )
@@ -40,7 +39,7 @@ type Request struct {
Method Method Method Method
// request url // request url
URL *url.URL URL *URL
// map of header values // map of header values
Header Header Header Header
@@ -75,16 +74,12 @@ func (req *Request) Read(rb *bufio.Reader) error {
return fmt.Errorf("empty url") return fmt.Errorf("empty url")
} }
ur, err := url.Parse(rawUrl) ur, err := ParseURL(rawUrl)
if err != nil { if err != nil {
return fmt.Errorf("unable to parse url (%v)", rawUrl) return fmt.Errorf("unable to parse url (%v)", rawUrl)
} }
req.URL = ur req.URL = ur
if req.URL.Scheme != "rtsp" {
return fmt.Errorf("invalid url scheme (%v)", rawUrl)
}
byts, err = readBytesLimited(rb, '\r', requestMaxProtocolLength) byts, err = readBytesLimited(rb, '\r', requestMaxProtocolLength)
if err != nil { if err != nil {
return err return err
@@ -116,15 +111,7 @@ func (req *Request) Read(rb *bufio.Reader) error {
// Write writes a request. // Write writes a request.
func (req Request) Write(bw *bufio.Writer) error { func (req Request) Write(bw *bufio.Writer) error {
// remove credentials u := req.URL.CloneWithoutCredentials()
u := &url.URL{
Scheme: req.URL.Scheme,
Host: req.URL.Host,
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
}
_, err := bw.Write([]byte(string(req.Method) + " " + u.String() + " " + rtspProtocol10 + "\r\n")) _, err := bw.Write([]byte(string(req.Method) + " " + u.String() + " " + rtspProtocol10 + "\r\n"))
if err != nil { if err != nil {
return err return err

View File

@@ -3,20 +3,11 @@ package base
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"net/url"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func urlMustParse(s string) *url.URL {
u, err := url.Parse(s)
if err != nil {
panic(err)
}
return u
}
var casesRequest = []struct { var casesRequest = []struct {
name string name string
byts []byte byts []byte
@@ -31,7 +22,7 @@ var casesRequest = []struct {
"\r\n"), "\r\n"),
Request{ Request{
Method: "OPTIONS", Method: "OPTIONS",
URL: urlMustParse("rtsp://example.com/media.mp4"), URL: MustParseURL("rtsp://example.com/media.mp4"),
Header: Header{ Header: Header{
"CSeq": HeaderValue{"1"}, "CSeq": HeaderValue{"1"},
"Require": HeaderValue{"implicit-play"}, "Require": HeaderValue{"implicit-play"},
@@ -47,7 +38,7 @@ var casesRequest = []struct {
"\r\n"), "\r\n"),
Request{ Request{
Method: "DESCRIBE", Method: "DESCRIBE",
URL: urlMustParse("rtsp://example.com/media.mp4"), URL: MustParseURL("rtsp://example.com/media.mp4"),
Header: Header{ Header: Header{
"Accept": HeaderValue{"application/sdp"}, "Accept": HeaderValue{"application/sdp"},
"CSeq": HeaderValue{"2"}, "CSeq": HeaderValue{"2"},
@@ -62,7 +53,7 @@ var casesRequest = []struct {
"\r\n"), "\r\n"),
Request{ Request{
Method: "DESCRIBE", Method: "DESCRIBE",
URL: urlMustParse("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp"), URL: MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp"),
Header: Header{ Header: Header{
"Accept": HeaderValue{"application/sdp"}, "Accept": HeaderValue{"application/sdp"},
"CSeq": HeaderValue{"3"}, "CSeq": HeaderValue{"3"},
@@ -91,7 +82,7 @@ var casesRequest = []struct {
"m=video 2232 RTP/AVP 31\n"), "m=video 2232 RTP/AVP 31\n"),
Request{ Request{
Method: "ANNOUNCE", Method: "ANNOUNCE",
URL: urlMustParse("rtsp://example.com/media.mp4"), URL: MustParseURL("rtsp://example.com/media.mp4"),
Header: Header{ Header: Header{
"CSeq": HeaderValue{"7"}, "CSeq": HeaderValue{"7"},
"Date": HeaderValue{"23 Jan 1997 15:35:06 GMT"}, "Date": HeaderValue{"23 Jan 1997 15:35:06 GMT"},
@@ -125,7 +116,7 @@ var casesRequest = []struct {
"jitter\n"), "jitter\n"),
Request{ Request{
Method: "GET_PARAMETER", Method: "GET_PARAMETER",
URL: urlMustParse("rtsp://example.com/media.mp4"), URL: MustParseURL("rtsp://example.com/media.mp4"),
Header: Header{ Header: Header{
"CSeq": HeaderValue{"9"}, "CSeq": HeaderValue{"9"},
"Content-Type": HeaderValue{"text/parameters"}, "Content-Type": HeaderValue{"text/parameters"},

View File

@@ -1,6 +1,7 @@
package base package base
import ( import (
"fmt"
"net/url" "net/url"
"strings" "strings"
) )
@@ -14,14 +15,84 @@ func stringsReverseIndexByte(s string, c byte) int {
return -1 return -1
} }
// URLGetBasePath returns the base path of a RTSP URL. // URL is a RTSP URL.
type URL struct {
inner *url.URL
}
// ParseURL parses a RTSP URL.
func ParseURL(s string) (*URL, error) {
u, err := url.Parse(s)
if err != nil {
return nil, err
}
if u.Scheme != "rtsp" {
return nil, fmt.Errorf("wrong scheme")
}
return &URL{u}, nil
}
// MustParseURL is like ParseURL but panics in case of errors.
func MustParseURL(s string) *URL {
u, err := ParseURL(s)
if err != nil {
panic(err)
}
return u
}
// String implements fmt.Stringer.
func (u *URL) String() string {
return u.inner.String()
}
// Clone clones a URL.
func (u *URL) Clone() *URL {
return &URL{&url.URL{
Scheme: u.inner.Scheme,
Opaque: u.inner.Opaque,
User: u.inner.User,
Host: u.inner.Host,
Path: u.inner.Path,
RawPath: u.inner.RawPath,
ForceQuery: u.inner.ForceQuery,
RawQuery: u.inner.RawQuery,
}}
}
// CloneWithoutCredentials clones a URL without its credentials.
func (u *URL) CloneWithoutCredentials() *URL {
return &URL{&url.URL{
Scheme: u.inner.Scheme,
Opaque: u.inner.Opaque,
Host: u.inner.Host,
Path: u.inner.Path,
RawPath: u.inner.RawPath,
ForceQuery: u.inner.ForceQuery,
RawQuery: u.inner.RawQuery,
}}
}
// Host returns the host of a RTSP URL.
func (u *URL) Host() string {
return u.inner.Host
}
// User returns the credentials of a RTSP URL.
func (u *URL) User() *url.Userinfo {
return u.inner.User
}
// BasePath returns the base path of a RTSP URL.
// We assume that the URL doesn't contain a control path. // We assume that the URL doesn't contain a control path.
func URLGetBasePath(u *url.URL) (string, bool) { func (u *URL) BasePath() (string, bool) {
var path string var path string
if u.RawPath != "" { if u.inner.RawPath != "" {
path = u.RawPath path = u.inner.RawPath
} else { } else {
path = u.Path path = u.inner.Path
} }
// remove leading slash // remove leading slash
@@ -33,17 +104,17 @@ func URLGetBasePath(u *url.URL) (string, bool) {
return path, true return path, true
} }
// URLGetBaseControlPath returns the base path and the control path of a RTSP URL. // BaseControlPath returns the base path and the control path of a RTSP URL.
// We assume that the URL contains a control path. // We assume that the URL contains a control path.
func URLGetBaseControlPath(u *url.URL) (string, string, bool) { func (u *URL) BaseControlPath() (string, string, bool) {
var pathAndQuery string var pathAndQuery string
if u.RawPath != "" { if u.inner.RawPath != "" {
pathAndQuery = u.RawPath pathAndQuery = u.inner.RawPath
} else { } else {
pathAndQuery = u.Path pathAndQuery = u.inner.Path
} }
if u.RawQuery != "" { if u.inner.RawQuery != "" {
pathAndQuery += "?" + u.RawQuery pathAndQuery += "?" + u.inner.RawQuery
} }
// remove leading slash // remove leading slash
@@ -77,26 +148,45 @@ func URLGetBaseControlPath(u *url.URL) (string, string, bool) {
return basePath, controlPath, true return basePath, controlPath, true
} }
// URLAddControlPath adds a control path to a RTSP url. // AddControlPath adds a control path to a RTSP url.
func URLAddControlPath(u *url.URL, controlPath string) { func (u *URL) AddControlPath(controlPath string) {
// always insert the control path at the end of the url // always insert the control path at the end of the url
if u.RawQuery != "" { if u.inner.RawQuery != "" {
if !strings.HasSuffix(u.RawQuery, "/") { if !strings.HasSuffix(u.inner.RawQuery, "/") {
u.RawQuery += "/" u.inner.RawQuery += "/"
} }
u.RawQuery += controlPath u.inner.RawQuery += controlPath
} else { } else {
if u.RawPath != "" { if u.inner.RawPath != "" {
if !strings.HasSuffix(u.RawPath, "/") { if !strings.HasSuffix(u.inner.RawPath, "/") {
u.RawPath += "/" u.inner.RawPath += "/"
} }
u.RawPath += controlPath u.inner.RawPath += controlPath
} }
if !strings.HasSuffix(u.Path, "/") { if !strings.HasSuffix(u.inner.Path, "/") {
u.Path += "/" u.inner.Path += "/"
} }
u.Path += controlPath u.inner.Path += controlPath
}
}
// SetHost sets the host of a RTSP URL.
func (u *URL) SetHost(host string) {
u.inner.Host = host
}
// SetUser sets the credentials of a RTSP URL.
func (u *URL) SetUser(user *url.Userinfo) {
u.inner.User = user
}
// SetPath sets the path of a RTSP URL.
func (u *URL) SetPath(path string) {
if u.inner.RawPath != "" {
u.inner.RawPath = path
} else {
u.inner.Path = path
} }
} }

View File

@@ -1,77 +1,76 @@
package base package base
import ( import (
"net/url"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestURLGetBasePath(t *testing.T) { func TestURLBasePath(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
u *url.URL u *URL
b string b string
}{ }{
{ {
urlMustParse("rtsp://localhost:8554/teststream"), MustParseURL("rtsp://localhost:8554/teststream"),
"teststream", "teststream",
}, },
{ {
urlMustParse("rtsp://localhost:8554/test/stream"), MustParseURL("rtsp://localhost:8554/test/stream"),
"test/stream", "test/stream",
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp"), MustParseURL("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp"),
"test", "test",
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp"), MustParseURL("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp"),
"te!st", "te!st",
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp"), MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp"),
"user=tmp&password=BagRep1!&channel=1&stream=0.sdp", "user=tmp&password=BagRep1!&channel=1&stream=0.sdp",
}, },
} { } {
b, ok := URLGetBasePath(ca.u) b, ok := ca.u.BasePath()
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, ca.b, b) require.Equal(t, ca.b, b)
} }
} }
func TestURLGetBaseControlPath(t *testing.T) { func TestURLBaseControlPath(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
u *url.URL u *URL
b string b string
c string c string
}{ }{
{ {
urlMustParse("rtsp://localhost:8554/teststream/trackID=1"), MustParseURL("rtsp://localhost:8554/teststream/trackID=1"),
"teststream", "teststream",
"trackID=1", "trackID=1",
}, },
{ {
urlMustParse("rtsp://localhost:8554/test/stream/trackID=1"), MustParseURL("rtsp://localhost:8554/test/stream/trackID=1"),
"test/stream", "test/stream",
"trackID=1", "trackID=1",
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1"),
"test", "test",
"trackID=1", "trackID=1",
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"),
"te!st", "te!st",
"trackID=1", "trackID=1",
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"),
"user=tmp&password=BagRep1!&channel=1&stream=0.sdp", "user=tmp&password=BagRep1!&channel=1&stream=0.sdp",
"trackID=1", "trackID=1",
}, },
} { } {
b, c, ok := URLGetBaseControlPath(ca.u) b, c, ok := ca.u.BaseControlPath()
require.Equal(t, true, ok) require.Equal(t, true, ok)
require.Equal(t, ca.b, b) require.Equal(t, ca.b, b)
require.Equal(t, ca.c, c) require.Equal(t, ca.c, c)
@@ -80,31 +79,31 @@ func TestURLGetBaseControlPath(t *testing.T) {
func TestURLAddControlPath(t *testing.T) { func TestURLAddControlPath(t *testing.T) {
for _, ca := range []struct { for _, ca := range []struct {
u *url.URL u *URL
ou *url.URL ou *URL
}{ }{
{ {
urlMustParse("rtsp://localhost:8554/teststream"), MustParseURL("rtsp://localhost:8554/teststream"),
urlMustParse("rtsp://localhost:8554/teststream/trackID=1"), MustParseURL("rtsp://localhost:8554/teststream/trackID=1"),
}, },
{ {
urlMustParse("rtsp://localhost:8554/test/stream"), MustParseURL("rtsp://localhost:8554/test/stream"),
urlMustParse("rtsp://localhost:8554/test/stream/trackID=1"), MustParseURL("rtsp://localhost:8554/test/stream/trackID=1"),
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp"), MustParseURL("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp"),
urlMustParse("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/test?user=tmp&password=BagRep1&channel=1&stream=0.sdp/trackID=1"),
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp"), MustParseURL("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp"),
urlMustParse("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/te!st?user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"),
}, },
{ {
urlMustParse("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp"), MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp"),
urlMustParse("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"), MustParseURL("rtsp://192.168.1.99:554/user=tmp&password=BagRep1!&channel=1&stream=0.sdp/trackID=1"),
}, },
} { } {
URLAddControlPath(ca.u, "trackID=1") ca.u.AddControlPath("trackID=1")
require.Equal(t, ca.ou, ca.u) require.Equal(t, ca.ou, ca.u)
} }
} }

View File

@@ -12,7 +12,6 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"net/url"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
@@ -85,7 +84,7 @@ type ConnClient struct {
cseq int cseq int
auth *auth.Client auth *auth.Client
state connClientState state connClientState
streamUrl *url.URL streamUrl *base.URL
streamProtocol *StreamProtocol streamProtocol *StreamProtocol
tracks map[int]*Track tracks map[int]*Track
rtcpReceivers map[int]*rtcpreceiver.RtcpReceiver rtcpReceivers map[int]*rtcpreceiver.RtcpReceiver
@@ -279,15 +278,7 @@ func (c *ConnClient) Do(req *base.Request) (*base.Response, error) {
// insert auth // insert auth
if c.auth != nil { if c.auth != nil {
// remove credentials req.Header["Authorization"] = c.auth.GenerateHeader(req.Method, req.URL)
u := &url.URL{
Scheme: req.URL.Scheme,
Host: req.URL.Host,
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
}
req.Header["Authorization"] = c.auth.GenerateHeader(req.Method, u)
} }
// insert cseq // insert cseq
@@ -334,9 +325,8 @@ func (c *ConnClient) Do(req *base.Request) (*base.Response, error) {
} }
// setup authentication // setup authentication
if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && c.auth == nil { if res.StatusCode == base.StatusUnauthorized && req.URL.User() != nil && c.auth == nil {
pass, _ := req.URL.User.Password() auth, err := auth.NewClient(res.Header["WWW-Authenticate"], req.URL.User())
auth, err := auth.NewClient(res.Header["WWW-Authenticate"], req.URL.User.Username(), pass)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to setup authentication: %s", err) return nil, fmt.Errorf("unable to setup authentication: %s", err)
} }
@@ -352,20 +342,14 @@ func (c *ConnClient) Do(req *base.Request) (*base.Response, error) {
// Options writes an OPTIONS request and reads a response, that contains // Options writes an OPTIONS request and reads a response, that contains
// the methods allowed by the server. Since this method is not implemented by // the methods allowed by the server. Since this method is not implemented by
// every RTSP server, the function does not fail if the returned code is StatusNotFound. // every RTSP server, the function does not fail if the returned code is StatusNotFound.
func (c *ConnClient) Options(u *url.URL) (*base.Response, error) { func (c *ConnClient) Options(u *base.URL) (*base.Response, error) {
if c.state != connClientStateInitial { if c.state != connClientStateInitial {
return nil, fmt.Errorf("can't be called when reading or publishing") return nil, fmt.Errorf("can't be called when reading or publishing")
} }
res, err := c.Do(&base.Request{ res, err := c.Do(&base.Request{
Method: base.OPTIONS, Method: base.OPTIONS,
URL: &url.URL{ URL: u,
Scheme: "rtsp",
Host: u.Host,
User: u.User,
// use the stream path, otherwise some cameras do not reply
Path: u.Path,
},
}) })
if err != nil { if err != nil {
return nil, err return nil, err
@@ -379,7 +363,7 @@ func (c *ConnClient) Options(u *url.URL) (*base.Response, error) {
} }
// Describe writes a DESCRIBE request and reads a Response. // Describe writes a DESCRIBE request and reads a Response.
func (c *ConnClient) Describe(u *url.URL) (Tracks, *base.Response, error) { func (c *ConnClient) Describe(u *base.URL) (Tracks, *base.Response, error) {
if c.state != connClientStateInitial { if c.state != connClientStateInitial {
return nil, nil, fmt.Errorf("can't be called when reading or publishing") return nil, nil, fmt.Errorf("can't be called when reading or publishing")
} }
@@ -417,7 +401,7 @@ func (c *ConnClient) Describe(u *url.URL) (Tracks, *base.Response, error) {
} }
// build an URL by merging baseUrl with the control attribute from track.Media // build an URL by merging baseUrl with the control attribute from track.Media
func (c *ConnClient) urlForTrack(baseUrl *url.URL, mode TransportMode, track *Track) *url.URL { func (c *ConnClient) urlForTrack(baseUrl *base.URL, mode TransportMode, track *Track) *base.URL {
control := func() string { control := func() string {
// if we're reading, get control from track ID // if we're reading, get control from track ID
if mode == TransportModeRecord { if mode == TransportModeRecord {
@@ -440,35 +424,24 @@ func (c *ConnClient) urlForTrack(baseUrl *url.URL, mode TransportMode, track *Tr
// control attribute contains an absolute path // control attribute contains an absolute path
if strings.HasPrefix(control, "rtsp://") { if strings.HasPrefix(control, "rtsp://") {
newUrl, err := url.Parse(control) newUrl, err := base.ParseURL(control)
if err != nil { if err != nil {
return baseUrl return baseUrl
} }
return &url.URL{ // copy host and credentials
Scheme: "rtsp", newUrl.SetHost(baseUrl.Host())
Host: baseUrl.Host, newUrl.SetUser(baseUrl.User())
User: baseUrl.User, return newUrl
Path: newUrl.Path,
RawPath: newUrl.RawPath,
RawQuery: newUrl.RawQuery,
}
} }
// control attribute contains a relative path // control attribute contains a control path
u := &url.URL{ newUrl := baseUrl.Clone()
Scheme: "rtsp", newUrl.AddControlPath(control)
Host: baseUrl.Host, return newUrl
User: baseUrl.User,
Path: baseUrl.Path,
RawPath: baseUrl.RawPath,
RawQuery: baseUrl.RawQuery,
}
base.URLAddControlPath(u, control)
return u
} }
func (c *ConnClient) setup(u *url.URL, mode TransportMode, track *Track, ht *headers.Transport) (*base.Response, error) { func (c *ConnClient) setup(u *base.URL, mode TransportMode, track *Track, ht *headers.Transport) (*base.Response, error) {
res, err := c.Do(&base.Request{ res, err := c.Do(&base.Request{
Method: base.SETUP, Method: base.SETUP,
URL: c.urlForTrack(u, mode, track), URL: c.urlForTrack(u, mode, track),
@@ -489,7 +462,7 @@ func (c *ConnClient) setup(u *url.URL, mode TransportMode, track *Track, ht *hea
// SetupUDP writes a SETUP request and reads a Response. // SetupUDP writes a SETUP request and reads a Response.
// If rtpPort and rtcpPort are zero, they are be chosen automatically. // If rtpPort and rtcpPort are zero, they are be chosen automatically.
func (c *ConnClient) SetupUDP(u *url.URL, mode TransportMode, track *Track, rtpPort int, func (c *ConnClient) SetupUDP(u *base.URL, mode TransportMode, track *Track, rtpPort int,
rtcpPort int) (*base.Response, error) { rtcpPort int) (*base.Response, error) {
if c.state != connClientStateInitial { if c.state != connClientStateInitial {
return nil, fmt.Errorf("can't be called when reading or publishing") return nil, fmt.Errorf("can't be called when reading or publishing")
@@ -608,7 +581,7 @@ func (c *ConnClient) SetupUDP(u *url.URL, mode TransportMode, track *Track, rtpP
} }
// SetupTCP writes a SETUP request and reads a Response. // SetupTCP writes a SETUP request and reads a Response.
func (c *ConnClient) SetupTCP(u *url.URL, mode TransportMode, track *Track) (*base.Response, error) { func (c *ConnClient) SetupTCP(u *base.URL, mode TransportMode, track *Track) (*base.Response, error) {
if c.state != connClientStateInitial { if c.state != connClientStateInitial {
return nil, fmt.Errorf("can't be called when reading or publishing") return nil, fmt.Errorf("can't be called when reading or publishing")
} }
@@ -662,7 +635,7 @@ func (c *ConnClient) SetupTCP(u *url.URL, mode TransportMode, track *Track) (*ba
// Play writes a PLAY request and reads a Response // Play writes a PLAY request and reads a Response
// This function can be called only after SetupUDP() or SetupTCP(). // This function can be called only after SetupUDP() or SetupTCP().
func (c *ConnClient) Play(u *url.URL) (*base.Response, error) { func (c *ConnClient) Play(u *base.URL) (*base.Response, error) {
if c.state != connClientStateInitial { if c.state != connClientStateInitial {
return nil, fmt.Errorf("can't be called when reading or publishing") return nil, fmt.Errorf("can't be called when reading or publishing")
} }
@@ -772,14 +745,8 @@ func (c *ConnClient) LoopUDP() error {
case <-keepaliveTicker.C: case <-keepaliveTicker.C:
_, err := c.Do(&base.Request{ _, err := c.Do(&base.Request{
Method: base.OPTIONS, Method: base.OPTIONS,
URL: &url.URL{ // use the stream path, otherwise some cameras do not reply
Scheme: "rtsp", URL: c.streamUrl,
Host: c.streamUrl.Host,
User: c.streamUrl.User,
// use the stream path, otherwise some cameras do not reply
Path: c.streamUrl.Path,
RawPath: c.streamUrl.RawPath,
},
SkipResponse: true, SkipResponse: true,
}) })
if err != nil { if err != nil {
@@ -811,7 +778,7 @@ func (c *ConnClient) LoopUDP() error {
} }
// Announce writes an ANNOUNCE request and reads a Response. // Announce writes an ANNOUNCE request and reads a Response.
func (c *ConnClient) Announce(u *url.URL, tracks Tracks) (*base.Response, error) { func (c *ConnClient) Announce(u *base.URL, tracks Tracks) (*base.Response, error) {
if c.streamUrl != nil { if c.streamUrl != nil {
fmt.Errorf("announce has already been sent with another url url") fmt.Errorf("announce has already been sent with another url url")
} }
@@ -838,7 +805,7 @@ func (c *ConnClient) Announce(u *url.URL, tracks Tracks) (*base.Response, error)
} }
// Record writes a RECORD request and reads a Response. // Record writes a RECORD request and reads a Response.
func (c *ConnClient) Record(u *url.URL) (*base.Response, error) { func (c *ConnClient) Record(u *base.URL) (*base.Response, error) {
if c.state != connClientStateInitial { if c.state != connClientStateInitial {
return nil, fmt.Errorf("can't be called when reading or publishing") return nil, fmt.Errorf("can't be called when reading or publishing")
} }

View File

@@ -1,17 +1,17 @@
package gortsplib package gortsplib
import ( import (
"net/url" "github.com/aler9/gortsplib/base"
) )
// DialRead connects to the address and starts reading all tracks. // DialRead connects to the address and starts reading all tracks.
func DialRead(address string, proto StreamProtocol) (*ConnClient, error) { func DialRead(address string, proto StreamProtocol) (*ConnClient, error) {
u, err := url.Parse(address) u, err := base.ParseURL(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := NewConnClient(ConnClientConf{Host: u.Host}) conn, err := NewConnClient(ConnClientConf{Host: u.Host()})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -57,12 +57,12 @@ func DialRead(address string, proto StreamProtocol) (*ConnClient, error) {
// DialPublish connects to the address and starts publishing the tracks. // DialPublish connects to the address and starts publishing the tracks.
func DialPublish(address string, proto StreamProtocol, tracks Tracks) (*ConnClient, error) { func DialPublish(address string, proto StreamProtocol, tracks Tracks) (*ConnClient, error) {
u, err := url.Parse(address) u, err := base.ParseURL(address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
conn, err := NewConnClient(ConnClientConf{Host: u.Host}) conn, err := NewConnClient(ConnClientConf{Host: u.Host()})
if err != nil { if err != nil {
return nil, err return nil, err
} }