diff --git a/client_read_test.go b/client_read_test.go index d28cb2be..af77e7d3 100644 --- a/client_read_test.go +++ b/client_read_test.go @@ -765,7 +765,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { err = base.Response{ StatusCode: base.StatusUnauthorized, Header: base.Header{ - "WWW-Authenticate": v.GenerateHeader(), + "WWW-Authenticate": v.Header(), }, }.Write(bconn.Writer) require.NoError(t, err) @@ -774,10 +774,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { require.NoError(t, err) require.Equal(t, base.Describe, req.Method) - err = v.ValidateHeader(req.Header["Authorization"], - base.Describe, - mustParseURL("rtsp://localhost:8554/teststream"), - nil) + err = v.ValidateRequest(req, nil) require.NoError(t, err) track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) @@ -853,7 +850,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { err = base.Response{ StatusCode: base.StatusUnauthorized, Header: base.Header{ - "WWW-Authenticate": v.GenerateHeader(), + "WWW-Authenticate": v.Header(), }, }.Write(bconn.Writer) require.NoError(t, err) @@ -863,10 +860,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) { require.Equal(t, base.Setup, req.Method) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) - err = v.ValidateHeader(req.Header["Authorization"], - base.Setup, - mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), - nil) + err = v.ValidateRequest(req, nil) require.NoError(t, err) inTH = headers.Transport{} diff --git a/client_test.go b/client_test.go index 41c0fa2d..fc403b51 100644 --- a/client_test.go +++ b/client_test.go @@ -135,7 +135,7 @@ func TestClientAuth(t *testing.T) { err = base.Response{ StatusCode: base.StatusUnauthorized, Header: base.Header{ - "WWW-Authenticate": v.GenerateHeader(), + "WWW-Authenticate": v.Header(), }, }.Write(bconn.Writer) require.NoError(t, err) @@ -144,7 +144,7 @@ func TestClientAuth(t *testing.T) { require.NoError(t, err) require.Equal(t, base.Describe, req.Method) - err = v.ValidateHeader(req.Header["Authorization"], base.Describe, req.URL, nil) + err = v.ValidateRequest(req, nil) require.NoError(t, err) track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) diff --git a/clientconn.go b/clientconn.go index 44c05217..4b48de19 100644 --- a/clientconn.go +++ b/clientconn.go @@ -804,21 +804,17 @@ func (cc *ClientConn) do(req *base.Request, skipResponse bool) (*base.Response, req.Header = make(base.Header) } - // add session if cc.session != "" { req.Header["Session"] = base.HeaderValue{cc.session} } - // add auth if cc.sender != nil { - req.Header["Authorization"] = cc.sender.GenerateHeader(req.Method, req.URL) + cc.sender.AddAuthorization(req) } - // add cseq cc.cseq++ req.Header["CSeq"] = base.HeaderValue{strconv.FormatInt(int64(cc.cseq), 10)} - // add user agent req.Header["User-Agent"] = base.HeaderValue{"gortsplib"} if cc.c.OnRequest != nil { diff --git a/pkg/auth/package_test.go b/pkg/auth/package_test.go index b0d86256..795fafec 100644 --- a/pkg/auth/package_test.go +++ b/pkg/auth/package_test.go @@ -47,7 +47,7 @@ func TestAuth(t *testing.T) { t.Run(c1.name+"_"+conf, func(t *testing.T) { va := NewValidator("testuser", "testpass", c1.methods) - wwwAuthenticate := va.GenerateHeader() + wwwAuthenticate := va.Header() se, err := NewSender(wwwAuthenticate, func() string { @@ -63,16 +63,21 @@ func TestAuth(t *testing.T) { return "testpass" }()) require.NoError(t, err) - authorization := se.GenerateHeader(base.Announce, - mustParseURL(func() string { + + req := &base.Request{ + Method: base.Announce, + URL: mustParseURL(func() string { if conf == "wrongurl" { return "rtsp://myhost/my1path" } return "rtsp://myhost/mypath" - }())) + }()), + } + se.AddAuthorization(req) - err = va.ValidateHeader(authorization, base.Announce, - mustParseURL("rtsp://myhost/mypath"), nil) + req.URL = mustParseURL("rtsp://myhost/mypath") + + err = va.ValidateRequest(req, nil) if conf != "nofail" { require.Error(t, err) @@ -101,13 +106,18 @@ func TestAuthVLC(t *testing.T) { va := NewValidator("testuser", "testpass", []headers.AuthMethod{headers.AuthBasic, headers.AuthDigest}) - se, err := NewSender(va.GenerateHeader(), "testuser", "testpass") + se, err := NewSender(va.Header(), "testuser", "testpass") require.NoError(t, err) - authorization := se.GenerateHeader(base.Announce, - mustParseURL(ca.clientURL)) - err = va.ValidateHeader(authorization, base.Announce, - mustParseURL(ca.serverURL), mustParseURL(ca.clientURL)) + req := &base.Request{ + Method: base.Announce, + URL: mustParseURL(ca.clientURL), + } + se.AddAuthorization(req) + + req.URL = mustParseURL(ca.serverURL) + + err = va.ValidateRequest(req, mustParseURL(ca.clientURL)) require.NoError(t, err) } } @@ -123,7 +133,7 @@ func TestAuthHashed(t *testing.T) { "sha256:E9JJ8stBJ7QM+nV4ZoUCeHk/gU3tPFh/5YieiJp6n2w=", []headers.AuthMethod{headers.AuthBasic, headers.AuthDigest}) - va, err := NewSender(se.GenerateHeader(), + va, err := NewSender(se.Header(), func() string { if conf == "wronguser" { return "test1user" @@ -137,11 +147,14 @@ func TestAuthHashed(t *testing.T) { return "testpass" }()) require.NoError(t, err) - authorization := va.GenerateHeader(base.Announce, - mustParseURL("rtsp://myhost/mypath")) - err = se.ValidateHeader(authorization, base.Announce, - mustParseURL("rtsp://myhost/mypath"), nil) + req := &base.Request{ + Method: base.Announce, + URL: mustParseURL("rtsp://myhost/mypath"), + } + va.AddAuthorization(req) + + err = se.ValidateRequest(req, nil) if conf != "nofail" { require.Error(t, err) diff --git a/pkg/auth/sender.go b/pkg/auth/sender.go index b8276270..5b3ad1be 100644 --- a/pkg/auth/sender.go +++ b/pkg/auth/sender.go @@ -81,10 +81,9 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) { return nil, fmt.Errorf("no authentication methods available") } -// GenerateHeader generates an Authorization Header that allows to authenticate a request with -// the given method and url. -func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderValue { - urStr := ur.CloneWithoutCredentials().String() +// AddAuthorization adds the Authorization header to a Request. +func (se *Sender) AddAuthorization(req *base.Request) { + urStr := req.URL.CloneWithoutCredentials().String() h := headers.Authorization{ Method: se.method, @@ -97,7 +96,7 @@ func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderVa default: // headers.AuthDigest response := md5Hex(md5Hex(se.user+":"+se.realm+":"+se.pass) + ":" + - se.nonce + ":" + md5Hex(string(method)+":"+urStr)) + se.nonce + ":" + md5Hex(string(req.Method)+":"+urStr)) h.DigestValues = headers.Authenticate{ Method: headers.AuthDigest, @@ -109,5 +108,9 @@ func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderVa } } - return h.Write() + if req.Header == nil { + req.Header = make(base.Header) + } + + req.Header["Authorization"] = h.Write() } diff --git a/pkg/auth/validator.go b/pkg/auth/validator.go index ea8bb488..ce0bd5a9 100644 --- a/pkg/auth/validator.go +++ b/pkg/auth/validator.go @@ -60,9 +60,9 @@ func NewValidator(user string, pass string, methods []headers.AuthMethod) *Valid } } -// GenerateHeader generates the WWW-Authenticate header needed by a client to +// Header generates the WWW-Authenticate header needed by a client to // authenticate. -func (va *Validator) GenerateHeader() base.HeaderValue { +func (va *Validator) Header() base.HeaderValue { var ret base.HeaderValue for _, m := range va.methods { switch m { @@ -83,15 +83,11 @@ func (va *Validator) GenerateHeader() base.HeaderValue { return ret } -// ValidateHeader validates the Authorization header sent by a client after receiving the -// WWW-Authenticate header. -func (va *Validator) ValidateHeader( - v base.HeaderValue, - method base.Method, - ur *base.URL, +// ValidateRequest validates a request sent by a client. +func (va *Validator) ValidateRequest(req *base.Request, altURL *base.URL) error { var auth headers.Authorization - err := auth.Read(v) + err := auth.Read(req.Header["Authorization"]) if err != nil { return err } @@ -151,7 +147,7 @@ func (va *Validator) ValidateHeader( return fmt.Errorf("wrong username") } - urlString := ur.String() + urlString := req.URL.String() if *auth.DigestValues.URI != urlString { // do another try with the alternative URL @@ -165,7 +161,7 @@ func (va *Validator) ValidateHeader( } response := md5Hex(md5Hex(va.user+":"+va.realm+":"+va.pass) + - ":" + va.nonce + ":" + md5Hex(string(method)+":"+urlString)) + ":" + va.nonce + ":" + md5Hex(string(req.Method)+":"+urlString)) if *auth.DigestValues.Response != response { return fmt.Errorf("wrong response") diff --git a/pkg/auth/validator_test.go b/pkg/auth/validator_test.go index c169a9b1..596d863a 100644 --- a/pkg/auth/validator_test.go +++ b/pkg/auth/validator_test.go @@ -58,7 +58,13 @@ func TestValidatorErrors(t *testing.T) { t.Run(ca.name, func(t *testing.T) { va := NewValidator("myuser", "mypass", nil) va.nonce = "abcde" - err := va.ValidateHeader(ca.hv, base.Describe, nil, nil) + err := va.ValidateRequest(&base.Request{ + Method: base.Describe, + URL: nil, + Header: base.Header{ + "Authorization": ca.hv, + }, + }, nil) require.Equal(t, ca.err, err.Error()) }) }