server: prevent wrong OnSetup / OnDescribe usage (#732)

This commit is contained in:
Alessandro Ros
2025-03-23 16:42:29 +01:00
committed by GitHub
parent 304c38bb60
commit 8c6495c33b
3 changed files with 242 additions and 231 deletions

View File

@@ -872,15 +872,13 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
Description: &desc,
})
if res.StatusCode != base.StatusOK {
return res, err
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStatePreRecord
ss.setuppedPath = path
ss.setuppedQuery = query
ss.announcedDesc = &desc
}
ss.state = ServerSessionStatePreRecord
ss.setuppedPath = path
ss.setuppedQuery = query
ss.announcedDesc = &desc
return res, err
case base.Setup:
@@ -1017,121 +1015,132 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
}
}
if res.StatusCode != base.StatusOK {
return res, err
if ss.state == ServerSessionStatePreRecord && stream != nil {
panic("stream must be nil when handling publishers")
}
var medi *description.Media
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePrePlay: // play
medi = findMediaByTrackID(stream.Desc.Medias, trackID)
default: // record
medi = findMediaByURL(ss.announcedDesc.Medias, path, query, req.URL)
}
if res.StatusCode == base.StatusOK {
var medi *description.Media
if medi == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerMediaNotFound{}
}
switch ss.state {
case ServerSessionStateInitial, ServerSessionStatePrePlay: // play
if stream == nil {
panic("stream cannot be nil when StatusCode is StatusOK")
}
if _, ok := ss.setuppedMedias[medi]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerMediaAlreadySetup{}
}
medi = findMediaByTrackID(stream.Desc.Medias, trackID)
default: // record
medi = findMediaByURL(ss.announcedDesc.Medias, path, query, req.URL)
}
ss.setuppedTransport = &transport
if ss.state == ServerSessionStateInitial {
err = stream.readerAdd(ss,
inTH.ClientPorts,
)
if err != nil {
if medi == nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}, liberrors.ErrServerMediaNotFound{}
}
ss.state = ServerSessionStatePrePlay
ss.setuppedPath = path
ss.setuppedQuery = query
ss.setuppedStream = stream
}
th := headers.Transport{}
if ss.state == ServerSessionStatePrePlay {
ssrc, ok := stream.localSSRC(medi)
if ok {
th.SSRC = &ssrc
}
}
if res.Header == nil {
res.Header = make(base.Header)
}
sm := &serverSessionMedia{
ss: ss,
media: medi,
onPacketRTCP: func(_ rtcp.Packet) {},
}
sm.initialize()
switch transport {
case TransportUDP:
sm.udpRTPReadPort = inTH.ClientPorts[0]
sm.udpRTCPReadPort = inTH.ClientPorts[1]
sm.udpRTPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTPReadPort,
if _, ok := ss.setuppedMedias[medi]; ok {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, liberrors.ErrServerMediaAlreadySetup{}
}
sm.udpRTCPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTCPReadPort,
ss.setuppedTransport = &transport
if ss.state == ServerSessionStateInitial {
err = stream.readerAdd(ss,
inTH.ClientPorts,
)
if err != nil {
return &base.Response{
StatusCode: base.StatusBadRequest,
}, err
}
ss.state = ServerSessionStatePrePlay
ss.setuppedPath = path
ss.setuppedQuery = query
ss.setuppedStream = stream
}
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
th := headers.Transport{}
case TransportUDPMulticast:
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryMulticast
th.Delivery = &de
v := uint(127)
th.TTL = &v
d := stream.medias[medi].multicastWriter.ip()
th.Destination = &d
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
if ss.state == ServerSessionStatePrePlay {
if stream != ss.setuppedStream {
panic("stream cannot be different than the one returned in previous OnSetup call")
}
default: // TCP
if inTH.InterleavedIDs != nil {
sm.tcpChannel = inTH.InterleavedIDs[0]
} else {
sm.tcpChannel = ss.findFreeChannelPair()
ssrc, ok := stream.localSSRC(medi)
if ok {
th.SSRC = &ssrc
}
}
th.Protocol = headers.TransportProtocolTCP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1}
}
if res.Header == nil {
res.Header = make(base.Header)
}
if ss.setuppedMedias == nil {
ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia)
}
ss.setuppedMedias[medi] = sm
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
sm := &serverSessionMedia{
ss: ss,
media: medi,
onPacketRTCP: func(_ rtcp.Packet) {},
}
sm.initialize()
res.Header["Transport"] = th.Marshal()
switch transport {
case TransportUDP:
sm.udpRTPReadPort = inTH.ClientPorts[0]
sm.udpRTCPReadPort = inTH.ClientPorts[1]
sm.udpRTPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTPReadPort,
}
sm.udpRTCPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(),
Zone: ss.author.zone(),
Port: sm.udpRTCPReadPort,
}
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.ClientPorts = inTH.ClientPorts
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
case TransportUDPMulticast:
th.Protocol = headers.TransportProtocolUDP
de := headers.TransportDeliveryMulticast
th.Delivery = &de
v := uint(127)
th.TTL = &v
d := stream.medias[medi].multicastWriter.ip()
th.Destination = &d
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
default: // TCP
if inTH.InterleavedIDs != nil {
sm.tcpChannel = inTH.InterleavedIDs[0]
} else {
sm.tcpChannel = ss.findFreeChannelPair()
}
th.Protocol = headers.TransportProtocolTCP
de := headers.TransportDeliveryUnicast
th.Delivery = &de
th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1}
}
if ss.setuppedMedias == nil {
ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia)
}
ss.setuppedMedias[medi] = sm
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
res.Header["Transport"] = th.Marshal()
}
return res, err
@@ -1166,64 +1175,61 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
Query: query,
})
if res.StatusCode != base.StatusOK {
if res.StatusCode == base.StatusOK {
if ss.state != ServerSessionStatePlay {
ss.state = ServerSessionStatePlay
v := ss.s.timeNow().Unix()
ss.udpLastPacketTime = &v
ss.timeDecoder = &rtptime.GlobalDecoder2{}
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
}
if *ss.setuppedTransport == TransportTCP {
ss.tcpFrame = &base.InterleavedFrame{}
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
}
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.startWriter()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
ss.setuppedStream.readerSetActive(ss)
rtpInfo, ok := generateRTPInfo(
ss.s.timeNow(),
ss.setuppedMediasOrdered,
ss.setuppedStream,
ss.setuppedPath,
req.URL)
if ok {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["RTP-Info"] = rtpInfo.Marshal()
}
}
} else {
if ss.state != ServerSessionStatePlay &&
*ss.setuppedTransport != TransportUDPMulticast {
ss.destroyWriter()
}
return res, err
}
if ss.state == ServerSessionStatePlay {
return res, err
}
ss.state = ServerSessionStatePlay
v := ss.s.timeNow().Unix()
ss.udpLastPacketTime = &v
ss.timeDecoder = &rtptime.GlobalDecoder2{}
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
}
if *ss.setuppedTransport == TransportTCP {
ss.tcpFrame = &base.InterleavedFrame{}
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
}
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.startWriter()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
ss.setuppedStream.readerSetActive(ss)
rtpInfo, ok := generateRTPInfo(
ss.s.timeNow(),
ss.setuppedMediasOrdered,
ss.setuppedStream,
ss.setuppedPath,
req.URL)
if ok {
if res.Header == nil {
res.Header = make(base.Header)
}
res.Header["RTP-Info"] = rtpInfo.Marshal()
}
return res, err
@@ -1260,38 +1266,37 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
Query: query,
})
if res.StatusCode != base.StatusOK {
if res.StatusCode == base.StatusOK {
ss.state = ServerSessionStateRecord
v := ss.s.timeNow().Unix()
ss.udpLastPacketTime = &v
ss.timeDecoder = &rtptime.GlobalDecoder2{}
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
}
if *ss.setuppedTransport == TransportTCP {
ss.tcpFrame = &base.InterleavedFrame{}
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
}
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.startWriter()
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
} else {
ss.destroyWriter()
return res, err
}
ss.state = ServerSessionStateRecord
v := ss.s.timeNow().Unix()
ss.udpLastPacketTime = &v
ss.timeDecoder = &rtptime.GlobalDecoder2{}
ss.timeDecoder.Initialize()
for _, sm := range ss.setuppedMedias {
sm.start()
}
if *ss.setuppedTransport == TransportTCP {
ss.tcpFrame = &base.InterleavedFrame{}
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
}
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
ss.startWriter()
default: // TCP
ss.tcpConn = sc
err = switchReadFuncError{true}
// startWriter() is called by ServerConn, through chAsyncStartWriter,
// after the response has been sent
}
return res, err
@@ -1317,50 +1322,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
Query: query,
})
if res.StatusCode != base.StatusOK {
return res, err
}
if res.StatusCode == base.StatusOK {
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
ss.destroyWriter()
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
ss.destroyWriter()
if ss.setuppedStream != nil {
ss.setuppedStream.readerSetInactive(ss)
}
for _, sm := range ss.setuppedMedias {
sm.stop()
}
ss.timeDecoder = nil
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
if ss.setuppedStream != nil {
ss.setuppedStream.readerSetInactive(ss)
}
case ServerSessionStateRecord:
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
for _, sm := range ss.setuppedMedias {
sm.stop()
}
ss.state = ServerSessionStatePreRecord
ss.timeDecoder = nil
switch ss.state {
case ServerSessionStatePlay:
ss.state = ServerSessionStatePrePlay
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
case TransportUDPMulticast:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
}
case ServerSessionStateRecord:
switch *ss.setuppedTransport {
case TransportUDP:
ss.udpCheckStreamTimer = emptyTimer()
default: // TCP
err = switchReadFuncError{false}
ss.tcpConn = nil
}
ss.state = ServerSessionStatePreRecord
}
}
}