mirror of
https://github.com/aler9/gortsplib
synced 2025-09-30 21:12:18 +08:00
server: do not allow a connection to communicate with multiple sessions
This commit is contained in:
@@ -992,7 +992,8 @@ func TestServerErrorInvalidSession(t *testing.T) {
|
||||
Method: method,
|
||||
URL: mustParseURL("rtsp://localhost:8554/teststream"),
|
||||
Header: base.Header{
|
||||
"CSeq": base.HeaderValue{"1"},
|
||||
"CSeq": base.HeaderValue{"1"},
|
||||
"Session": base.HeaderValue{"ABC"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
114
serverconn.go
114
serverconn.go
@@ -39,9 +39,8 @@ type ServerConn struct {
|
||||
ctxCancel func()
|
||||
remoteAddr *net.TCPAddr
|
||||
br *bufio.Reader
|
||||
sessions map[string]*ServerSession
|
||||
session *ServerSession
|
||||
readFunc func(readRequest chan readReq) error
|
||||
tcpSession *ServerSession
|
||||
|
||||
// in
|
||||
sessionRemove chan *ServerSession
|
||||
@@ -110,7 +109,6 @@ func (sc *ServerConn) run() {
|
||||
}
|
||||
|
||||
sc.br = bufio.NewReaderSize(sc.conn, serverReadBufferSize)
|
||||
sc.sessions = make(map[string]*ServerSession)
|
||||
|
||||
readRequest := make(chan readReq)
|
||||
readErr := make(chan error)
|
||||
@@ -127,7 +125,9 @@ func (sc *ServerConn) run() {
|
||||
return err
|
||||
|
||||
case ss := <-sc.sessionRemove:
|
||||
delete(sc.sessions, ss.secretID)
|
||||
if sc.session == ss {
|
||||
sc.session = nil
|
||||
}
|
||||
|
||||
case <-sc.ctx.Done():
|
||||
return liberrors.ErrServerTerminated{}
|
||||
@@ -140,10 +140,10 @@ func (sc *ServerConn) run() {
|
||||
sc.conn.Close()
|
||||
<-readDone
|
||||
|
||||
for _, ss := range sc.sessions {
|
||||
if sc.session != nil {
|
||||
select {
|
||||
case ss.connRemove <- sc:
|
||||
case <-ss.ctx.Done():
|
||||
case sc.session.connRemove <- sc:
|
||||
case <-sc.session.ctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
@@ -211,14 +211,14 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
||||
sc.conn.SetReadDeadline(time.Time{})
|
||||
|
||||
select {
|
||||
case sc.tcpSession.startWriter <- struct{}{}:
|
||||
case <-sc.tcpSession.ctx.Done():
|
||||
case sc.session.startWriter <- struct{}{}:
|
||||
case <-sc.session.ctx.Done():
|
||||
}
|
||||
|
||||
var tcpReadBuffer *multibuffer.MultiBuffer
|
||||
var processFunc func(int, bool, []byte)
|
||||
|
||||
if sc.tcpSession.state == ServerSessionStatePlay {
|
||||
if sc.session.state == ServerSessionStatePlay {
|
||||
// when playing, tcpReadBuffer is only used to receive RTCP receiver reports,
|
||||
// that are much smaller than RTP packets and are sent at a fixed interval.
|
||||
// decrease RAM consumption by allocating less buffers.
|
||||
@@ -234,7 +234,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
||||
for _, pkt := range packets {
|
||||
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
||||
Session: sc.tcpSession,
|
||||
Session: sc.session,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
@@ -256,7 +256,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
||||
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTP); ok {
|
||||
h.OnPacketRTP(&ServerHandlerOnPacketRTPCtx{
|
||||
Session: sc.tcpSession,
|
||||
Session: sc.session,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
@@ -270,7 +270,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
||||
if h, ok := sc.s.Handler.(ServerHandlerOnPacketRTCP); ok {
|
||||
for _, pkt := range packets {
|
||||
h.OnPacketRTCP(&ServerHandlerOnPacketRTCPCtx{
|
||||
Session: sc.tcpSession,
|
||||
Session: sc.session,
|
||||
TrackID: trackID,
|
||||
Packet: pkt,
|
||||
})
|
||||
@@ -284,7 +284,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
||||
var frame base.InterleavedFrame
|
||||
|
||||
for {
|
||||
if sc.tcpSession.state == ServerSessionStateRecord {
|
||||
if sc.session.state == ServerSessionStateRecord {
|
||||
sc.conn.SetReadDeadline(time.Now().Add(sc.s.ReadTimeout))
|
||||
}
|
||||
|
||||
@@ -304,7 +304,7 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
||||
}
|
||||
|
||||
// forward frame only if it has been set up
|
||||
if trackID, ok := sc.tcpSession.tcpTracksByChannel[channel]; ok {
|
||||
if trackID, ok := sc.session.tcpTracksByChannel[channel]; ok {
|
||||
processFunc(trackID, isRTP, frame.Payload)
|
||||
}
|
||||
|
||||
@@ -334,18 +334,8 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
|
||||
sxID := getSessionID(req.Header)
|
||||
|
||||
// the connection can't communicate with another session
|
||||
// if it's receiving or sending TCP frames.
|
||||
if sc.tcpSession != nil &&
|
||||
sxID != sc.tcpSession.secretID {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerLinkedToOtherSession{}
|
||||
}
|
||||
|
||||
switch req.Method {
|
||||
case base.Options:
|
||||
// handle request in session
|
||||
if sxID != "" {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
}
|
||||
@@ -440,25 +430,32 @@ func (sc *ServerConn) handleRequest(req *base.Request) (*base.Response, error) {
|
||||
}
|
||||
|
||||
case base.Play:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
if sxID != "" {
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPlay); ok {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
}
|
||||
}
|
||||
|
||||
case base.Record:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
if sxID != "" {
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnRecord); ok {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
}
|
||||
}
|
||||
|
||||
case base.Pause:
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
if sxID != "" {
|
||||
if _, ok := sc.s.Handler.(ServerHandlerOnPause); ok {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
}
|
||||
}
|
||||
|
||||
case base.Teardown:
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
if sxID != "" {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
}
|
||||
|
||||
case base.GetParameter:
|
||||
// handle request in session
|
||||
if sxID != "" {
|
||||
return sc.handleRequestInSession(sxID, req, false)
|
||||
}
|
||||
@@ -544,28 +541,34 @@ func (sc *ServerConn) handleRequestInSession(
|
||||
req *base.Request,
|
||||
create bool,
|
||||
) (*base.Response, error) {
|
||||
// if the session is already linked to this conn, communicate directly with it
|
||||
if sxID != "" {
|
||||
if ss, ok := sc.sessions[sxID]; ok {
|
||||
cres := make(chan sessionRequestRes)
|
||||
sreq := sessionRequestReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: create,
|
||||
res: cres,
|
||||
}
|
||||
// handle directly in Session
|
||||
if sc.session != nil {
|
||||
// the connection can't communicate with two sessions at once.
|
||||
if sxID != sc.session.secretID {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerLinkedToOtherSession{}
|
||||
}
|
||||
|
||||
select {
|
||||
case ss.request <- sreq:
|
||||
res := <-cres
|
||||
return res.res, res.err
|
||||
cres := make(chan sessionRequestRes)
|
||||
sreq := sessionRequestReq{
|
||||
sc: sc,
|
||||
req: req,
|
||||
id: sxID,
|
||||
create: create,
|
||||
res: cres,
|
||||
}
|
||||
|
||||
case <-ss.ctx.Done():
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerTerminated{}
|
||||
}
|
||||
select {
|
||||
case sc.session.request <- sreq:
|
||||
res := <-cres
|
||||
sc.session = res.ss
|
||||
return res.res, res.err
|
||||
|
||||
case <-sc.session.ctx.Done():
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerTerminated{}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -582,10 +585,7 @@ func (sc *ServerConn) handleRequestInSession(
|
||||
select {
|
||||
case sc.s.sessionRequest <- sreq:
|
||||
res := <-cres
|
||||
if res.ss != nil {
|
||||
sc.sessions[res.ss.secretID] = res.ss
|
||||
}
|
||||
|
||||
sc.session = res.ss
|
||||
return res.res, res.err
|
||||
|
||||
case <-sc.s.ctx.Done():
|
||||
|
@@ -279,6 +279,7 @@ func (ss *ServerSession) run() {
|
||||
|
||||
res, err := ss.handleRequest(req.sc, req.req)
|
||||
|
||||
var returnedSession *ServerSession
|
||||
if err == nil || err == errSwitchReadFunc {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
@@ -291,6 +292,10 @@ func (ss *ServerSession) run() {
|
||||
return &v
|
||||
}(),
|
||||
}.Write()
|
||||
|
||||
if req.req.Method != base.Teardown {
|
||||
returnedSession = ss
|
||||
}
|
||||
}
|
||||
|
||||
savedMethod := req.req.Method
|
||||
@@ -298,11 +303,11 @@ func (ss *ServerSession) run() {
|
||||
req.res <- sessionRequestRes{
|
||||
res: res,
|
||||
err: err,
|
||||
ss: ss,
|
||||
ss: returnedSession,
|
||||
}
|
||||
|
||||
if (err == nil || err == errSwitchReadFunc) && savedMethod == base.Teardown {
|
||||
return liberrors.ErrServerSessionTeardown{}
|
||||
return liberrors.ErrServerSessionTeardown{Author: req.sc.NetConn().RemoteAddr()}
|
||||
}
|
||||
|
||||
case sc := <-ss.connRemove:
|
||||
@@ -871,7 +876,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
ss.tcpConn.tcpSession = ss
|
||||
|
||||
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
|
||||
err = errSwitchReadFunc
|
||||
@@ -997,7 +1001,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
ss.tcpConn.tcpSession = ss
|
||||
|
||||
ss.tcpConn.readFunc = ss.tcpConn.readFuncTCP
|
||||
err = errSwitchReadFunc
|
||||
@@ -1067,7 +1070,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard
|
||||
err = errSwitchReadFunc
|
||||
|
||||
ss.tcpConn.tcpSession = nil
|
||||
ss.tcpConn = nil
|
||||
}
|
||||
|
||||
@@ -1087,7 +1089,6 @@ func (ss *ServerSession) handleRequest(sc *ServerConn, req *base.Request) (*base
|
||||
ss.tcpConn.readFunc = ss.tcpConn.readFuncStandard
|
||||
err = errSwitchReadFunc
|
||||
|
||||
ss.tcpConn.tcpSession = nil
|
||||
ss.tcpConn = nil
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user