server: do not allow a connection to communicate with multiple sessions

This commit is contained in:
aler9
2022-02-19 18:52:05 +01:00
parent eae1e120f1
commit 7dd4842fc0
3 changed files with 66 additions and 64 deletions

View File

@@ -992,7 +992,8 @@ func TestServerErrorInvalidSession(t *testing.T) {
Method: method,
URL: mustParseURL("rtsp://localhost:8554/teststream"),
Header: base.Header{
"CSeq": base.HeaderValue{"1"},
"CSeq": base.HeaderValue{"1"},
"Session": base.HeaderValue{"ABC"},
},
})
require.NoError(t, err)

View File

@@ -39,9 +39,8 @@ type ServerConn struct {
ctxCancel func()
remoteAddr *net.TCPAddr
br *bufio.Reader
sessions map[string]*ServerSession
session *ServerSession
readFunc func(readRequest chan readReq) error
tcpSession *ServerSession
// in
sessionRemove chan *ServerSession
@@ -110,7 +109,6 @@ func (sc *ServerConn) run() {
}
sc.br = bufio.NewReaderSize(sc.conn, serverReadBufferSize)
sc.sessions = make(map[string]*ServerSession)
readRequest := make(chan readReq)
readErr := make(chan error)
@@ -127,7 +125,9 @@ func (sc *ServerConn) run() {
return err
case ss := <-sc.sessionRemove:
delete(sc.sessions, ss.secretID)
if sc.session == ss {
sc.session = nil
}
case <-sc.ctx.Done():
return liberrors.ErrServerTerminated{}
@@ -140,10 +140,10 @@ func (sc *ServerConn) run() {
sc.conn.Close()
<-readDone
for _, ss := range sc.sessions {
if sc.session != nil {
select {
case ss.connRemove <- sc:
case <-ss.ctx.Done():
case sc.session.connRemove <- sc:
case <-sc.session.ctx.Done():
}
}
@@ -211,14 +211,14 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
sc.conn.SetReadDeadline(time.Time{})
select {
case sc.tcpSession.startWriter <- struct{}{}:
case <-sc.tcpSession.ctx.Done():
case sc.session.startWriter <- struct{}{}:
case <-sc.session.ctx.Done():
}
var tcpReadBuffer *multibuffer.MultiBuffer
var processFunc func(int, bool, []byte)
if sc.tcpSession.state == ServerSessionStatePlay {
if sc.session.state == ServerSessionStatePlay {
// when playing, tcpReadBuffer is only used to receive RTCP receiver reports,
// that are much smaller than RTP packets and are sent at a fixed interval.
// decrease RAM consumption by allocating less buffers.
@@ -234,7 +234,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
for _, pkt := range packets {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: sc.tcpSession,
Session: sc.session,
TrackID: trackID,
Packet: pkt,
})
@@ -256,7 +256,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
Session: sc.tcpSession,
Session: sc.session,
TrackID: trackID,
Packet: pkt,
})
@@ -270,7 +270,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
for _, pkt := range packets {
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
Session: sc.tcpSession,
Session: sc.session,
TrackID: trackID,
Packet: pkt,
})
@@ -284,7 +284,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
var frame base.InterleavedFrame
for {
if sc.tcpSession.state == ServerSessionStateRecord {
if sc.session.state == ServerSessionStateRecord {
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
}
@@ -304,7 +304,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
}
// forward frame only if it has been set up
if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok {
if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok {
processFunc(trackID, isRTP, frame.Payload)
}
@@ -334,18 +334,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
sxID := getSessionID(req.Header)
// the connection can't communicate with another session
// if it's receiving or sending TCP frames.
if sc.tcpSession != nil &&
sxID != sc.tcpSession.secretID {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerLinkedToOtherSession{}
}
switch req.Method {
case base.Options:
// handle request in session
if sxID != "" {
return sc.handleRequestInSession(sxID, req, false)
}
@@ -440,25 +430,32 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
}
case base.Play:
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
return sc.handleRequestInSession(sxID, req, false)
if sxID != "" {
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
return sc.handleRequestInSession(sxID, req, false)
}
}
case base.Record:
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
return sc.handleRequestInSession(sxID, req, false)
if sxID != "" {
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
return sc.handleRequestInSession(sxID, req, false)
}
}
case base.Pause:
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
return sc.handleRequestInSession(sxID, req, false)
if sxID != "" {
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
return sc.handleRequestInSession(sxID, req, false)
}
}
case base.Teardown:
return sc.handleRequestInSession(sxID, req, false)
if sxID != "" {
return sc.handleRequestInSession(sxID, req, false)
}
case base.GetParameter:
// handle request in session
if sxID != "" {
return sc.handleRequestInSession(sxID, req, false)
}
@@ -544,28 +541,34 @@ func (sc *ServerConn) handleRequestInSession(
req *base.Request,
create bool,
) (*base.Response, error) {
// if the session is already linked to this conn, communicate directly with it
if sxID != "" {
if ss, ok := sc.sessions[sxID]; ok {
cres := make(chan sessionRequestRes)
sreq := sessionRequestReq{
sc: sc,
req: req,
id: sxID,
create: create,
res: cres,
}
// handle directly in Session
if sc.session != nil {
// the connection can't communicate with two sessions at once.
if sxID != sc.session.secretID {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerLinkedToOtherSession{}
}
select {
case ss.request <- sreq:
res := <-cres
return res.res, res.err
cres := make(chan sessionRequestRes)
sreq := sessionRequestReq{
sc: sc,
req: req,
id: sxID,
create: create,
res: cres,
}
case <-ss.ctx.Done():
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTerminated{}
}
select {
case sc.session.request <- sreq:
res := <-cres
sc.session = res.ss
return res.res, res.err
case <-sc.session.ctx.Done():
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerTerminated{}
}
}
@@ -582,10 +585,7 @@ func (sc *ServerConn) handleRequestInSession(
select {
case sc.s.sessionRequest <- sreq:
res := <-cres
if res.ss != nil {
sc.sessions[res.ss.secretID] = res.ss
}
sc.session = res.ss
return res.res, res.err
case <-sc.s.ctx.Done():

View File

@@ -279,6 +279,7 @@ func (ss *ServerSession) run() {
res, err := ss.handleRequest(req.sc, req.req)
var returnedSession *ServerSession
if err == nil || err == errSwitchReadFunc {
if res.Header == nil {
res.Header = make(base.Header)
@@ -291,6 +292,10 @@ func (ss *ServerSession) run() {
return &v
}(),
}.Write()
if req.req.Method != base.Teardown {
returnedSession = ss
}
}
savedMethod := req.req.Method
@@ -298,11 +303,11 @@ func (ss *ServerSession) run() {
req.res <- sessionRequestRes{
res: res,
err: err,
ss: ss,
ss: returnedSession,
}
if (err == nil || err == errSwitchReadFunc) && savedMethod == base.Teardown {
return liberrors.ErrServerSessionTeardown{}
return liberrors.ErrServerSessionTeardown{Author: req.sc.NetConn().RemoteAddr()}
}
case sc := <-ss.connRemove:
@@ -871,7 +876,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
default: // TCP
ss.tcpConn = sc
ss.tcpConn.tcpSession = ss
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
err = errSwitchReadFunc
@@ -997,7 +1001,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
default: // TCP
ss.tcpConn = sc
ss.tcpConn.tcpSession = ss
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
err = errSwitchReadFunc
@@ -1067,7 +1070,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard
err = errSwitchReadFunc
ss.tcpConn.tcpSession = nil
ss.tcpConn = nil
}
@@ -1087,7 +1089,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard
err = errSwitchReadFunc
ss.tcpConn.tcpSession = nil
ss.tcpConn = nil
}
}