diff --git a/auth-client.go b/auth-client.go index 1f02fef7..e90157f0 100644 --- a/auth-client.go +++ b/auth-client.go @@ -4,6 +4,7 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "net/url" "strings" ) @@ -61,11 +62,11 @@ func NewAuthClient(header []string, user string, pass string) (*AuthClient, erro // GenerateHeader generates an Authorization Header that allows to authenticate a request with // the given method and path. -func (ac *AuthClient) GenerateHeader(method Method, path string) []string { +func (ac *AuthClient) GenerateHeader(method Method, ur *url.URL) []string { ha1 := md5Hex(ac.user + ":" + ac.realm + ":" + ac.pass) - ha2 := md5Hex(string(method) + ":" + path) + ha2 := md5Hex(string(method) + ":" + ur.String()) response := md5Hex(ha1 + ":" + ac.nonce + ":" + ha2) return []string{fmt.Sprintf("Digest username=\"%s\", realm=\"%s\", nonce=\"%s\", uri=\"%s\", response=\"%s\"", - ac.user, ac.realm, ac.nonce, path, response)} + ac.user, ac.realm, ac.nonce, ur.String(), response)} } diff --git a/auth-server.go b/auth-server.go index 61da50a0..9ed2917d 100644 --- a/auth-server.go +++ b/auth-server.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/url" ) // AuthServer is an object that helps a server validating the credentials of a client. @@ -35,7 +36,7 @@ func (as *AuthServer) GenerateHeader() []string { // ValidateHeader validates the Authorization header sent by a client after receiving the // WWW-Authenticate header provided by GenerateHeader(). -func (as *AuthServer) ValidateHeader(header []string, method Method, path string) error { +func (as *AuthServer) ValidateHeader(header []string, method Method, ur *url.URL) error { if len(header) != 1 { return fmt.Errorf("Authorization header not provided") } @@ -82,12 +83,12 @@ func (as *AuthServer) ValidateHeader(header []string, method Method, path string return fmt.Errorf("wrong username") } - if inUri != path { - return fmt.Errorf("wrong uri") + if inUri != ur.String() { + return fmt.Errorf("wrong url") } ha1 := md5Hex(as.user + ":" + as.realm + ":" + as.pass) - ha2 := md5Hex(string(method) + ":" + path) + ha2 := md5Hex(string(method) + ":" + ur.String()) response := md5Hex(ha1 + ":" + as.nonce + ":" + ha2) if inResponse != response { diff --git a/auth-server_test.go b/auth-server_test.go index 49579b12..e408efdc 100644 --- a/auth-server_test.go +++ b/auth-server_test.go @@ -1,6 +1,7 @@ package gortsplib import ( + "net/url" "testing" "github.com/stretchr/testify/require" @@ -12,8 +13,10 @@ func TestAuthClientServer(t *testing.T) { ac, err := NewAuthClient(wwwAuthenticate, "testuser", "testpass") require.NoError(t, err) - authorization := ac.GenerateHeader(ANNOUNCE, "rtsp://myhost/mypath") + authorization := ac.GenerateHeader(ANNOUNCE, + &url.URL{Scheme: "rtsp", Host: "myhost", Path: "mypath"}) - err = as.ValidateHeader(authorization, ANNOUNCE, "rtsp://myhost/mypath") + err = as.ValidateHeader(authorization, ANNOUNCE, + &url.URL{Scheme: "rtsp", Host: "myhost", Path: "mypath"}) require.NoError(t, err) } diff --git a/request.go b/request.go index a5eacf2c..ceaba909 100644 --- a/request.go +++ b/request.go @@ -3,6 +3,7 @@ package gortsplib import ( "bufio" "fmt" + "net/url" ) const ( @@ -32,7 +33,7 @@ const ( // Request is a RTSP request. type Request struct { Method Method - Url string + Url *url.URL Header Header Content []byte } @@ -46,7 +47,7 @@ func readRequest(br *bufio.Reader) (*Request, error) { } req.Method = Method(byts[:len(byts)-1]) - if len(req.Method) == 0 { + if req.Method == "" { return nil, fmt.Errorf("empty method") } @@ -54,12 +55,22 @@ func readRequest(br *bufio.Reader) (*Request, error) { if err != nil { return nil, err } - req.Url = string(byts[:len(byts)-1]) + rawUrl := string(byts[:len(byts)-1]) - if len(req.Url) == 0 { + if rawUrl == "" { return nil, fmt.Errorf("empty url") } + ur, err := url.Parse(rawUrl) + if err != nil { + return nil, fmt.Errorf("unable to parse url '%s'", rawUrl) + } + req.Url = ur + + if req.Url.Scheme != "rtsp" { + return nil, fmt.Errorf("invalid url scheme '%s'", req.Url.Scheme) + } + byts, err = readBytesLimited(br, '\r', _MAX_PROTOCOL_LENGTH) if err != nil { return nil, err @@ -89,7 +100,7 @@ func readRequest(br *bufio.Reader) (*Request, error) { } func (req *Request) write(bw *bufio.Writer) error { - _, err := bw.Write([]byte(string(req.Method) + " " + req.Url + " " + _RTSP_PROTO + "\r\n")) + _, err := bw.Write([]byte(string(req.Method) + " " + req.Url.String() + " " + _RTSP_PROTO + "\r\n")) if err != nil { return err } diff --git a/request_test.go b/request_test.go index 8f75a126..95dc55b3 100644 --- a/request_test.go +++ b/request_test.go @@ -3,6 +3,7 @@ package gortsplib import ( "bufio" "bytes" + "net/url" "testing" "github.com/stretchr/testify/require" @@ -22,7 +23,7 @@ var casesRequest = []struct { "\r\n"), &Request{ Method: "OPTIONS", - Url: "rtsp://example.com/media.mp4", + Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ "CSeq": []string{"1"}, "Require": []string{"implicit-play"}, @@ -37,7 +38,7 @@ var casesRequest = []struct { "\r\n"), &Request{ Method: "DESCRIBE", - Url: "rtsp://example.com/media.mp4", + Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ "CSeq": []string{"2"}, }, @@ -65,7 +66,7 @@ var casesRequest = []struct { "m=video 2232 RTP/AVP 31\n"), &Request{ Method: "ANNOUNCE", - Url: "rtsp://example.com/media.mp4", + Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ "CSeq": []string{"7"}, "Date": []string{"23 Jan 1997 15:35:06 GMT"}, @@ -99,7 +100,7 @@ var casesRequest = []struct { "jitter\n"), &Request{ Method: "GET_PARAMETER", - Url: "rtsp://example.com/media.mp4", + Url: &url.URL{Scheme: "rtsp", Host: "example.com", Path: "/media.mp4"}, Header: Header{ "CSeq": []string{"9"}, "Content-Type": []string{"text/parameters"},