mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 23:26:54 +08:00
client: fix authentication during SETUP, preventing 401s during protocol switches (https://github.com/aler9/rtsp-simple-server/issues/392)
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
|||||||
psdp "github.com/pion/sdp/v3"
|
psdp "github.com/pion/sdp/v3"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/aler9/gortsplib/pkg/auth"
|
||||||
"github.com/aler9/gortsplib/pkg/base"
|
"github.com/aler9/gortsplib/pkg/base"
|
||||||
"github.com/aler9/gortsplib/pkg/headers"
|
"github.com/aler9/gortsplib/pkg/headers"
|
||||||
"github.com/aler9/gortsplib/pkg/rtcpsender"
|
"github.com/aler9/gortsplib/pkg/rtcpsender"
|
||||||
@@ -759,6 +760,26 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, base.Describe, req.Method)
|
require.Equal(t, base.Describe, req.Method)
|
||||||
|
|
||||||
|
v := auth.NewValidator("myuser", "mypass", nil)
|
||||||
|
|
||||||
|
err = base.Response{
|
||||||
|
StatusCode: base.StatusUnauthorized,
|
||||||
|
Header: base.Header{
|
||||||
|
"WWW-Authenticate": v.GenerateHeader(),
|
||||||
|
},
|
||||||
|
}.Write(bconn.Writer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
req, err = readRequest(bconn.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, base.Describe, req.Method)
|
||||||
|
|
||||||
|
err = v.ValidateHeader(req.Header["Authorization"],
|
||||||
|
base.Describe,
|
||||||
|
mustParseURL("rtsp://localhost:8554/teststream"),
|
||||||
|
nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
|
track, err := NewTrackH264(96, []byte("123456"), []byte("123456"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -823,11 +844,31 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
bconn = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
bconn = bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||||||
|
|
||||||
|
req, err = readRequest(bconn.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, base.Setup, req.Method)
|
||||||
|
|
||||||
|
v = auth.NewValidator("myuser", "mypass", nil)
|
||||||
|
|
||||||
|
err = base.Response{
|
||||||
|
StatusCode: base.StatusUnauthorized,
|
||||||
|
Header: base.Header{
|
||||||
|
"WWW-Authenticate": v.GenerateHeader(),
|
||||||
|
},
|
||||||
|
}.Write(bconn.Writer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
req, err = readRequest(bconn.Reader)
|
req, err = readRequest(bconn.Reader)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, base.Setup, req.Method)
|
require.Equal(t, base.Setup, req.Method)
|
||||||
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL)
|
require.Equal(t, mustParseURL("rtsp://localhost:8554/teststream/trackID=0"), req.URL)
|
||||||
|
|
||||||
|
err = v.ValidateHeader(req.Header["Authorization"],
|
||||||
|
base.Setup,
|
||||||
|
mustParseURL("rtsp://localhost:8554/teststream/trackID=0"),
|
||||||
|
nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
inTH = headers.Transport{}
|
inTH = headers.Transport{}
|
||||||
err = inTH.Read(req.Header["Transport"])
|
err = inTH.Read(req.Header["Transport"])
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -880,7 +921,7 @@ func TestClientReadAutomaticProtocol(t *testing.T) {
|
|||||||
ReadTimeout: 1 * time.Second,
|
ReadTimeout: 1 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := c.DialRead("rtsp://localhost:8554/teststream")
|
conn, err := c.DialRead("rtsp://myuser:mypass@localhost:8554/teststream")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
frameRecv := make(chan struct{})
|
frameRecv := make(chan struct{})
|
||||||
|
@@ -897,7 +897,7 @@ func (cc *ClientConn) do(req *base.Request, skipResponse bool) (*base.Response,
|
|||||||
cc.session = sx.Session
|
cc.session = sx.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup authentication
|
// if required, send request again with authentication
|
||||||
if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && cc.sender == nil {
|
if res.StatusCode == base.StatusUnauthorized && req.URL.User != nil && cc.sender == nil {
|
||||||
pass, _ := req.URL.User.Password()
|
pass, _ := req.URL.User.Password()
|
||||||
user := req.URL.User.Username()
|
user := req.URL.User.Username()
|
||||||
@@ -908,7 +908,6 @@ func (cc *ClientConn) do(req *base.Request, skipResponse bool) (*base.Response,
|
|||||||
}
|
}
|
||||||
cc.sender = sender
|
cc.sender = sender
|
||||||
|
|
||||||
// send request again
|
|
||||||
return cc.do(req, false)
|
return cc.do(req, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1036,7 +1035,7 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
baseURL, err := func() (*base.URL, error) {
|
baseURL, err := func() (*base.URL, error) {
|
||||||
// prefer Content-Base (optional)
|
// use Content-Base
|
||||||
if cb, ok := res.Header["Content-Base"]; ok {
|
if cb, ok := res.Header["Content-Base"]; ok {
|
||||||
if len(cb) != 1 {
|
if len(cb) != 1 {
|
||||||
return nil, fmt.Errorf("invalid Content-Base: '%v'", cb)
|
return nil, fmt.Errorf("invalid Content-Base: '%v'", cb)
|
||||||
@@ -1047,10 +1046,13 @@ func (cc *ClientConn) doDescribe(u *base.URL) (Tracks, *base.Response, error) {
|
|||||||
return nil, fmt.Errorf("invalid Content-Base: '%v'", cb)
|
return nil, fmt.Errorf("invalid Content-Base: '%v'", cb)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add credentials from URL of request
|
||||||
|
ret.User = u.User
|
||||||
|
|
||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// if not provided, use DESCRIBE URL
|
// if not provided, use URL of request
|
||||||
return u, nil
|
return u, nil
|
||||||
}()
|
}()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -10,7 +10,7 @@ import (
|
|||||||
"github.com/aler9/gortsplib/pkg/headers"
|
"github.com/aler9/gortsplib/pkg/headers"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Validator allows to validate some credentials generated by a Sender.
|
// Validator allows to validate credentials generated by a Sender.
|
||||||
type Validator struct {
|
type Validator struct {
|
||||||
user string
|
user string
|
||||||
userHashed bool
|
userHashed bool
|
||||||
|
Reference in New Issue
Block a user