Code refactoring for RTSP auth

This commit is contained in:
Alex X
2025-02-18 12:01:55 +03:00
parent 9e673559c4
commit 02ac3a6814
3 changed files with 9 additions and 9 deletions

View File

@@ -1,6 +1,7 @@
package rtsp package rtsp
import ( import (
"errors"
"io" "io"
"net" "net"
"net/url" "net/url"
@@ -237,7 +238,7 @@ func tcpHandler(conn *rtsp.Conn) {
}) })
if err := conn.Accept(); err != nil { if err := conn.Accept(); err != nil {
if err == rtsp.FailedAuth { if errors.Is(err, rtsp.FailedAuth) {
log.Warn().Str("remote_addr", conn.Connection.RemoteAddr).Msg("[rtsp] failed authentication") log.Warn().Str("remote_addr", conn.Connection.RemoteAddr).Msg("[rtsp] failed authentication")
} else if err != io.EOF { } else if err != io.EOF {
log.WithLevel(level).Err(err).Caller().Send() log.WithLevel(level).Err(err).Caller().Send()

View File

@@ -47,7 +47,7 @@ func (c *Conn) Accept() error {
c.Fire(req) c.Fire(req)
if !c.auth.Validate(req) { if valid, empty := c.auth.Validate(req); !valid {
res := &tcp.Response{ res := &tcp.Response{
Status: "401 Unauthorized", Status: "401 Unauthorized",
Header: map[string][]string{"Www-Authenticate": {`Basic realm="go2rtc"`}}, Header: map[string][]string{"Www-Authenticate": {`Basic realm="go2rtc"`}},
@@ -56,13 +56,12 @@ func (c *Conn) Accept() error {
if err = c.WriteResponse(res); err != nil { if err = c.WriteResponse(res); err != nil {
return err return err
} }
if req.Header.Get("Authorization") != "" { if empty {
// eliminate false positive: ffmpeg sends first request without // eliminate false positive: ffmpeg sends first request without
// authorization header even if the user provides credentials // authorization header even if the user provides credentials
return FailedAuth
} else {
continue continue
} }
return FailedAuth
} }
// Receiver: OPTIONS > DESCRIBE > SETUP... > PLAY > TEARDOWN // Receiver: OPTIONS > DESCRIBE > SETUP... > PLAY > TEARDOWN

View File

@@ -85,14 +85,14 @@ func (a *Auth) Write(req *Request) {
} }
} }
func (a *Auth) Validate(req *Request) bool { func (a *Auth) Validate(req *Request) (valid, empty bool) {
if a == nil { if a == nil {
return true return true, true
} }
header := req.Header.Get("Authorization") header := req.Header.Get("Authorization")
if header == "" { if header == "" {
return false return false, true
} }
if a.Method == AuthUnknown { if a.Method == AuthUnknown {
@@ -100,7 +100,7 @@ func (a *Auth) Validate(req *Request) bool {
a.header = "Basic " + B64(a.user, a.pass) a.header = "Basic " + B64(a.user, a.pass)
} }
return header == a.header return header == a.header, false
} }
func (a *Auth) ReadNone(res *Response) bool { func (a *Auth) ReadNone(res *Response) bool {