auth: simplify

This commit is contained in:
aler9
2021-05-30 12:52:46 +02:00
parent 9007f20af8
commit d07e93f245
7 changed files with 59 additions and 51 deletions

View File

@@ -765,7 +765,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
err = base.Response{ err = base.Response{
StatusCode: base.StatusUnauthorized, StatusCode: base.StatusUnauthorized,
Header: base.Header{ Header: base.Header{
"WWW-Authenticate": v.GenerateHeader(), "WWW-Authenticate": v.Header(),
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -774,10 +774,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Describe, req.Method) require.Equal(t, base.Describe, req.Method)
err = v.ValidateHeader(req.Header["Authorization"], err = v.ValidateRequest(req, nil)
base.Describe,
mustParseURL("rtsp://localhost:8554/teststream"),
nil)
require.NoError(t, err) require.NoError(t, err)
track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
@@ -853,7 +850,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
err = base.Response{ err = base.Response{
StatusCode: base.StatusUnauthorized, StatusCode: base.StatusUnauthorized,
Header: base.Header{ Header: base.Header{
"WWW-Authenticate": v.GenerateHeader(), "WWW-Authenticate": v.Header(),
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -863,10 +860,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL)
err = v.ValidateHeader(req.Header["Authorization"], err = v.ValidateRequest(req, nil)
base.Setup,
mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
nil)
require.NoError(t, err) require.NoError(t, err)
inTH = headers.Transport{} inTH = headers.Transport{}

View File

@@ -135,7 +135,7 @@ func TestClientAuth(t *testing.T) {
err = base.Response{ err = base.Response{
StatusCode: base.StatusUnauthorized, StatusCode: base.StatusUnauthorized,
Header: base.Header{ Header: base.Header{
"WWW-Authenticate": v.GenerateHeader(), "WWW-Authenticate": v.Header(),
}, },
}.Write(bconn.Writer) }.Write(bconn.Writer)
require.NoError(t, err) require.NoError(t, err)
@@ -144,7 +144,7 @@ func TestClientAuth(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Describe, req.Method) require.Equal(t, base.Describe, req.Method)
err = v.ValidateHeader(req.Header["Authorization"], base.Describe, req.URL, nil) err = v.ValidateRequest(req, nil)
require.NoError(t, err) require.NoError(t, err)
track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))

View File

@@ -804,21 +804,17 @@ func (cc *ClientConn) do(req *base.Request, skipResponse bool) (*base.Response,
req.Header = make(base.Header) req.Header = make(base.Header)
} }
// add session
if cc.session != "" { if cc.session != "" {
req.Header["Session"] = base.HeaderValue{cc.session} req.Header["Session"] = base.HeaderValue{cc.session}
} }
// add auth
if cc.sender != nil { if cc.sender != nil {
req.Header["Authorization"] = cc.sender.GenerateHeader(req.Method, req.URL) cc.sender.AddAuthorization(req)
} }
// add cseq
cc.cseq++ cc.cseq++
req.Header["CSeq"] = base.HeaderValue{strconv.FormatInt(int64(cc.cseq), 10)} req.Header["CSeq"] = base.HeaderValue{strconv.FormatInt(int64(cc.cseq), 10)}
// add user agent
req.Header["User-Agent"] = base.HeaderValue{"gortsplib"} req.Header["User-Agent"] = base.HeaderValue{"gortsplib"}
if cc.c.OnRequest != nil { if cc.c.OnRequest != nil {

View File

@@ -47,7 +47,7 @@ func TestAuth(t *testing.T) {
t.Run(c1.name+"_"+conf, func(t *testing.T) { t.Run(c1.name+"_"+conf, func(t *testing.T) {
va := NewValidator("testuser", "testpass", c1.methods) va := NewValidator("testuser", "testpass", c1.methods)
wwwAuthenticate := va.GenerateHeader() wwwAuthenticate := va.Header()
se, err := NewSender(wwwAuthenticate, se, err := NewSender(wwwAuthenticate,
func() string { func() string {
@@ -63,16 +63,21 @@ func TestAuth(t *testing.T) {
return "testpass" return "testpass"
}()) }())
require.NoError(t, err) require.NoError(t, err)
authorization := se.GenerateHeader(base.Announce,
mustParseURL(func() string { req := &base.Request{
Method: base.Announce,
URL: mustParseURL(func() string {
if conf == "wrongurl" { if conf == "wrongurl" {
return "rtsp://myhost/my1path" return "rtsp://myhost/my1path"
} }
return "rtsp://myhost/mypath" return "rtsp://myhost/mypath"
}())) }()),
}
se.AddAuthorization(req)
err = va.ValidateHeader(authorization, base.Announce, req.URL = mustParseURL("rtsp://myhost/mypath")
mustParseURL("rtsp://myhost/mypath"), nil)
err = va.ValidateRequest(req, nil)
if conf != "nofail" { if conf != "nofail" {
require.Error(t, err) require.Error(t, err)
@@ -101,13 +106,18 @@ func TestAuthVLC(t *testing.T) {
va := NewValidator("testuser", "testpass", va := NewValidator("testuser", "testpass",
[]headers.AuthMethod{headers.AuthBasic, headers.AuthDigest}) []headers.AuthMethod{headers.AuthBasic, headers.AuthDigest})
se, err := NewSender(va.GenerateHeader(), "testuser", "testpass") se, err := NewSender(va.Header(), "testuser", "testpass")
require.NoError(t, err) require.NoError(t, err)
authorization := se.GenerateHeader(base.Announce,
mustParseURL(ca.clientURL))
err = va.ValidateHeader(authorization, base.Announce, req := &base.Request{
mustParseURL(ca.serverURL), mustParseURL(ca.clientURL)) Method: base.Announce,
URL: mustParseURL(ca.clientURL),
}
se.AddAuthorization(req)
req.URL = mustParseURL(ca.serverURL)
err = va.ValidateRequest(req, mustParseURL(ca.clientURL))
require.NoError(t, err) require.NoError(t, err)
} }
} }
@@ -123,7 +133,7 @@ func TestAuthHashed(t *testing.T) {
"sha256:E9JJ8stBJ7QM+nV4ZoUCeHk/gU3tPFh/5YieiJp6n2w=", "sha256:E9JJ8stBJ7QM+nV4ZoUCeHk/gU3tPFh/5YieiJp6n2w=",
[]headers.AuthMethod{headers.AuthBasic, headers.AuthDigest}) []headers.AuthMethod{headers.AuthBasic, headers.AuthDigest})
va, err := NewSender(se.GenerateHeader(), va, err := NewSender(se.Header(),
func() string { func() string {
if conf == "wronguser" { if conf == "wronguser" {
return "test1user" return "test1user"
@@ -137,11 +147,14 @@ func TestAuthHashed(t *testing.T) {
return "testpass" return "testpass"
}()) }())
require.NoError(t, err) require.NoError(t, err)
authorization := va.GenerateHeader(base.Announce,
mustParseURL("rtsp://myhost/mypath"))
err = se.ValidateHeader(authorization, base.Announce, req := &base.Request{
mustParseURL("rtsp://myhost/mypath"), nil) Method: base.Announce,
URL: mustParseURL("rtsp://myhost/mypath"),
}
va.AddAuthorization(req)
err = se.ValidateRequest(req, nil)
if conf != "nofail" { if conf != "nofail" {
require.Error(t, err) require.Error(t, err)

View File

@@ -81,10 +81,9 @@ func NewSender(v base.HeaderValue, user string, pass string) (*Sender, error) {
return nil, fmt.Errorf("no authentication methods available") return nil, fmt.Errorf("no authentication methods available")
} }
// GenerateHeader generates an Authorization Header that allows to authenticate a request with // AddAuthorization adds the Authorization header to a Request.
// the given method and url. func (se *Sender) AddAuthorization(req *base.Request) {
func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderValue { urStr := req.URL.CloneWithoutCredentials().String()
urStr := ur.CloneWithoutCredentials().String()
h := headers.Authorization{ h := headers.Authorization{
Method: se.method, Method: se.method,
@@ -97,7 +96,7 @@ func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderVa
default: // headers.AuthDigest default: // headers.AuthDigest
response := md5Hex(md5Hex(se.user+":"+se.realm+":"+se.pass) + ":" + response := md5Hex(md5Hex(se.user+":"+se.realm+":"+se.pass) + ":" +
se.nonce + ":" + md5Hex(string(method)+":"+urStr)) se.nonce + ":" + md5Hex(string(req.Method)+":"+urStr))
h.DigestValues = headers.Authenticate{ h.DigestValues = headers.Authenticate{
Method: headers.AuthDigest, Method: headers.AuthDigest,
@@ -109,5 +108,9 @@ func (se *Sender) GenerateHeader(method base.Method, ur *base.URL) base.HeaderVa
} }
} }
return h.Write() if req.Header == nil {
req.Header = make(base.Header)
}
req.Header["Authorization"] = h.Write()
} }

View File

@@ -60,9 +60,9 @@ func NewValidator(user string, pass string, methods []headers.AuthMethod) *Valid
} }
} }
// GenerateHeader generates the WWW-Authenticate header needed by a client to // Header generates the WWW-Authenticate header needed by a client to
// authenticate. // authenticate.
func (va *Validator) GenerateHeader() base.HeaderValue { func (va *Validator) Header() base.HeaderValue {
var ret base.HeaderValue var ret base.HeaderValue
for _, m := range va.methods { for _, m := range va.methods {
switch m { switch m {
@@ -83,15 +83,11 @@ func (va *Validator) GenerateHeader() base.HeaderValue {
return ret return ret
} }
// ValidateHeader validates the Authorization header sent by a client after receiving the // ValidateRequest validates a request sent by a client.
// WWW-Authenticate header. func (va *Validator) ValidateRequest(req *base.Request,
func (va *Validator) ValidateHeader(
v base.HeaderValue,
method base.Method,
ur *base.URL,
altURL *base.URL) error { altURL *base.URL) error {
var auth headers.Authorization var auth headers.Authorization
err := auth.Read(v) err := auth.Read(req.Header["Authorization"])
if err != nil { if err != nil {
return err return err
} }
@@ -151,7 +147,7 @@ func (va *Validator) ValidateHeader(
return fmt.Errorf("wrong username") return fmt.Errorf("wrong username")
} }
urlString := ur.String() urlString := req.URL.String()
if *auth.DigestValues.URI != urlString { if *auth.DigestValues.URI != urlString {
// do another try with the alternative URL // do another try with the alternative URL
@@ -165,7 +161,7 @@ func (va *Validator) ValidateHeader(
} }
response := md5Hex(md5Hex(va.user+":"+va.realm+":"+va.pass) + response := md5Hex(md5Hex(va.user+":"+va.realm+":"+va.pass) +
":" + va.nonce + ":" + md5Hex(string(method)+":"+urlString)) ":" + va.nonce + ":" + md5Hex(string(req.Method)+":"+urlString))
if *auth.DigestValues.Response != response { if *auth.DigestValues.Response != response {
return fmt.Errorf("wrong response") return fmt.Errorf("wrong response")

View File

@@ -58,7 +58,13 @@ func TestValidatorErrors(t *testing.T) {
t.Run(ca.name, func(t *testing.T) { t.Run(ca.name, func(t *testing.T) {
va := NewValidator("myuser", "mypass", nil) va := NewValidator("myuser", "mypass", nil)
va.nonce = "abcde" va.nonce = "abcde"
err := va.ValidateHeader(ca.hv, base.Describe, nil, nil) err := va.ValidateRequest(&base.Request{
Method: base.Describe,
URL: nil,
Header: base.Header{
"Authorization": ca.hv,
},
}, nil)
require.Equal(t, ca.err, err.Error()) require.Equal(t, ca.err, err.Error())
}) })
} }