mirror of
https://github.com/aler9/gortsplib
synced 2025-12-24 13:38:08 +08:00
server: prevent wrong OnSetup / OnDescribe usage (#732)
This commit is contained in:
@@ -872,15 +872,13 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
Description: &desc,
|
||||
})
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
return res, err
|
||||
if res.StatusCode == base.StatusOK {
|
||||
ss.state = ServerSessionStatePreRecord
|
||||
ss.setuppedPath = path
|
||||
ss.setuppedQuery = query
|
||||
ss.announcedDesc = &desc
|
||||
}
|
||||
|
||||
ss.state = ServerSessionStatePreRecord
|
||||
ss.setuppedPath = path
|
||||
ss.setuppedQuery = query
|
||||
ss.announcedDesc = &desc
|
||||
|
||||
return res, err
|
||||
|
||||
case base.Setup:
|
||||
@@ -1017,121 +1015,132 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
}
|
||||
}
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
return res, err
|
||||
if ss.state == ServerSessionStatePreRecord && stream != nil {
|
||||
panic("stream must be nil when handling publishers")
|
||||
}
|
||||
|
||||
var medi *description.Media
|
||||
switch ss.state {
|
||||
case ServerSessionStateInitial, ServerSessionStatePrePlay: // play
|
||||
medi = findMediaByTrackID(stream.Desc.Medias, trackID)
|
||||
default: // record
|
||||
medi = findMediaByURL(ss.announcedDesc.Medias, path, query, req.URL)
|
||||
}
|
||||
if res.StatusCode == base.StatusOK {
|
||||
var medi *description.Media
|
||||
|
||||
if medi == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerMediaNotFound{}
|
||||
}
|
||||
switch ss.state {
|
||||
case ServerSessionStateInitial, ServerSessionStatePrePlay: // play
|
||||
if stream == nil {
|
||||
panic("stream cannot be nil when StatusCode is StatusOK")
|
||||
}
|
||||
|
||||
if _, ok := ss.setuppedMedias[medi]; ok {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerMediaAlreadySetup{}
|
||||
}
|
||||
medi = findMediaByTrackID(stream.Desc.Medias, trackID)
|
||||
default: // record
|
||||
medi = findMediaByURL(ss.announcedDesc.Medias, path, query, req.URL)
|
||||
}
|
||||
|
||||
ss.setuppedTransport = &transport
|
||||
|
||||
if ss.state == ServerSessionStateInitial {
|
||||
err = stream.readerAdd(ss,
|
||||
inTH.ClientPorts,
|
||||
)
|
||||
if err != nil {
|
||||
if medi == nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}, liberrors.ErrServerMediaNotFound{}
|
||||
}
|
||||
|
||||
ss.state = ServerSessionStatePrePlay
|
||||
ss.setuppedPath = path
|
||||
ss.setuppedQuery = query
|
||||
ss.setuppedStream = stream
|
||||
}
|
||||
|
||||
th := headers.Transport{}
|
||||
|
||||
if ss.state == ServerSessionStatePrePlay {
|
||||
ssrc, ok := stream.localSSRC(medi)
|
||||
if ok {
|
||||
th.SSRC = &ssrc
|
||||
}
|
||||
}
|
||||
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
sm := &serverSessionMedia{
|
||||
ss: ss,
|
||||
media: medi,
|
||||
onPacketRTCP: func(_ rtcp.Packet) {},
|
||||
}
|
||||
sm.initialize()
|
||||
|
||||
switch transport {
|
||||
case TransportUDP:
|
||||
sm.udpRTPReadPort = inTH.ClientPorts[0]
|
||||
sm.udpRTCPReadPort = inTH.ClientPorts[1]
|
||||
|
||||
sm.udpRTPWriteAddr = &net.UDPAddr{
|
||||
IP: ss.author.ip(),
|
||||
Zone: ss.author.zone(),
|
||||
Port: sm.udpRTPReadPort,
|
||||
if _, ok := ss.setuppedMedias[medi]; ok {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrServerMediaAlreadySetup{}
|
||||
}
|
||||
|
||||
sm.udpRTCPWriteAddr = &net.UDPAddr{
|
||||
IP: ss.author.ip(),
|
||||
Zone: ss.author.zone(),
|
||||
Port: sm.udpRTCPReadPort,
|
||||
ss.setuppedTransport = &transport
|
||||
|
||||
if ss.state == ServerSessionStateInitial {
|
||||
err = stream.readerAdd(ss,
|
||||
inTH.ClientPorts,
|
||||
)
|
||||
if err != nil {
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, err
|
||||
}
|
||||
|
||||
ss.state = ServerSessionStatePrePlay
|
||||
ss.setuppedPath = path
|
||||
ss.setuppedQuery = query
|
||||
ss.setuppedStream = stream
|
||||
}
|
||||
|
||||
th.Protocol = headers.TransportProtocolUDP
|
||||
de := headers.TransportDeliveryUnicast
|
||||
th.Delivery = &de
|
||||
th.ClientPorts = inTH.ClientPorts
|
||||
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
|
||||
th := headers.Transport{}
|
||||
|
||||
case TransportUDPMulticast:
|
||||
th.Protocol = headers.TransportProtocolUDP
|
||||
de := headers.TransportDeliveryMulticast
|
||||
th.Delivery = &de
|
||||
v := uint(127)
|
||||
th.TTL = &v
|
||||
d := stream.medias[medi].multicastWriter.ip()
|
||||
th.Destination = &d
|
||||
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
|
||||
if ss.state == ServerSessionStatePrePlay {
|
||||
if stream != ss.setuppedStream {
|
||||
panic("stream cannot be different than the one returned in previous OnSetup call")
|
||||
}
|
||||
|
||||
default: // TCP
|
||||
if inTH.InterleavedIDs != nil {
|
||||
sm.tcpChannel = inTH.InterleavedIDs[0]
|
||||
} else {
|
||||
sm.tcpChannel = ss.findFreeChannelPair()
|
||||
ssrc, ok := stream.localSSRC(medi)
|
||||
if ok {
|
||||
th.SSRC = &ssrc
|
||||
}
|
||||
}
|
||||
|
||||
th.Protocol = headers.TransportProtocolTCP
|
||||
de := headers.TransportDeliveryUnicast
|
||||
th.Delivery = &de
|
||||
th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1}
|
||||
}
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
|
||||
if ss.setuppedMedias == nil {
|
||||
ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia)
|
||||
}
|
||||
ss.setuppedMedias[medi] = sm
|
||||
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
|
||||
sm := &serverSessionMedia{
|
||||
ss: ss,
|
||||
media: medi,
|
||||
onPacketRTCP: func(_ rtcp.Packet) {},
|
||||
}
|
||||
sm.initialize()
|
||||
|
||||
res.Header["Transport"] = th.Marshal()
|
||||
switch transport {
|
||||
case TransportUDP:
|
||||
sm.udpRTPReadPort = inTH.ClientPorts[0]
|
||||
sm.udpRTCPReadPort = inTH.ClientPorts[1]
|
||||
|
||||
sm.udpRTPWriteAddr = &net.UDPAddr{
|
||||
IP: ss.author.ip(),
|
||||
Zone: ss.author.zone(),
|
||||
Port: sm.udpRTPReadPort,
|
||||
}
|
||||
|
||||
sm.udpRTCPWriteAddr = &net.UDPAddr{
|
||||
IP: ss.author.ip(),
|
||||
Zone: ss.author.zone(),
|
||||
Port: sm.udpRTCPReadPort,
|
||||
}
|
||||
|
||||
th.Protocol = headers.TransportProtocolUDP
|
||||
de := headers.TransportDeliveryUnicast
|
||||
th.Delivery = &de
|
||||
th.ClientPorts = inTH.ClientPorts
|
||||
th.ServerPorts = &[2]int{sc.s.udpRTPListener.port(), sc.s.udpRTCPListener.port()}
|
||||
|
||||
case TransportUDPMulticast:
|
||||
th.Protocol = headers.TransportProtocolUDP
|
||||
de := headers.TransportDeliveryMulticast
|
||||
th.Delivery = &de
|
||||
v := uint(127)
|
||||
th.TTL = &v
|
||||
d := stream.medias[medi].multicastWriter.ip()
|
||||
th.Destination = &d
|
||||
th.Ports = &[2]int{ss.s.MulticastRTPPort, ss.s.MulticastRTCPPort}
|
||||
|
||||
default: // TCP
|
||||
if inTH.InterleavedIDs != nil {
|
||||
sm.tcpChannel = inTH.InterleavedIDs[0]
|
||||
} else {
|
||||
sm.tcpChannel = ss.findFreeChannelPair()
|
||||
}
|
||||
|
||||
th.Protocol = headers.TransportProtocolTCP
|
||||
de := headers.TransportDeliveryUnicast
|
||||
th.Delivery = &de
|
||||
th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1}
|
||||
}
|
||||
|
||||
if ss.setuppedMedias == nil {
|
||||
ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia)
|
||||
}
|
||||
ss.setuppedMedias[medi] = sm
|
||||
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
|
||||
|
||||
res.Header["Transport"] = th.Marshal()
|
||||
}
|
||||
|
||||
return res, err
|
||||
|
||||
@@ -1166,64 +1175,61 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
Query: query,
|
||||
})
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if ss.state != ServerSessionStatePlay {
|
||||
ss.state = ServerSessionStatePlay
|
||||
|
||||
v := ss.s.timeNow().Unix()
|
||||
ss.udpLastPacketTime = &v
|
||||
|
||||
ss.timeDecoder = &rtptime.GlobalDecoder2{}
|
||||
ss.timeDecoder.Initialize()
|
||||
|
||||
for _, sm := range ss.setuppedMedias {
|
||||
sm.start()
|
||||
}
|
||||
|
||||
if *ss.setuppedTransport == TransportTCP {
|
||||
ss.tcpFrame = &base.InterleavedFrame{}
|
||||
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
|
||||
}
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
ss.startWriter()
|
||||
|
||||
case TransportUDPMulticast:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
err = switchReadFuncError{true}
|
||||
// startWriter() is called by ServerConn, through chAsyncStartWriter,
|
||||
// after the response has been sent
|
||||
}
|
||||
|
||||
ss.setuppedStream.readerSetActive(ss)
|
||||
|
||||
rtpInfo, ok := generateRTPInfo(
|
||||
ss.s.timeNow(),
|
||||
ss.setuppedMediasOrdered,
|
||||
ss.setuppedStream,
|
||||
ss.setuppedPath,
|
||||
req.URL)
|
||||
|
||||
if ok {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["RTP-Info"] = rtpInfo.Marshal()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if ss.state != ServerSessionStatePlay &&
|
||||
*ss.setuppedTransport != TransportUDPMulticast {
|
||||
ss.destroyWriter()
|
||||
}
|
||||
return res, err
|
||||
}
|
||||
|
||||
if ss.state == ServerSessionStatePlay {
|
||||
return res, err
|
||||
}
|
||||
|
||||
ss.state = ServerSessionStatePlay
|
||||
|
||||
v := ss.s.timeNow().Unix()
|
||||
ss.udpLastPacketTime = &v
|
||||
|
||||
ss.timeDecoder = &rtptime.GlobalDecoder2{}
|
||||
ss.timeDecoder.Initialize()
|
||||
|
||||
for _, sm := range ss.setuppedMedias {
|
||||
sm.start()
|
||||
}
|
||||
|
||||
if *ss.setuppedTransport == TransportTCP {
|
||||
ss.tcpFrame = &base.InterleavedFrame{}
|
||||
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
|
||||
}
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
ss.startWriter()
|
||||
|
||||
case TransportUDPMulticast:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
err = switchReadFuncError{true}
|
||||
// startWriter() is called by ServerConn, through chAsyncStartWriter,
|
||||
// after the response has been sent
|
||||
}
|
||||
|
||||
ss.setuppedStream.readerSetActive(ss)
|
||||
|
||||
rtpInfo, ok := generateRTPInfo(
|
||||
ss.s.timeNow(),
|
||||
ss.setuppedMediasOrdered,
|
||||
ss.setuppedStream,
|
||||
ss.setuppedPath,
|
||||
req.URL)
|
||||
|
||||
if ok {
|
||||
if res.Header == nil {
|
||||
res.Header = make(base.Header)
|
||||
}
|
||||
res.Header["RTP-Info"] = rtpInfo.Marshal()
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -1260,38 +1266,37 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
Query: query,
|
||||
})
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
if res.StatusCode == base.StatusOK {
|
||||
ss.state = ServerSessionStateRecord
|
||||
|
||||
v := ss.s.timeNow().Unix()
|
||||
ss.udpLastPacketTime = &v
|
||||
|
||||
ss.timeDecoder = &rtptime.GlobalDecoder2{}
|
||||
ss.timeDecoder.Initialize()
|
||||
|
||||
for _, sm := range ss.setuppedMedias {
|
||||
sm.start()
|
||||
}
|
||||
|
||||
if *ss.setuppedTransport == TransportTCP {
|
||||
ss.tcpFrame = &base.InterleavedFrame{}
|
||||
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
|
||||
}
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
ss.startWriter()
|
||||
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
err = switchReadFuncError{true}
|
||||
// startWriter() is called by ServerConn, through chAsyncStartWriter,
|
||||
// after the response has been sent
|
||||
}
|
||||
} else {
|
||||
ss.destroyWriter()
|
||||
return res, err
|
||||
}
|
||||
|
||||
ss.state = ServerSessionStateRecord
|
||||
|
||||
v := ss.s.timeNow().Unix()
|
||||
ss.udpLastPacketTime = &v
|
||||
|
||||
ss.timeDecoder = &rtptime.GlobalDecoder2{}
|
||||
ss.timeDecoder.Initialize()
|
||||
|
||||
for _, sm := range ss.setuppedMedias {
|
||||
sm.start()
|
||||
}
|
||||
|
||||
if *ss.setuppedTransport == TransportTCP {
|
||||
ss.tcpFrame = &base.InterleavedFrame{}
|
||||
ss.tcpBuffer = make([]byte, ss.s.MaxPacketSize+4)
|
||||
}
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = time.NewTimer(ss.s.checkStreamPeriod)
|
||||
ss.startWriter()
|
||||
|
||||
default: // TCP
|
||||
ss.tcpConn = sc
|
||||
err = switchReadFuncError{true}
|
||||
// startWriter() is called by ServerConn, through chAsyncStartWriter,
|
||||
// after the response has been sent
|
||||
}
|
||||
|
||||
return res, err
|
||||
@@ -1317,50 +1322,48 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
|
||||
Query: query,
|
||||
})
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
return res, err
|
||||
}
|
||||
if res.StatusCode == base.StatusOK {
|
||||
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
|
||||
ss.destroyWriter()
|
||||
|
||||
if ss.state == ServerSessionStatePlay || ss.state == ServerSessionStateRecord {
|
||||
ss.destroyWriter()
|
||||
|
||||
if ss.setuppedStream != nil {
|
||||
ss.setuppedStream.readerSetInactive(ss)
|
||||
}
|
||||
|
||||
for _, sm := range ss.setuppedMedias {
|
||||
sm.stop()
|
||||
}
|
||||
|
||||
ss.timeDecoder = nil
|
||||
|
||||
switch ss.state {
|
||||
case ServerSessionStatePlay:
|
||||
ss.state = ServerSessionStatePrePlay
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = emptyTimer()
|
||||
|
||||
case TransportUDPMulticast:
|
||||
ss.udpCheckStreamTimer = emptyTimer()
|
||||
|
||||
default: // TCP
|
||||
err = switchReadFuncError{false}
|
||||
ss.tcpConn = nil
|
||||
if ss.setuppedStream != nil {
|
||||
ss.setuppedStream.readerSetInactive(ss)
|
||||
}
|
||||
|
||||
case ServerSessionStateRecord:
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = emptyTimer()
|
||||
|
||||
default: // TCP
|
||||
err = switchReadFuncError{false}
|
||||
ss.tcpConn = nil
|
||||
for _, sm := range ss.setuppedMedias {
|
||||
sm.stop()
|
||||
}
|
||||
|
||||
ss.state = ServerSessionStatePreRecord
|
||||
ss.timeDecoder = nil
|
||||
|
||||
switch ss.state {
|
||||
case ServerSessionStatePlay:
|
||||
ss.state = ServerSessionStatePrePlay
|
||||
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = emptyTimer()
|
||||
|
||||
case TransportUDPMulticast:
|
||||
ss.udpCheckStreamTimer = emptyTimer()
|
||||
|
||||
default: // TCP
|
||||
err = switchReadFuncError{false}
|
||||
ss.tcpConn = nil
|
||||
}
|
||||
|
||||
case ServerSessionStateRecord:
|
||||
switch *ss.setuppedTransport {
|
||||
case TransportUDP:
|
||||
ss.udpCheckStreamTimer = emptyTimer()
|
||||
|
||||
default: // TCP
|
||||
err = switchReadFuncError{false}
|
||||
ss.tcpConn = nil
|
||||
}
|
||||
|
||||
ss.state = ServerSessionStatePreRecord
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user