add ServerConn.VerifyCredentials() (#555)

This commit is contained in:
Alessandro Ros
2025-02-18 17:39:04 +01:00
committed by GitHub
parent 3409f00c90
commit 55556f1ecf
11 changed files with 387 additions and 25 deletions

View File

@@ -30,6 +30,7 @@ Features:
* Pause without disconnecting from the server * Pause without disconnecting from the server
* Server * Server
* Handle requests from clients * Handle requests from clients
* Validate client credentials
* Read media streams from clients ("record") * Read media streams from clients ("record")
* Read streams with the UDP or TCP transport protocol * Read streams with the UDP or TCP transport protocol
* Read TLS-encrypted streams (TCP only) * Read TLS-encrypted streams (TCP only)
@@ -94,6 +95,7 @@ Features:
* [client-record-format-vp9](examples/client-record-format-vp9/main.go) * [client-record-format-vp9](examples/client-record-format-vp9/main.go)
* [server](examples/server/main.go) * [server](examples/server/main.go)
* [server-tls](examples/server-tls/main.go) * [server-tls](examples/server-tls/main.go)
* [server-auth](examples/server-auth/main.go)
* [server-h264-save-to-disk](examples/server-h264-save-to-disk/main.go) * [server-h264-save-to-disk](examples/server-h264-save-to-disk/main.go)
* [proxy](examples/proxy/main.go) * [proxy](examples/proxy/main.go)

View File

@@ -34,6 +34,10 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/sdp" "github.com/bluenviron/gortsplib/v4/pkg/sdp"
) )
const (
clientUserAgent = "gortsplib"
)
// avoid an int64 overflow and preserve resolution by splitting division into two parts: // avoid an int64 overflow and preserve resolution by splitting division into two parts:
// first add the integer part, then the decimal part. // first add the integer part, then the decimal part.
func multiplyAndDivide(v, m, d time.Duration) time.Duration { func multiplyAndDivide(v, m, d time.Duration) time.Duration {
@@ -386,7 +390,7 @@ func (c *Client) Start(scheme string, host string) error {
return fmt.Errorf("MaxPacketSize must be less than %d", udpMaxPayloadSize) return fmt.Errorf("MaxPacketSize must be less than %d", udpMaxPayloadSize)
} }
if c.UserAgent == "" { if c.UserAgent == "" {
c.UserAgent = "gortsplib" c.UserAgent = clientUserAgent
} }
// system functions // system functions

View File

@@ -0,0 +1,195 @@
package main
import (
"log"
"sync"
"github.com/pion/rtp"
"github.com/bluenviron/gortsplib/v4"
"github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/format"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
)
// This example shows how to
// 1. create a RTSP server which accepts plain connections
// 2. allow a single client to publish a stream with TCP or UDP, if it provides credentials
// 3. allow multiple clients to read that stream with TCP, UDP or UDP-multicast, if they provide credentials
const (
// credentials required to publish the stream
publishUser = "publishuser"
publishPass = "publishpass"
// credentials required to read the stream
readUser = "readuser"
readPass = "readpass"
)
type serverHandler struct {
s *gortsplib.Server
mutex sync.Mutex
stream *gortsplib.ServerStream
publisher *gortsplib.ServerSession
}
// called when a connection is opened.
func (sh *serverHandler) OnConnOpen(ctx *gortsplib.ServerHandlerOnConnOpenCtx) {
log.Printf("conn opened")
}
// called when a connection is closed.
func (sh *serverHandler) OnConnClose(ctx *gortsplib.ServerHandlerOnConnCloseCtx) {
log.Printf("conn closed (%v)", ctx.Error)
}
// called when a session is opened.
func (sh *serverHandler) OnSessionOpen(ctx *gortsplib.ServerHandlerOnSessionOpenCtx) {
log.Printf("session opened")
}
// called when a session is closed.
func (sh *serverHandler) OnSessionClose(ctx *gortsplib.ServerHandlerOnSessionCloseCtx) {
log.Printf("session closed")
sh.mutex.Lock()
defer sh.mutex.Unlock()
// if the session is the publisher,
// close the stream and disconnect any reader.
if sh.stream != nil && ctx.Session == sh.publisher {
sh.stream.Close()
sh.stream = nil
}
}
// called when receiving a DESCRIBE request.
func (sh *serverHandler) OnDescribe(ctx *gortsplib.ServerHandlerOnDescribeCtx) (*base.Response, *gortsplib.ServerStream, error) {
log.Printf("describe request")
// Verify reader credentials.
// In case of readers, credentials have to be verified during DESCRIBE and SETUP.
ok := ctx.Conn.VerifyCredentials(ctx.Request, readUser, readPass)
if !ok {
return &base.Response{
StatusCode: base.StatusUnauthorized,
}, nil, liberrors.ErrServerAuth{}
}
sh.mutex.Lock()
defer sh.mutex.Unlock()
// no one is publishing yet
if sh.stream == nil {
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil, nil
}
// send medias that are being published to the client
return &base.Response{
StatusCode: base.StatusOK,
}, sh.stream, nil
}
// called when receiving an ANNOUNCE request.
func (sh *serverHandler) OnAnnounce(ctx *gortsplib.ServerHandlerOnAnnounceCtx) (*base.Response, error) {
log.Printf("announce request")
// Verify publisher credentials.
// In case of publishers, credentials have to be verified during ANNOUNCE.
ok := ctx.Conn.VerifyCredentials(ctx.Request, publishUser, publishPass)
if !ok {
return &base.Response{
StatusCode: base.StatusUnauthorized,
}, liberrors.ErrServerAuth{}
}
sh.mutex.Lock()
defer sh.mutex.Unlock()
// disconnect existing publisher
if sh.stream != nil {
sh.stream.Close()
sh.publisher.Close()
}
// create the stream and save the publisher
sh.stream = gortsplib.NewServerStream(sh.s, ctx.Description)
sh.publisher = ctx.Session
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
// called when receiving a SETUP request.
func (sh *serverHandler) OnSetup(ctx *gortsplib.ServerHandlerOnSetupCtx) (*base.Response, *gortsplib.ServerStream, error) {
log.Printf("setup request")
// Verify reader credentials.
// In case of readers, credentials have to be verified during DESCRIBE and SETUP.
if ctx.Session.State() == gortsplib.ServerSessionStateInitial {
ok := ctx.Conn.VerifyCredentials(ctx.Request, readUser, readPass)
if !ok {
return &base.Response{
StatusCode: base.StatusUnauthorized,
}, nil, liberrors.ErrServerAuth{}
}
}
// no one is publishing yet
if sh.stream == nil {
return &base.Response{
StatusCode: base.StatusNotFound,
}, nil, nil
}
return &base.Response{
StatusCode: base.StatusOK,
}, sh.stream, nil
}
// called when receiving a PLAY request.
func (sh *serverHandler) OnPlay(ctx *gortsplib.ServerHandlerOnPlayCtx) (*base.Response, error) {
log.Printf("play request")
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
// called when receiving a RECORD request.
func (sh *serverHandler) OnRecord(ctx *gortsplib.ServerHandlerOnRecordCtx) (*base.Response, error) {
log.Printf("record request")
// called when receiving a RTP packet
ctx.Session.OnPacketRTPAny(func(medi *description.Media, forma format.Format, pkt *rtp.Packet) {
// route the RTP packet to all readers
sh.stream.WritePacketRTP(medi, pkt)
})
return &base.Response{
StatusCode: base.StatusOK,
}, nil
}
func main() {
// configure the server
h := &serverHandler{}
h.s = &gortsplib.Server{
Handler: h,
RTSPAddress: ":8554",
UDPRTPAddress: ":8000",
UDPRTCPAddress: ":8001",
MulticastIPRange: "224.1.0.0/16",
MulticastRTPPort: 8002,
MulticastRTCPPort: 8003,
}
// start server and wait until a fatal error
log.Printf("server is ready")
panic(h.s.StartAndWait())
}

View File

@@ -53,11 +53,11 @@ func (se *Sender) AddAuthorization(req *base.Request) {
Method: se.authHeader.Method, Method: se.authHeader.Method,
} }
h.Username = se.user
if se.authHeader.Method == headers.AuthMethodBasic { if se.authHeader.Method == headers.AuthMethodBasic {
h.BasicUser = se.user
h.BasicPass = se.pass h.BasicPass = se.pass
} else { // digest } else { // digest
h.Username = se.user
h.Realm = se.authHeader.Realm h.Realm = se.authHeader.Realm
h.Nonce = se.authHeader.Nonce h.Nonce = se.authHeader.Nonce
h.URI = urStr h.URI = urStr

View File

@@ -61,7 +61,7 @@ const (
VerifyMethodDigestSHA256 VerifyMethodDigestSHA256
) )
// Verify validates a request sent by a client. // Verify verifies a request sent by a client.
func Verify( func Verify(
req *base.Request, req *base.Request,
user string, user string,
@@ -119,7 +119,7 @@ func Verify(
} }
case auth.Method == headers.AuthMethodBasic && contains(methods, VerifyMethodBasic): case auth.Method == headers.AuthMethodBasic && contains(methods, VerifyMethodBasic):
if auth.BasicUser != user { if auth.Username != user {
return fmt.Errorf("authentication failed") return fmt.Errorf("authentication failed")
} }

View File

@@ -13,11 +13,16 @@ type Authorization struct {
// authentication method // authentication method
Method AuthMethod Method AuthMethod
// username
Username string
// //
// Basic authentication fields // Basic authentication fields
// //
// user // user
//
// Deprecated: replaced by Username.
BasicUser string BasicUser string
// password // password
@@ -27,9 +32,6 @@ type Authorization struct {
// Digest authentication fields // Digest authentication fields
// //
// username
Username string
// realm // realm
Realm string Realm string
@@ -89,7 +91,8 @@ func (h *Authorization) Unmarshal(v base.HeaderValue) error {
return fmt.Errorf("invalid value") return fmt.Errorf("invalid value")
} }
h.BasicUser, h.BasicPass = tmp2[0], tmp2[1] h.Username, h.BasicPass = tmp2[0], tmp2[1]
h.BasicUser = h.Username
} else { // digest } else { // digest
kvs, err := keyValParse(v0, ',') kvs, err := keyValParse(v0, ',')
if err != nil { if err != nil {
@@ -149,8 +152,11 @@ func (h *Authorization) Unmarshal(v base.HeaderValue) error {
// Marshal encodes an Authorization header. // Marshal encodes an Authorization header.
func (h Authorization) Marshal() base.HeaderValue { func (h Authorization) Marshal() base.HeaderValue {
if h.Method == AuthMethodBasic { if h.Method == AuthMethodBasic {
if h.BasicUser != "" {
h.Username = h.BasicUser
}
return base.HeaderValue{"Basic " + return base.HeaderValue{"Basic " +
base64.StdEncoding.EncodeToString([]byte(h.BasicUser+":"+h.BasicPass))} base64.StdEncoding.EncodeToString([]byte(h.Username+":"+h.BasicPass))}
} }
ret := "Digest " + ret := "Digest " +

View File

@@ -24,6 +24,7 @@ var casesAuthorization = []struct {
base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="}, base.HeaderValue{"Basic bXl1c2VyOm15cGFzcw=="},
Authorization{ Authorization{
Method: AuthMethodBasic, Method: AuthMethodBasic,
Username: "myuser",
BasicUser: "myuser", BasicUser: "myuser",
BasicPass: "mypass", BasicPass: "mypass",
}, },

View File

@@ -270,3 +270,13 @@ func (ErrServerInvalidSetupPath) Error() string {
"This typically happens when VLC fails a request, and then switches to an " + "This typically happens when VLC fails a request, and then switches to an " +
"unsupported RTSP dialect" "unsupported RTSP dialect"
} }
// ErrServerAuth is an error that can be returned by a server.
// If a client did not provide credentials, it will be asked for
// credentials instead of being kicked out.
type ErrServerAuth struct{}
// Error implements the error interface.
func (e ErrServerAuth) Error() string {
return "authentication error"
}

View File

@@ -9,10 +9,16 @@ import (
"sync" "sync"
"time" "time"
"github.com/bluenviron/gortsplib/v4/pkg/auth"
"github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors" "github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
const (
serverHeader = "gortsplib"
serverAuthRealm = "ipcam"
)
func extractPort(address string) (int, error) { func extractPort(address string) (int, error) {
_, tmp, err := net.SplitHostPort(address) _, tmp, err := net.SplitHostPort(address)
if err != nil { if err != nil {
@@ -88,6 +94,9 @@ type Server struct {
MaxPacketSize int MaxPacketSize int
// disable automatic RTCP sender reports. // disable automatic RTCP sender reports.
DisableRTCPSenderReports bool DisableRTCPSenderReports bool
// authentication methods.
// It defaults to plain and digest+MD5.
AuthMethods []auth.VerifyMethod
// //
// handler (optional) // handler (optional)
@@ -156,6 +165,11 @@ func (s *Server) Start() error {
} else if s.MaxPacketSize > udpMaxPayloadSize { } else if s.MaxPacketSize > udpMaxPayloadSize {
return fmt.Errorf("MaxPacketSize must be less than %d", udpMaxPayloadSize) return fmt.Errorf("MaxPacketSize must be less than %d", udpMaxPayloadSize)
} }
if len(s.AuthMethods) == 0 {
// disable VerifyMethodDigestSHA256 unless explicitly set
// since it prevents FFmpeg from authenticating
s.AuthMethods = []auth.VerifyMethod{auth.VerifyMethodBasic, auth.VerifyMethodDigestMD5}
}
// system functions // system functions
if s.Listen == nil { if s.Listen == nil {

View File

@@ -10,10 +10,12 @@ import (
"strings" "strings"
"time" "time"
"github.com/bluenviron/gortsplib/v4/pkg/auth"
"github.com/bluenviron/gortsplib/v4/pkg/base" "github.com/bluenviron/gortsplib/v4/pkg/base"
"github.com/bluenviron/gortsplib/v4/pkg/bytecounter" "github.com/bluenviron/gortsplib/v4/pkg/bytecounter"
"github.com/bluenviron/gortsplib/v4/pkg/conn" "github.com/bluenviron/gortsplib/v4/pkg/conn"
"github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors" "github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
@@ -46,6 +48,12 @@ func serverSideDescription(d *description.Session) *description.Session {
return out return out
} }
func credentialsProvided(req *base.Request) bool {
var auth headers.Authorization
err := auth.Unmarshal(req.Header["Authorization"])
return err == nil && auth.Username != ""
}
type readReq struct { type readReq struct {
req *base.Request req *base.Request
res chan error res chan error
@@ -64,6 +72,7 @@ type ServerConn struct {
conn *conn.Conn conn *conn.Conn
session *ServerSession session *ServerSession
reader *serverConnReader reader *serverConnReader
authNonce string
// in // in
chRemoveSession chan *ServerSession chRemoveSession chan *ServerSession
@@ -137,6 +146,48 @@ func (sc *ServerConn) Stats() *StatsConn {
} }
} }
// VerifyCredentials verifies credentials provided by the user.
func (sc *ServerConn) VerifyCredentials(
req *base.Request,
expectedUser string,
expectedPass string,
) bool {
// we do not support using an empty string as user
// since it interferes with credentialsProvided()
if expectedUser == "" {
return false
}
if sc.authNonce == "" {
n, err := auth.GenerateNonce()
if err != nil {
return false
}
sc.authNonce = n
}
err := auth.Verify(
req,
expectedUser,
expectedPass,
sc.s.AuthMethods,
serverAuthRealm,
sc.authNonce)
return (err == nil)
}
func (sc *ServerConn) handleAuthError(req *base.Request, res *base.Response) error {
// if credentials have not been provided, clear error and send the WWW-Authenticate header.
if !credentialsProvided(req) {
res.Header["WWW-Authenticate"] = auth.GenerateWWWAuthenticate(sc.s.AuthMethods, serverAuthRealm, sc.authNonce)
return nil
}
// if credentials have been provided (and are wrong), close the connection.
return liberrors.ErrServerAuth{}
}
func (sc *ServerConn) ip() net.IP { func (sc *ServerConn) ip() net.IP {
return sc.remoteAddr.IP return sc.remoteAddr.IP
} }
@@ -386,14 +437,20 @@ func (sc *ServerConn) handleRequestOuter(req *base.Request) error {
res.Header = make(base.Header) res.Header = make(base.Header)
} }
// handle auth errors
var eerr1 liberrors.ErrServerAuth
if errors.As(err, &eerr1) {
err = sc.handleAuthError(req, res)
}
// add cseq // add cseq
var eerr liberrors.ErrServerCSeqMissing var eerr2 liberrors.ErrServerCSeqMissing
if !errors.As(err, &eerr) { if !errors.As(err, &eerr2) {
res.Header["CSeq"] = req.Header["CSeq"] res.Header["CSeq"] = req.Header["CSeq"]
} }
// add server // add server
res.Header["Server"] = base.HeaderValue{"gortsplib"} res.Header["Server"] = base.HeaderValue{serverHeader}
if h, ok := sc.s.Handler.(ServerHandlerOnResponse); ok { if h, ok := sc.s.Handler.(ServerHandlerOnResponse); ok {
h.OnResponse(sc, res) h.OnResponse(sc, res)

View File

@@ -3,6 +3,7 @@ package gortsplib
import ( import (
"fmt" "fmt"
"net" "net"
"net/http"
"testing" "testing"
"time" "time"
@@ -13,6 +14,7 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/conn" "github.com/bluenviron/gortsplib/v4/pkg/conn"
"github.com/bluenviron/gortsplib/v4/pkg/description" "github.com/bluenviron/gortsplib/v4/pkg/description"
"github.com/bluenviron/gortsplib/v4/pkg/headers" "github.com/bluenviron/gortsplib/v4/pkg/headers"
"github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
var serverCert = []byte(`-----BEGIN CERTIFICATE----- var serverCert = []byte(`-----BEGIN CERTIFICATE-----
@@ -1035,20 +1037,16 @@ func TestServerSessionTeardown(t *testing.T) {
} }
func TestServerAuth(t *testing.T) { func TestServerAuth(t *testing.T) {
nonce, err := auth.GenerateNonce() for _, method := range []string{"all", "basic", "digest_md5", "digest_sha256"} {
require.NoError(t, err) t.Run(method, func(t *testing.T) {
s := &Server{ s := &Server{
Handler: &testServerHandler{ Handler: &testServerHandler{
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) { onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
err2 := auth.Verify(ctx.Request, "myuser", "mypass", nil, "IPCAM", nonce) ok := ctx.Conn.VerifyCredentials(ctx.Request, "myuser", "mypass")
if err2 != nil { if !ok {
return &base.Response{ //nolint:nilerr return &base.Response{
StatusCode: base.StatusUnauthorized, StatusCode: base.StatusUnauthorized,
Header: base.Header{ }, liberrors.ErrServerAuth{}
"WWW-Authenticate": auth.GenerateWWWAuthenticate(nil, "IPCAM", nonce),
},
}, nil
} }
return &base.Response{ return &base.Response{
@@ -1057,9 +1055,22 @@ func TestServerAuth(t *testing.T) {
}, },
}, },
RTSPAddress: "localhost:8554", RTSPAddress: "localhost:8554",
AuthMethods: func() []auth.VerifyMethod {
switch method {
case "basic":
return []auth.VerifyMethod{auth.VerifyMethodBasic}
case "digest_md5":
return []auth.VerifyMethod{auth.VerifyMethodDigestMD5}
case "digest_sha256":
return []auth.VerifyMethod{auth.VerifyMethodDigestSHA256}
}
return nil
}(),
} }
err = s.Start() err := s.Start()
require.NoError(t, err) require.NoError(t, err)
defer s.Close() defer s.Close()
@@ -1091,4 +1102,66 @@ func TestServerAuth(t *testing.T) {
res, err = writeReqReadRes(conn, req) res, err = writeReqReadRes(conn, req)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, base.StatusOK, res.StatusCode) require.Equal(t, base.StatusOK, res.StatusCode)
})
}
}
func TestServerAuthFail(t *testing.T) {
s := &Server{
Handler: &testServerHandler{
onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) {
require.EqualError(t, ctx.Error, "authentication error")
},
onAnnounce: func(ctx *ServerHandlerOnAnnounceCtx) (*base.Response, error) {
ok := ctx.Conn.VerifyCredentials(ctx.Request, "myuser2", "mypass2")
if !ok {
return &base.Response{
StatusCode: http.StatusUnauthorized,
}, liberrors.ErrServerAuth{}
}
return &base.Response{
StatusCode: base.StatusOK,
}, nil
},
},
RTSPAddress: "localhost:8554",
}
err := s.Start()
require.NoError(t, err)
defer s.Close()
nconn, err := net.Dial("tcp", "localhost:8554")
require.NoError(t, err)
defer nconn.Close()
conn := conn.NewConn(nconn)
medias := []*description.Media{testH264Media}
req := base.Request{
Method: base.Announce,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"Content-Type": base.HeaderValue{"application/sdp"},
},
Body: mediasToSDP(medias),
}
res, err := writeReqReadRes(conn, req)
require.NoError(t, err)
require.Equal(t, base.StatusUnauthorized, res.StatusCode)
sender, err := auth.NewSender(res.Header["WWW-Authenticate"], "myuser", "mypass")
require.NoError(t, err)
sender.AddAuthorization(&req)
res, err = writeReqReadRes(conn, req)
require.NoError(t, err)
require.Equal(t, base.StatusUnauthorized, res.StatusCode)
_, err = writeReqReadRes(conn, req)
require.Error(t, err)
} }