client: fix authentication during SETUP, preventing 401s during protocol switches (https://github.com/aler9/rtsp-simple-server/issues/392)

This commit is contained in:
aler9
2021-05-30 12:00:35 +02:00
parent 2c0d28ecb4
commit 9007f20af8
3 changed files with 49 additions and 6 deletions

View File

@@ -15,6 +15,7 @@ import (
psdp "github.com/pion/sdp/v3" psdp "github.com/pion/sdp/v3"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/aler9/gortsplib/pkg/auth"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/headers"
"github.com/aler9/gortsplib/pkg/rtcpsender" "github.com/aler9/gortsplib/pkg/rtcpsender"
@@ -759,6 +760,26 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Describe, req.Method) require.Equal(t, base.Describe, req.Method)
v := auth.NewValidator("myuser", "mypass", nil)
err = base.Response{
StatusCode: base.StatusUnauthorized,
Header: base.Header{
"WWW-Authenticate": v.GenerateHeader(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
req, err = readRequest(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Describe, req.Method)
err = v.ValidateHeader(req.Header["Authorization"],
base.Describe,
mustParseURL("rtsp://localhost:8554/teststream"),
nil)
require.NoError(t, err)
track, err := NewTrackH264(96, []byte("123456"), []byte("123456")) track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
require.NoError(t, err) require.NoError(t, err)
@@ -823,11 +844,31 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
bconn = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) bconn = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
req, err = readRequest(bconn.Reader)
require.NoError(t, err)
require.Equal(t, base.Setup, req.Method)
v = auth.NewValidator("myuser", "mypass", nil)
err = base.Response{
StatusCode: base.StatusUnauthorized,
Header: base.Header{
"WWW-Authenticate": v.GenerateHeader(),
},
}.Write(bconn.Writer)
require.NoError(t, err)
req, err = readRequest(bconn.Reader) req, err = readRequest(bconn.Reader)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.Setup, req.Method) require.Equal(t, base.Setup, req.Method)
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL) require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL)
err = v.ValidateHeader(req.Header["Authorization"],
base.Setup,
mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
nil)
require.NoError(t, err)
inTH = headers.Transport{} inTH = headers.Transport{}
err = inTH.Read(req.Header["Transport"]) err = inTH.Read(req.Header["Transport"])
require.NoError(t, err) require.NoError(t, err)
@@ -880,7 +921,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
ReadTimeout: 1 * time.Second, ReadTimeout: 1 * time.Second,
} }
conn, err := c.DialRead("rtsp://localhost:8554/teststream") conn, err := c.DialRead("rtsp://myuser:mypass@localhost:8554/teststream")
require.NoError(t, err) require.NoError(t, err)
frameRecv := make(chan struct{}) frameRecv := make(chan struct{})

View File

@@ -897,7 +897,7 @@ func (cc *ClientConn) do(req *base.Request, skipResponse bool) (*base.Response,
cc.session = sx.Session cc.session = sx.Session
} }
// setup authentication // if required, send request again with authentication
if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && cc.sender == nil { if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && cc.sender == nil {
pass, _ := req.URL.User.Password() pass, _ := req.URL.User.Password()
user := req.URL.User.Username() user := req.URL.User.Username()
@@ -908,7 +908,6 @@ func (cc *ClientConn) do(req *base.Request, skipResponse bool) (*base.Response,
} }
cc.sender = sender cc.sender = sender
// send request again
return cc.do(req, false) return cc.do(req, false)
} }
@@ -1036,7 +1035,7 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) {
} }
baseURL, err := func() (*base.URL, error) { baseURL, err := func() (*base.URL, error) {
// prefer Content-Base (optional) // use Content-Base
if cb, ok := res.Header["Content-Base"]; ok { if cb, ok := res.Header["Content-Base"]; ok {
if len(cb) != 1 { if len(cb) != 1 {
return nil, fmt.Errorf("invalid Content-Base: '%v'", cb) return nil, fmt.Errorf("invalid Content-Base: '%v'", cb)
@@ -1047,10 +1046,13 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) {
return nil, fmt.Errorf("invalid Content-Base: '%v'", cb) return nil, fmt.Errorf("invalid Content-Base: '%v'", cb)
} }
// add credentials from URL of request
ret.User = u.User
return ret, nil return ret, nil
} }
// if not provided, use DESCRIBE URL // if not provided, use URL of request
return u, nil return u, nil
}() }()
if err != nil { if err != nil {

View File

@@ -10,7 +10,7 @@ import (
"github.com/aler9/gortsplib/pkg/headers" "github.com/aler9/gortsplib/pkg/headers"
) )
// Validator allows to validate some credentials generated by a Sender. // Validator allows to validate credentials generated by a Sender.
type Validator struct { type Validator struct {
user string user string
userHashed bool userHashed bool