diff --git a/server_test.go b/server_test.go index 169517c3..6dc95ca9 100644 --- a/server_test.go +++ b/server_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" + "github.com/aler9/gortsplib/pkg/auth" "github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/headers" ) @@ -1251,3 +1252,62 @@ func TestServerErrorInvalidPath(t *testing.T) { }) } } + +func TestServerAuth(t *testing.T) { + authValidator := auth.NewValidator("myuser", "mypass", nil) + + s := &Server{ + Handler: &testServerHandler{ + onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { + err := authValidator.ValidateRequest(ctx.Req) + if err != nil { + return &base.Response{ + StatusCode: base.StatusUnauthorized, + Header: base.Header{ + "WWW-Authenticate": authValidator.Header(), + }, + }, nil + } + + return &base.Response{ + StatusCode: base.StatusOK, + }, nil + }, + }, + RTSPAddress: "localhost:8554", + } + + err := s.Start() + require.NoError(t, err) + defer s.Close() + + conn, err := net.Dial("tcp", "localhost:8554") + require.NoError(t, err) + defer conn.Close() + br := bufio.NewReader(conn) + + track, err := NewTrackH264(96, []byte{0x01, 0x02, 0x03, 0x04}, []byte{0x01, 0x02, 0x03, 0x04}, nil) + require.NoError(t, err) + + req := base.Request{ + Method: base.Announce, + URL: mustParseURL("rtsp://localhost:8554/teststream"), + Header: base.Header{ + "CSeq": base.HeaderValue{"1"}, + "Content-Type": base.HeaderValue{"application/sdp"}, + }, + Body: Tracks{track}.Write(false), + } + + res, err := writeReqReadRes(conn, br, req) + require.NoError(t, err) + require.Equal(t, base.StatusUnauthorized, res.StatusCode) + + sender, err := auth.NewSender(res.Header["WWW-Authenticate"], "myuser", "mypass") + require.NoError(t, err) + + sender.AddAuthorization(&req) + res, err = writeReqReadRes(conn, br, req) + require.NoError(t, err) + require.Equal(t, base.StatusOK, res.StatusCode) +} diff --git a/serverconn.go b/serverconn.go index 4c2b2f2a..2885460f 100644 --- a/serverconn.go +++ b/serverconn.go @@ -543,9 +543,11 @@ func (sc *ServerConn) handleRequestInSession( ) (*base.Response, error) { // handle directly in Session if sc.session != nil { - // the SETUP request after ANNOUNCE don't have the session ID - // since ANNOUNCE didn't provide it. - if req.Method != base.Setup || sxID != "" { + // session ID is optional in SETUP and ANNOUNCE requests, since + // client may not have received the session ID yet due to multiple reasons: + // * requests can be retries after code 301 + // * SETUP requests comes after ANNOUNCE response, that don't contain the session ID + if sxID != "" { // the connection can't communicate with two sessions at once. if sxID != sc.session.secretID { return &base.Response{ diff --git a/serversession.go b/serversession.go index e66aaf18..1fa62d8f 100644 --- a/serversession.go +++ b/serversession.go @@ -296,6 +296,7 @@ func (ss *ServerSession) run() { }.Write() } + // after a TEARDOWN, session must be unpaired with the connection. if req.req.Method != base.Teardown { returnedSession = ss }