optimize code (#878)

* remove unused code

* initialize UDP listeners and SRTP before initializing medias

* make rtpSender and rtpReceiver available before PLAY / RECORD

* use writerMutex to protect writer only
This commit is contained in:
Alessandro Ros
2025-09-05 23:11:51 +02:00
committed by GitHub
parent 1bc89661eb
commit 65da49ffc0
9 changed files with 413 additions and 469 deletions

216
client.go
View File

@@ -41,6 +41,47 @@ const (
clientUserAgent = "gortsplib" clientUserAgent = "gortsplib"
) )
func generateLocalSSRCs(existing []uint32, formats []format.Format) (map[uint8]uint32, error) {
ret := make(map[uint8]uint32)
for _, forma := range formats {
for {
ssrc, err := randUint32()
if err != nil {
return nil, err
}
if ssrc != 0 && !slices.Contains(existing, ssrc) {
existing = append(existing, ssrc)
ret[forma.PayloadType()] = ssrc
break
}
}
}
return ret, nil
}
func ssrcsMapToList(m map[uint8]uint32) []uint32 {
ret := make([]uint32, len(m))
n := 0
for _, el := range m {
ret[n] = el
n++
}
return ret
}
func clientExtractExistingSSRCs(setuppedMedias map[*description.Media]*clientMedia) []uint32 {
var ret []uint32
for _, media := range setuppedMedias {
for _, forma := range media.formats {
ret = append(ret, forma.localSSRC)
}
}
return ret
}
// avoid an int64 overflow and preserve resolution by splitting division into two parts: // avoid an int64 overflow and preserve resolution by splitting division into two parts:
// first add the integer part, then the decimal part. // first add the integer part, then the decimal part.
func multiplyAndDivide(v, m, d time.Duration) time.Duration { func multiplyAndDivide(v, m, d time.Duration) time.Duration {
@@ -134,16 +175,16 @@ func announceDataPickLocalSSRC(
am *clientAnnounceDataMedia, am *clientAnnounceDataMedia,
data map[*description.Media]*clientAnnounceDataMedia, data map[*description.Media]*clientAnnounceDataMedia,
) (uint32, error) { ) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc var existing []uint32 //nolint:prealloc
for _, am := range data { for _, am := range data {
for _, af := range am.formats { for _, af := range am.formats {
takenSSRCs = append(takenSSRCs, af.localSSRC) existing = append(existing, af.localSSRC)
} }
} }
for _, af := range am.formats { for _, af := range am.formats {
takenSSRCs = append(takenSSRCs, af.localSSRC) existing = append(existing, af.localSSRC)
} }
for { for {
@@ -152,7 +193,7 @@ func announceDataPickLocalSSRC(
return 0, err return 0, err
} }
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { if ssrc != 0 && !slices.Contains(existing, ssrc) {
return ssrc, nil return ssrc, nil
} }
} }
@@ -214,9 +255,9 @@ func prepareForAnnounce(
n++ n++
} }
// use a dummy Context. // create a temporary Context.
// Context is needed to extract ROC, but since client has not started streaming, // Context is needed to extract ROC, but since client has not started streaming,
// ROC is always zero, therefore a dummy Context can be used. // ROC is always zero, therefore a temporary Context can be used.
srtpCtx := &wrappedSRTPContext{ srtpCtx := &wrappedSRTPContext{
key: announceDataMedia.srtpOutKey, key: announceDataMedia.srtpOutKey,
ssrcs: ssrcs, ssrcs: ssrcs,
@@ -527,8 +568,8 @@ type Client struct {
keepAlivePeriod time.Duration keepAlivePeriod time.Duration
keepAliveTimer *time.Timer keepAliveTimer *time.Timer
closeError error closeError error
writer *asyncProcessor
writerMutex sync.RWMutex writerMutex sync.RWMutex
writer *asyncProcessor
reader *clientReader reader *clientReader
timeDecoder *rtptime.GlobalDecoder2 timeDecoder *rtptime.GlobalDecoder2
mustClose bool mustClose bool
@@ -1618,20 +1659,41 @@ func (c *Client) doSetup(
} }
} }
cm := &clientMedia{ var localSSRCs map[uint8]uint32
c: c,
media: medi, if c.state == clientStatePreRecord {
secure: isSecure(th.Profile), localSSRCs = make(map[uint8]uint32)
} for forma, data := range c.announceData[medi].formats {
err = cm.initialize() localSSRCs[forma] = data.localSSRC
if err != nil { }
return nil, err } else {
localSSRCs, err = generateLocalSSRCs(
clientExtractExistingSSRCs(c.setuppedMedias),
medi.Formats,
)
if err != nil {
return nil, err
}
} }
var udpRTPListener *clientUDPListener
var udpRTCPListener *clientUDPListener
var tcpChannel int
var srtpInCtx *wrappedSRTPContext
var srtpOutCtx *wrappedSRTPContext
defer func() {
if udpRTPListener != nil {
udpRTPListener.close()
}
if udpRTCPListener != nil {
udpRTCPListener.close()
}
}()
switch transport { switch transport {
case TransportUDP, TransportUDPMulticast: case TransportUDP, TransportUDPMulticast:
if c.Scheme == "rtsps" && !isSecure(th.Profile) { if c.Scheme == "rtsps" && !isSecure(th.Profile) {
cm.close()
return nil, fmt.Errorf("unable to setup secure UDP") return nil, fmt.Errorf("unable to setup secure UDP")
} }
@@ -1640,29 +1702,27 @@ func (c *Client) doSetup(
if transport == TransportUDP { if transport == TransportUDP {
if (rtpPort == 0 && rtcpPort != 0) || if (rtpPort == 0 && rtcpPort != 0) ||
(rtpPort != 0 && rtcpPort == 0) { (rtpPort != 0 && rtcpPort == 0) {
cm.close()
return nil, liberrors.ErrClientUDPPortsZero{} return nil, liberrors.ErrClientUDPPortsZero{}
} }
if rtpPort != 0 && rtcpPort != (rtpPort+1) { if rtpPort != 0 && rtcpPort != (rtpPort+1) {
cm.close()
return nil, liberrors.ErrClientUDPPortsNotConsecutive{} return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
} }
err = cm.createUDPListeners( udpRTPListener, udpRTCPListener, err = createUDPListenerPair(
c,
false, false,
nil, nil,
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
) )
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
v1 := headers.TransportDeliveryUnicast v1 := headers.TransportDeliveryUnicast
th.Delivery = &v1 th.Delivery = &v1
th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()} th.ClientPorts = &[2]int{udpRTPListener.port(), udpRTCPListener.port()}
} else { } else {
v1 := headers.TransportDeliveryMulticast v1 := headers.TransportDeliveryMulticast
th.Delivery = &v1 th.Delivery = &v1
@@ -1678,7 +1738,6 @@ func (c *Client) doSetup(
mediaURL, err := medi.URL(baseURL) mediaURL, err := medi.URL(baseURL)
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
@@ -1688,7 +1747,6 @@ func (c *Client) doSetup(
if medi.IsBackChannel { if medi.IsBackChannel {
if !c.RequestBackChannels { if !c.RequestBackChannels {
cm.close()
return nil, fmt.Errorf("we are setupping a back channel but we did not request back channels") return nil, fmt.Errorf("we are setupping a back channel but we did not request back channels")
} }
@@ -1696,17 +1754,30 @@ func (c *Client) doSetup(
} }
if isSecure(th.Profile) { if isSecure(th.Profile) {
ssrcs := make([]uint32, len(cm.formats)) var srtpOutKey []byte
n := 0
for _, cf := range cm.formats { if c.state == clientStatePreRecord {
ssrcs[n] = cf.localSSRC srtpOutKey = c.announceData[medi].srtpOutKey
n++ } else {
srtpOutKey = make([]byte, srtpKeyLength)
_, err = rand.Read(srtpOutKey)
if err != nil {
return nil, err
}
}
srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcsMapToList(localSSRCs),
}
err = srtpOutCtx.initialize()
if err != nil {
return nil, err
} }
var mikeyMsg *mikey.Message var mikeyMsg *mikey.Message
mikeyMsg, err = mikeyGenerate(cm.srtpOutCtx) mikeyMsg, err = mikeyGenerate(srtpOutCtx)
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
@@ -1716,7 +1787,6 @@ func (c *Client) doSetup(
MikeyMessage: mikeyMsg, MikeyMessage: mikeyMsg,
}.Marshal() }.Marshal()
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
@@ -1729,13 +1799,10 @@ func (c *Client) doSetup(
Header: header, Header: header,
}, false) }, false)
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
if res.StatusCode != base.StatusOK { if res.StatusCode != base.StatusOK {
cm.close()
// switch transport automatically // switch transport automatically
if res.StatusCode == base.StatusUnsupportedTransport && if res.StatusCode == base.StatusUnsupportedTransport &&
c.setuppedTransport == nil && c.Transport == nil { c.setuppedTransport == nil && c.Transport == nil {
@@ -1753,15 +1820,12 @@ func (c *Client) doSetup(
var thRes headers.Transport var thRes headers.Transport
err = thRes.Unmarshal(res.Header["Transport"]) err = thRes.Unmarshal(res.Header["Transport"])
if err != nil { if err != nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err} return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err}
} }
switch transport { switch transport {
case TransportUDP, TransportUDPMulticast: case TransportUDP, TransportUDPMulticast:
if thRes.Protocol == headers.TransportProtocolTCP { if thRes.Protocol == headers.TransportProtocolTCP {
cm.close()
// switch transport automatically // switch transport automatically
if c.setuppedTransport == nil && c.Transport == nil { if c.setuppedTransport == nil && c.Transport == nil {
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{}) c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
@@ -1790,14 +1854,12 @@ func (c *Client) doSetup(
switch transport { switch transport {
case TransportUDP: case TransportUDP:
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
} }
serverPortsValid := thRes.ServerPorts != nil && !isAnyPort(thRes.ServerPorts[0]) && !isAnyPort(thRes.ServerPorts[1]) serverPortsValid := thRes.ServerPorts != nil && !isAnyPort(thRes.ServerPorts[0]) && !isAnyPort(thRes.ServerPorts[1])
if (c.state == clientStatePreRecord || !c.AnyPortEnable) && !serverPortsValid { if (c.state == clientStatePreRecord || !c.AnyPortEnable) && !serverPortsValid {
cm.close()
return nil, liberrors.ErrClientServerPortsNotProvided{} return nil, liberrors.ErrClientServerPortsNotProvided{}
} }
@@ -1810,41 +1872,38 @@ func (c *Client) doSetup(
if serverPortsValid { if serverPortsValid {
if !c.AnyPortEnable { if !c.AnyPortEnable {
cm.udpRTPListener.readPort = thRes.ServerPorts[0] udpRTPListener.readPort = thRes.ServerPorts[0]
} }
cm.udpRTPListener.writeAddr = &net.UDPAddr{ udpRTPListener.writeAddr = &net.UDPAddr{
IP: remoteIP, IP: remoteIP,
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
Port: thRes.ServerPorts[0], Port: thRes.ServerPorts[0],
} }
} }
cm.udpRTPListener.readIP = remoteIP udpRTPListener.readIP = remoteIP
if serverPortsValid { if serverPortsValid {
if !c.AnyPortEnable { if !c.AnyPortEnable {
cm.udpRTCPListener.readPort = thRes.ServerPorts[1] udpRTCPListener.readPort = thRes.ServerPorts[1]
} }
cm.udpRTCPListener.writeAddr = &net.UDPAddr{ udpRTCPListener.writeAddr = &net.UDPAddr{
IP: remoteIP, IP: remoteIP,
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
Port: thRes.ServerPorts[1], Port: thRes.ServerPorts[1],
} }
} }
cm.udpRTCPListener.readIP = remoteIP udpRTCPListener.readIP = remoteIP
case TransportUDPMulticast: case TransportUDPMulticast:
if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast { if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
} }
if thRes.Ports == nil { if thRes.Ports == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoPorts{} return nil, liberrors.ErrClientTransportHeaderNoPorts{}
} }
if thRes.Destination == nil { if thRes.Destination == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoDestination{} return nil, liberrors.ErrClientTransportHeaderNoDestination{}
} }
@@ -1858,72 +1917,65 @@ func (c *Client) doSetup(
var intf *net.Interface var intf *net.Interface
intf, err = interfaceOfConn(c.nconn) intf, err = interfaceOfConn(c.nconn)
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
err = cm.createUDPListeners( udpRTPListener, udpRTCPListener, err = createUDPListenerPair(
c,
true, true,
intf, intf,
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)), net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)), net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)),
) )
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
cm.udpRTPListener.readIP = remoteIP udpRTPListener.readIP = remoteIP
cm.udpRTPListener.readPort = thRes.Ports[0] udpRTPListener.readPort = thRes.Ports[0]
cm.udpRTPListener.writeAddr = &net.UDPAddr{ udpRTPListener.writeAddr = &net.UDPAddr{
IP: remoteIP, IP: remoteIP,
Port: thRes.Ports[0], Port: thRes.Ports[0],
} }
cm.udpRTCPListener.readIP = remoteIP udpRTCPListener.readIP = remoteIP
cm.udpRTCPListener.readPort = thRes.Ports[1] udpRTCPListener.readPort = thRes.Ports[1]
cm.udpRTCPListener.writeAddr = &net.UDPAddr{ udpRTCPListener.writeAddr = &net.UDPAddr{
IP: remoteIP, IP: remoteIP,
Port: thRes.Ports[1], Port: thRes.Ports[1],
} }
case TransportTCP: case TransportTCP:
if thRes.Protocol != headers.TransportProtocolTCP { if thRes.Protocol != headers.TransportProtocolTCP {
cm.close()
return nil, liberrors.ErrClientServerRequestedUDP{} return nil, liberrors.ErrClientServerRequestedUDP{}
} }
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
} }
if thRes.InterleavedIDs == nil { if thRes.InterleavedIDs == nil {
cm.close()
return nil, liberrors.ErrClientTransportHeaderNoInterleavedIDs{} return nil, liberrors.ErrClientTransportHeaderNoInterleavedIDs{}
} }
if (thRes.InterleavedIDs[0] + 1) != thRes.InterleavedIDs[1] { if (thRes.InterleavedIDs[0] + 1) != thRes.InterleavedIDs[1] {
cm.close()
return nil, liberrors.ErrClientTransportHeaderInvalidInterleavedIDs{} return nil, liberrors.ErrClientTransportHeaderInvalidInterleavedIDs{}
} }
if c.isChannelPairInUse(thRes.InterleavedIDs[0]) { if c.isChannelPairInUse(thRes.InterleavedIDs[0]) {
cm.close()
return &base.Response{ return &base.Response{
StatusCode: base.StatusBadRequest, StatusCode: base.StatusBadRequest,
}, liberrors.ErrClientTransportHeaderInterleavedIDsInUse{} }, liberrors.ErrClientTransportHeaderInterleavedIDsInUse{}
} }
cm.tcpChannel = thRes.InterleavedIDs[0] tcpChannel = thRes.InterleavedIDs[0]
} }
if thRes.Profile != th.Profile { if thRes.Profile != th.Profile {
cm.close()
return nil, fmt.Errorf("returned profile does not match requested profile") return nil, fmt.Errorf("returned profile does not match requested profile")
} }
if cm.secure { if isSecure(th.Profile) {
var mikeyMsg *mikey.Message var mikeyMsg *mikey.Message
// extract key-mgmt from (in order of priority): // extract key-mgmt from (in order of priority):
@@ -1935,7 +1987,6 @@ func (c *Client) doSetup(
var keyMgmt headers.KeyMgmt var keyMgmt headers.KeyMgmt
err = keyMgmt.Unmarshal(res.Header["KeyMgmt"]) err = keyMgmt.Unmarshal(res.Header["KeyMgmt"])
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
mikeyMsg = keyMgmt.MikeyMessage mikeyMsg = keyMgmt.MikeyMessage
@@ -1950,13 +2001,28 @@ func (c *Client) doSetup(
return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way") return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way")
} }
cm.srtpInCtx, err = mikeyToContext(mikeyMsg) srtpInCtx, err = mikeyToContext(mikeyMsg)
if err != nil { if err != nil {
cm.close()
return nil, err return nil, err
} }
} }
cm := &clientMedia{
c: c,
media: medi,
secure: isSecure(th.Profile),
udpRTPListener: udpRTPListener,
udpRTCPListener: udpRTCPListener,
tcpChannel: tcpChannel,
localSSRCs: localSSRCs,
srtpInCtx: srtpInCtx,
srtpOutCtx: srtpOutCtx,
}
cm.initialize()
udpRTPListener = nil
udpRTCPListener = nil
if c.setuppedMedias == nil { if c.setuppedMedias == nil {
c.setuppedMedias = make(map[*description.Media]*clientMedia) c.setuppedMedias = make(map[*description.Media]*clientMedia)
} }
@@ -2309,13 +2375,6 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet,
default: default:
} }
c.writerMutex.RLock()
defer c.writerMutex.RUnlock()
if c.writer == nil {
return nil
}
cm := c.setuppedMedias[medi] cm := c.setuppedMedias[medi]
cf := cm.formats[pkt.PayloadType] cf := cm.formats[pkt.PayloadType]
return cf.writePacketRTP(pkt, ntp) return cf.writePacketRTP(pkt, ntp)
@@ -2329,13 +2388,6 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
default: default:
} }
c.writerMutex.RLock()
defer c.writerMutex.RUnlock()
if c.writer == nil {
return nil
}
cm := c.setuppedMedias[medi] cm := c.setuppedMedias[medi]
return cm.writePacketRTCP(pkt) return cm.writePacketRTCP(pkt)
} }

View File

@@ -1,7 +1,6 @@
package gortsplib package gortsplib
import ( import (
"slices"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -14,37 +13,12 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/rtpsender" "github.com/bluenviron/gortsplib/v4/pkg/rtpsender"
) )
func clientPickLocalSSRC(cf *clientFormat) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc
for _, cm := range cf.cm.c.setuppedMedias {
for _, cf := range cm.formats {
takenSSRCs = append(takenSSRCs, cf.localSSRC)
}
}
for _, cf := range cf.cm.formats {
takenSSRCs = append(takenSSRCs, cf.localSSRC)
}
for {
ssrc, err := randUint32()
if err != nil {
return 0, err
}
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
return ssrc, nil
}
}
}
type clientFormat struct { type clientFormat struct {
cm *clientMedia cm *clientMedia
format format.Format format format.Format
localSSRC uint32
onPacketRTP OnPacketRTPFunc onPacketRTP OnPacketRTPFunc
localSSRC uint32
rtpReceiver *rtpreceiver.Receiver // play rtpReceiver *rtpreceiver.Receiver // play
rtpSender *rtpsender.Sender // record or back channel rtpSender *rtpsender.Sender // record or back channel
writePacketRTPInQueue func([]byte) error writePacketRTPInQueue func([]byte) error
@@ -53,32 +27,18 @@ type clientFormat struct {
rtpPacketsLost *uint64 rtpPacketsLost *uint64
} }
func (cf *clientFormat) initialize() error { func (cf *clientFormat) initialize() {
if cf.cm.c.state == clientStatePreRecord {
cf.localSSRC = cf.cm.c.announceData[cf.cm.media].formats[cf.format.PayloadType()].localSSRC
} else {
var err error
cf.localSSRC, err = clientPickLocalSSRC(cf)
if err != nil {
return err
}
}
cf.rtpPacketsReceived = new(uint64) cf.rtpPacketsReceived = new(uint64)
cf.rtpPacketsSent = new(uint64) cf.rtpPacketsSent = new(uint64)
cf.rtpPacketsLost = new(uint64) cf.rtpPacketsLost = new(uint64)
return nil
}
func (cf *clientFormat) start() {
if cf.cm.udpRTPListener != nil { if cf.cm.udpRTPListener != nil {
cf.writePacketRTPInQueue = cf.writePacketRTPInQueueUDP cf.writePacketRTPInQueue = cf.writePacketRTPInQueueUDP
} else { } else {
cf.writePacketRTPInQueue = cf.writePacketRTPInQueueTCP cf.writePacketRTPInQueue = cf.writePacketRTPInQueueTCP
} }
if cf.cm.c.state == clientStateRecord || cf.cm.media.IsBackChannel { if cf.cm.c.state == clientStatePreRecord || cf.cm.media.IsBackChannel {
cf.rtpSender = &rtpsender.Sender{ cf.rtpSender = &rtpsender.Sender{
ClockRate: cf.format.ClockRate(), ClockRate: cf.format.ClockRate(),
Period: cf.cm.c.senderReportPeriod, Period: cf.cm.c.senderReportPeriod,
@@ -110,7 +70,7 @@ func (cf *clientFormat) start() {
} }
} }
func (cf *clientFormat) stop() { func (cf *clientFormat) close() {
if cf.rtpReceiver != nil { if cf.rtpReceiver != nil {
cf.rtpReceiver.Close() cf.rtpReceiver.Close()
cf.rtpReceiver = nil cf.rtpReceiver = nil
@@ -178,6 +138,13 @@ func (cf *clientFormat) writePacketRTP(pkt *rtp.Packet, ntp time.Time) error {
buf = encr buf = encr
} }
cf.cm.c.writerMutex.RLock()
defer cf.cm.c.writerMutex.RUnlock()
if cf.cm.c.writer == nil {
return nil
}
ok := cf.cm.c.writer.push(func() error { ok := cf.cm.c.writer.push(func() error {
return cf.writePacketRTPInQueue(buf) return cf.writePacketRTPInQueue(buf)
}) })

View File

@@ -1,7 +1,6 @@
package gortsplib package gortsplib
import ( import (
"crypto/rand"
"fmt" "fmt"
"net" "net"
"strconv" "strconv"
@@ -15,18 +14,87 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/liberrors" "github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
type clientMedia struct { func createUDPListenerPair(
c *Client c *Client,
media *description.Media multicast bool,
secure bool multicastInterface *net.Interface,
rtpAddress string,
rtcpAddress string,
) (*clientUDPListener, *clientUDPListener, error) {
if rtpAddress != ":0" {
l1 := &clientUDPListener{
c: c,
multicast: multicast,
multicastInterface: multicastInterface,
address: rtpAddress,
}
err := l1.initialize()
if err != nil {
return nil, nil, err
}
l2 := &clientUDPListener{
c: c,
multicast: multicast,
multicastInterface: multicastInterface,
address: rtcpAddress,
}
err = l2.initialize()
if err != nil {
l1.close()
return nil, nil, err
}
return l1, l2, nil
}
// pick two consecutive ports in range 65535-10000
// RTP port must be even and RTCP port odd
for {
v, err := randInRange((65535 - 10000) / 2)
if err != nil {
return nil, nil, err
}
rtpPort := v*2 + 10000
rtcpPort := rtpPort + 1
l1 := &clientUDPListener{
c: c,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
}
err = l1.initialize()
if err != nil {
continue
}
l2 := &clientUDPListener{
c: c,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
}
err = l2.initialize()
if err != nil {
l2.close()
continue
}
return l1, l2, nil
}
}
type clientMedia struct {
c *Client
media *description.Media
secure bool
udpRTPListener *clientUDPListener
udpRTCPListener *clientUDPListener
tcpChannel int
localSSRCs map[uint8]uint32
srtpInCtx *wrappedSRTPContext
srtpOutCtx *wrappedSRTPContext
srtpOutCtx *wrappedSRTPContext
srtpInCtx *wrappedSRTPContext
onPacketRTCP OnPacketRTCPFunc onPacketRTCP OnPacketRTCPFunc
formats map[uint8]*clientFormat formats map[uint8]*clientFormat
tcpChannel int
udpRTPListener *clientUDPListener
udpRTCPListener *clientUDPListener
writePacketRTCPInQueue func([]byte) error writePacketRTCPInQueue func([]byte) error
bytesReceived *uint64 bytesReceived *uint64
bytesSent *uint64 bytesSent *uint64
@@ -36,7 +104,7 @@ type clientMedia struct {
rtcpPacketsInError *uint64 rtcpPacketsInError *uint64
} }
func (cm *clientMedia) initialize() error { func (cm *clientMedia) initialize() {
cm.onPacketRTCP = func(rtcp.Packet) {} cm.onPacketRTCP = func(rtcp.Packet) {}
cm.bytesReceived = new(uint64) cm.bytesReceived = new(uint64)
cm.bytesSent = new(uint64) cm.bytesSent = new(uint64)
@@ -51,123 +119,19 @@ func (cm *clientMedia) initialize() error {
f := &clientFormat{ f := &clientFormat{
cm: cm, cm: cm,
format: forma, format: forma,
localSSRC: cm.localSSRCs[forma.PayloadType()],
onPacketRTP: func(*rtp.Packet) {}, onPacketRTP: func(*rtp.Packet) {},
} }
err := f.initialize() f.initialize()
if err != nil {
return err
}
cm.formats[forma.PayloadType()] = f cm.formats[forma.PayloadType()] = f
} }
if cm.secure {
var srtpOutKey []byte
if cm.c.state == clientStatePreRecord {
srtpOutKey = cm.c.announceData[cm.media].srtpOutKey
} else {
srtpOutKey = make([]byte, srtpKeyLength)
_, err := rand.Read(srtpOutKey)
if err != nil {
return err
}
}
ssrcs := make([]uint32, len(cm.formats))
n := 0
for _, cf := range cm.formats {
ssrcs[n] = cf.localSSRC
n++
}
cm.srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcs,
}
err := cm.srtpOutCtx.initialize()
if err != nil {
return err
}
}
return nil
} }
func (cm *clientMedia) close() { func (cm *clientMedia) close() {
if cm.udpRTPListener != nil { cm.stop()
cm.udpRTPListener.close()
cm.udpRTCPListener.close()
}
}
func (cm *clientMedia) createUDPListeners( for _, ct := range cm.formats {
multicast bool, ct.close()
multicastInterface *net.Interface,
rtpAddress string,
rtcpAddress string,
) error {
if rtpAddress != ":0" {
l1 := &clientUDPListener{
c: cm.c,
multicast: multicast,
multicastInterface: multicastInterface,
address: rtpAddress,
}
err := l1.initialize()
if err != nil {
return err
}
l2 := &clientUDPListener{
c: cm.c,
multicast: multicast,
multicastInterface: multicastInterface,
address: rtcpAddress,
}
err = l2.initialize()
if err != nil {
l1.close()
return err
}
cm.udpRTPListener, cm.udpRTCPListener = l1, l2
return nil
}
// pick two consecutive ports in range 65535-10000
// RTP port must be even and RTCP port odd
for {
v, err := randInRange((65535 - 10000) / 2)
if err != nil {
return err
}
rtpPort := v*2 + 10000
rtcpPort := rtpPort + 1
cm.udpRTPListener = &clientUDPListener{
c: cm.c,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
}
err = cm.udpRTPListener.initialize()
if err != nil {
cm.udpRTPListener = nil
continue
}
cm.udpRTCPListener = &clientUDPListener{
c: cm.c,
address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
}
err = cm.udpRTCPListener.initialize()
if err != nil {
cm.udpRTPListener.close()
cm.udpRTPListener = nil
cm.udpRTCPListener = nil
continue
}
return nil
} }
} }
@@ -198,10 +162,6 @@ func (cm *clientMedia) start() {
} }
} }
for _, ct := range cm.formats {
ct.start()
}
if cm.udpRTPListener != nil { if cm.udpRTPListener != nil {
cm.udpRTPListener.start() cm.udpRTPListener.start()
cm.udpRTCPListener.start() cm.udpRTCPListener.start()
@@ -213,10 +173,6 @@ func (cm *clientMedia) stop() {
cm.udpRTPListener.stop() cm.udpRTPListener.stop()
cm.udpRTCPListener.stop() cm.udpRTCPListener.stop()
} }
for _, ct := range cm.formats {
ct.stop()
}
} }
func (cm *clientMedia) findFormatByRemoteSSRC(ssrc uint32) *clientFormat { func (cm *clientMedia) findFormatByRemoteSSRC(ssrc uint32) *clientFormat {
@@ -460,6 +416,13 @@ func (cm *clientMedia) writePacketRTCP(pkt rtcp.Packet) error {
buf = encr buf = encr
} }
cm.c.writerMutex.RLock()
defer cm.c.writerMutex.RUnlock()
if cm.c.writer == nil {
return nil
}
ok := cm.c.writer.push(func() error { ok := cm.c.writer.push(func() error {
return cm.writePacketRTCPInQueue(buf) return cm.writePacketRTCPInQueue(buf)
}) })

View File

@@ -3,6 +3,7 @@ package gortsplib
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/rand"
"fmt" "fmt"
"log" "log"
"net" "net"
@@ -31,6 +32,16 @@ import (
type readFunc func([]byte) bool type readFunc func([]byte) bool
func serverSessionExtractExistingSSRCs(medias map[*description.Media]*serverSessionMedia) []uint32 {
var ret []uint32
for _, media := range medias {
for _, forma := range media.formats {
ret = append(ret, forma.localSSRC)
}
}
return ret
}
func isSecure(profile headers.TransportProfile) bool { func isSecure(profile headers.TransportProfile) bool {
return profile == headers.TransportProfileSAVP return profile == headers.TransportProfileSAVP
} }
@@ -433,8 +444,8 @@ type ServerSession struct {
announcedDesc *description.Session // record announcedDesc *description.Session // record
udpLastPacketTime *int64 // record udpLastPacketTime *int64 // record
udpCheckStreamTimer *time.Timer udpCheckStreamTimer *time.Timer
writer *asyncProcessor
writerMutex sync.RWMutex writerMutex sync.RWMutex
writer *asyncProcessor
timeDecoder *rtptime.GlobalDecoder2 timeDecoder *rtptime.GlobalDecoder2
tcpFrame *base.InterleavedFrame tcpFrame *base.InterleavedFrame
tcpBuffer []byte tcpBuffer []byte
@@ -818,7 +829,7 @@ func (ss *ServerSession) run() {
} }
for _, sm := range ss.setuppedMedias { for _, sm := range ss.setuppedMedias {
sm.stop() sm.close()
} }
if ss.writer != nil { if ss.writer != nil {
@@ -1290,37 +1301,76 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
res.Header = make(base.Header) res.Header = make(base.Header)
} }
sm := &serverSessionMedia{ var localSSRCs map[uint8]uint32
ss: ss,
media: medi, if ss.state == ServerSessionStatePreRecord || medi.IsBackChannel {
srtpInCtx: srtpInCtx, localSSRCs, err = generateLocalSSRCs(
onPacketRTCP: func(_ rtcp.Packet) {}, serverSessionExtractExistingSSRCs(ss.setuppedMedias),
medi.Formats,
)
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
} else {
localSSRCs = make(map[uint8]uint32)
for forma, data := range ss.setuppedStream.medias[medi].formats {
localSSRCs[forma] = data.localSSRC
}
} }
err = sm.initialize()
if err != nil { var srtpOutCtx *wrappedSRTPContext
return &base.Response{
StatusCode: base.StatusInternalServerError, if ss.s.TLSConfig != nil {
}, err if ss.state == ServerSessionStatePreRecord || medi.IsBackChannel {
srtpOutKey := make([]byte, srtpKeyLength)
_, err = rand.Read(srtpOutKey)
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcsMapToList(localSSRCs),
}
err = srtpOutCtx.initialize()
if err != nil {
return &base.Response{
StatusCode: base.StatusInternalServerError,
}, err
}
} else {
srtpOutCtx = ss.setuppedStream.medias[medi].srtpOutCtx
}
} }
var udpRTPReadPort int
var udpRTPWriteAddr *net.UDPAddr
var udpRTCPReadPort int
var udpRTCPWriteAddr *net.UDPAddr
var tcpChannel int
switch transport { switch transport {
case TransportUDP, TransportUDPMulticast: case TransportUDP, TransportUDPMulticast:
th.Protocol = headers.TransportProtocolUDP th.Protocol = headers.TransportProtocolUDP
if transport == TransportUDP { if transport == TransportUDP {
sm.udpRTPReadPort = inTH.ClientPorts[0] udpRTPReadPort = inTH.ClientPorts[0]
sm.udpRTCPReadPort = inTH.ClientPorts[1] udpRTCPReadPort = inTH.ClientPorts[1]
sm.udpRTPWriteAddr = &net.UDPAddr{ udpRTPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(), IP: ss.author.ip(),
Zone: ss.author.zone(), Zone: ss.author.zone(),
Port: sm.udpRTPReadPort, Port: udpRTPReadPort,
} }
sm.udpRTCPWriteAddr = &net.UDPAddr{ udpRTCPWriteAddr = &net.UDPAddr{
IP: ss.author.ip(), IP: ss.author.ip(),
Zone: ss.author.zone(), Zone: ss.author.zone(),
Port: sm.udpRTCPReadPort, Port: udpRTCPReadPort,
} }
de := headers.TransportDeliveryUnicast de := headers.TransportDeliveryUnicast
@@ -1341,32 +1391,41 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) (
th.Protocol = headers.TransportProtocolTCP th.Protocol = headers.TransportProtocolTCP
if inTH.InterleavedIDs != nil { if inTH.InterleavedIDs != nil {
sm.tcpChannel = inTH.InterleavedIDs[0] tcpChannel = inTH.InterleavedIDs[0]
} else { } else {
sm.tcpChannel = ss.findFreeChannelPair() tcpChannel = ss.findFreeChannelPair()
} }
de := headers.TransportDeliveryUnicast de := headers.TransportDeliveryUnicast
th.Delivery = &de th.Delivery = &de
th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1} th.InterleavedIDs = &[2]int{tcpChannel, tcpChannel + 1}
} }
sm := &serverSessionMedia{
ss: ss,
media: medi,
localSSRCs: localSSRCs,
srtpInCtx: srtpInCtx,
srtpOutCtx: srtpOutCtx,
udpRTPReadPort: udpRTPReadPort,
udpRTPWriteAddr: udpRTPWriteAddr,
udpRTCPReadPort: udpRTCPReadPort,
udpRTCPWriteAddr: udpRTCPWriteAddr,
tcpChannel: tcpChannel,
onPacketRTCP: func(_ rtcp.Packet) {},
}
sm.initialize()
if ss.setuppedMedias == nil { if ss.setuppedMedias == nil {
ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia) ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia)
} }
ss.setuppedMedias[medi] = sm ss.setuppedMedias[medi] = sm
ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm) ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm)
res.Header["Transport"] = th.Marshal() res.Header["Transport"] = th.Marshal()
if isSecure(inTH.Profile) { if isSecure(inTH.Profile) {
ssrcs := make([]uint32, len(sm.formats))
n := 0
for _, sf := range sm.formats {
ssrcs[n] = sf.localSSRC
n++
}
var mk *mikey.Message var mk *mikey.Message
mk, err = mikeyGenerate(sm.srtpOutCtx) mk, err = mikeyGenerate(sm.srtpOutCtx)
if err != nil { if err != nil {

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"log" "log"
"slices"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -14,37 +13,12 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/rtpreceiver" "github.com/bluenviron/gortsplib/v4/pkg/rtpreceiver"
) )
func serverSessionPickLocalSSRC(sf *serverSessionFormat) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc
for _, sm := range sf.sm.ss.setuppedMedias {
for _, sf := range sm.formats {
takenSSRCs = append(takenSSRCs, sf.localSSRC)
}
}
for _, sf := range sf.sm.formats {
takenSSRCs = append(takenSSRCs, sf.localSSRC)
}
for {
ssrc, err := randUint32()
if err != nil {
return 0, err
}
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
return ssrc, nil
}
}
}
type serverSessionFormat struct { type serverSessionFormat struct {
sm *serverSessionMedia sm *serverSessionMedia
format format.Format format format.Format
localSSRC uint32
onPacketRTP OnPacketRTPFunc onPacketRTP OnPacketRTPFunc
localSSRC uint32
rtpReceiver *rtpreceiver.Receiver rtpReceiver *rtpreceiver.Receiver
writePacketRTPInQueue func([]byte) error writePacketRTPInQueue func([]byte) error
rtpPacketsReceived *uint64 rtpPacketsReceived *uint64
@@ -52,25 +26,11 @@ type serverSessionFormat struct {
rtpPacketsLost *uint64 rtpPacketsLost *uint64
} }
func (sf *serverSessionFormat) initialize() error { func (sf *serverSessionFormat) initialize() {
if sf.sm.ss.state == ServerSessionStatePreRecord || sf.sm.media.IsBackChannel {
var err error
sf.localSSRC, err = serverSessionPickLocalSSRC(sf)
if err != nil {
return err
}
} else {
sf.localSSRC = sf.sm.ss.setuppedStream.medias[sf.sm.media].formats[sf.format.PayloadType()].localSSRC
}
sf.rtpPacketsReceived = new(uint64) sf.rtpPacketsReceived = new(uint64)
sf.rtpPacketsSent = new(uint64) sf.rtpPacketsSent = new(uint64)
sf.rtpPacketsLost = new(uint64) sf.rtpPacketsLost = new(uint64)
return nil
}
func (sf *serverSessionFormat) start() {
udp := *sf.sm.ss.setuppedTransport == TransportUDP || *sf.sm.ss.setuppedTransport == TransportUDPMulticast udp := *sf.sm.ss.setuppedTransport == TransportUDP || *sf.sm.ss.setuppedTransport == TransportUDPMulticast
if udp { if udp {
@@ -79,7 +39,7 @@ func (sf *serverSessionFormat) start() {
sf.writePacketRTPInQueue = sf.writePacketRTPInQueueTCP sf.writePacketRTPInQueue = sf.writePacketRTPInQueueTCP
} }
if sf.sm.ss.state == ServerSessionStateRecord || sf.sm.media.IsBackChannel { if sf.sm.ss.state == ServerSessionStatePreRecord || sf.sm.media.IsBackChannel {
sf.rtpReceiver = &rtpreceiver.Receiver{ sf.rtpReceiver = &rtpreceiver.Receiver{
ClockRate: sf.format.ClockRate(), ClockRate: sf.format.ClockRate(),
LocalSSRC: &sf.localSSRC, LocalSSRC: &sf.localSSRC,
@@ -99,7 +59,7 @@ func (sf *serverSessionFormat) start() {
} }
} }
func (sf *serverSessionFormat) stop() { func (sf *serverSessionFormat) close() {
if sf.rtpReceiver != nil { if sf.rtpReceiver != nil {
sf.rtpReceiver.Close() sf.rtpReceiver.Close()
sf.rtpReceiver = nil sf.rtpReceiver = nil

View File

@@ -1,7 +1,6 @@
package gortsplib package gortsplib
import ( import (
"crypto/rand"
"fmt" "fmt"
"log" "log"
"net" "net"
@@ -16,17 +15,18 @@ import (
) )
type serverSessionMedia struct { type serverSessionMedia struct {
ss *ServerSession ss *ServerSession
media *description.Media media *description.Media
srtpInCtx *wrappedSRTPContext localSSRCs map[uint8]uint32
onPacketRTCP OnPacketRTCPFunc srtpInCtx *wrappedSRTPContext
srtpOutCtx *wrappedSRTPContext
udpRTPReadPort int
udpRTPWriteAddr *net.UDPAddr
udpRTCPReadPort int
udpRTCPWriteAddr *net.UDPAddr
tcpChannel int
onPacketRTCP OnPacketRTCPFunc
srtpOutCtx *wrappedSRTPContext
tcpChannel int
udpRTPReadPort int
udpRTPWriteAddr *net.UDPAddr
udpRTCPReadPort int
udpRTCPWriteAddr *net.UDPAddr
formats map[uint8]*serverSessionFormat // record only formats map[uint8]*serverSessionFormat // record only
writePacketRTCPInQueue func([]byte) error writePacketRTCPInQueue func([]byte) error
bytesReceived *uint64 bytesReceived *uint64
@@ -37,7 +37,7 @@ type serverSessionMedia struct {
rtcpPacketsInError *uint64 rtcpPacketsInError *uint64
} }
func (sm *serverSessionMedia) initialize() error { func (sm *serverSessionMedia) initialize() {
sm.bytesReceived = new(uint64) sm.bytesReceived = new(uint64)
sm.bytesSent = new(uint64) sm.bytesSent = new(uint64)
sm.rtpPacketsInError = new(uint64) sm.rtpPacketsInError = new(uint64)
@@ -51,54 +51,23 @@ func (sm *serverSessionMedia) initialize() error {
f := &serverSessionFormat{ f := &serverSessionFormat{
sm: sm, sm: sm,
format: forma, format: forma,
localSSRC: sm.localSSRCs[forma.PayloadType()],
onPacketRTP: func(*rtp.Packet) {}, onPacketRTP: func(*rtp.Packet) {},
} }
err := f.initialize() f.initialize()
if err != nil {
return err
}
sm.formats[forma.PayloadType()] = f sm.formats[forma.PayloadType()] = f
} }
}
if sm.ss.s.TLSConfig != nil { func (sm *serverSessionMedia) close() {
if sm.ss.state == ServerSessionStatePreRecord || sm.media.IsBackChannel { sm.stop()
srtpOutKey := make([]byte, srtpKeyLength)
_, err := rand.Read(srtpOutKey)
if err != nil {
return err
}
ssrcs := make([]uint32, len(sm.formats)) for _, forma := range sm.formats {
n := 0 forma.close()
for _, cf := range sm.formats {
ssrcs[n] = cf.localSSRC
n++
}
sm.srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcs,
}
err = sm.srtpOutCtx.initialize()
if err != nil {
return err
}
} else {
streamMedia := sm.ss.setuppedStream.medias[sm.media]
sm.srtpOutCtx = streamMedia.srtpOutCtx
}
} }
return nil
} }
func (sm *serverSessionMedia) start() error { func (sm *serverSessionMedia) start() error {
// allocate udpRTCPReceiver before udpRTCPListener
// otherwise udpRTCPReceiver.LastSSRC() cannot be called.
for _, sf := range sm.formats {
sf.start()
}
switch *sm.ss.setuppedTransport { switch *sm.ss.setuppedTransport {
case TransportUDP, TransportUDPMulticast: case TransportUDP, TransportUDPMulticast:
sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueUDP sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueUDP
@@ -168,10 +137,6 @@ func (sm *serverSessionMedia) stop() {
sm.ss.s.udpRTPListener.removeClient(sm.ss.author.ip(), sm.udpRTPReadPort) sm.ss.s.udpRTPListener.removeClient(sm.ss.author.ip(), sm.udpRTPReadPort)
sm.ss.s.udpRTCPListener.removeClient(sm.ss.author.ip(), sm.udpRTCPReadPort) sm.ss.s.udpRTCPListener.removeClient(sm.ss.author.ip(), sm.udpRTCPReadPort)
} }
for _, sf := range sm.formats {
sf.stop()
}
} }
func (sm *serverSessionMedia) findFormatByRemoteSSRC(ssrc uint32) *serverSessionFormat { func (sm *serverSessionMedia) findFormatByRemoteSSRC(ssrc uint32) *serverSessionFormat {

View File

@@ -1,6 +1,7 @@
package gortsplib package gortsplib
import ( import (
"crypto/rand"
"fmt" "fmt"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -14,6 +15,16 @@ import (
"github.com/bluenviron/gortsplib/v4/pkg/liberrors" "github.com/bluenviron/gortsplib/v4/pkg/liberrors"
) )
func serverStreamExtractExistingSSRCs(medias map[*description.Media]*serverStreamMedia) []uint32 {
var ret []uint32
for _, media := range medias {
for _, forma := range media.formats {
ret = append(ret, forma.localSSRC)
}
}
return ret
}
// NewServerStream allocates a ServerStream. // NewServerStream allocates a ServerStream.
// //
// Deprecated: replaced by ServerStream.Initialize(). // Deprecated: replaced by ServerStream.Initialize().
@@ -56,20 +67,53 @@ func (st *ServerStream) Initialize() error {
st.activeUnicastReaders = make(map[*ServerSession]struct{}) st.activeUnicastReaders = make(map[*ServerSession]struct{})
st.medias = make(map[*description.Media]*serverStreamMedia, len(st.Desc.Medias)) st.medias = make(map[*description.Media]*serverStreamMedia, len(st.Desc.Medias))
for i, medi := range st.Desc.Medias { for i, medi := range st.Desc.Medias {
sm := &serverStreamMedia{ localSSRCs, err := generateLocalSSRCs(
st: st, serverStreamExtractExistingSSRCs(st.medias),
media: medi, medi.Formats,
trackID: i, )
}
err := sm.initialize()
if err != nil { if err != nil {
for _, medi := range st.Desc.Medias[:i] { for _, sm := range st.medias {
st.medias[medi].close() sm.close()
} }
return err return err
} }
var srtpOutCtx *wrappedSRTPContext
if st.Server.TLSConfig != nil {
srtpOutKey := make([]byte, srtpKeyLength)
_, err = rand.Read(srtpOutKey)
if err != nil {
for _, sm := range st.medias {
sm.close()
}
return err
}
srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcsMapToList(localSSRCs),
}
err = srtpOutCtx.initialize()
if err != nil {
for _, sm := range st.medias {
sm.close()
}
return err
}
}
sm := &serverStreamMedia{
st: st,
media: medi,
trackID: i,
localSSRCs: localSSRCs,
srtpOutCtx: srtpOutCtx,
}
sm.initialize()
st.medias[medi] = sm st.medias[medi] = sm
} }

View File

@@ -2,7 +2,6 @@ package gortsplib
import ( import (
"crypto/rand" "crypto/rand"
"slices"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -22,47 +21,16 @@ func randUint32() (uint32, error) {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil
} }
func serverStreamPickLocalSSRC(sf *serverStreamFormat) (uint32, error) {
var takenSSRCs []uint32 //nolint:prealloc
for _, sm := range sf.sm.st.medias {
for _, sf := range sm.formats {
takenSSRCs = append(takenSSRCs, sf.localSSRC)
}
}
for _, sf := range sf.sm.formats {
takenSSRCs = append(takenSSRCs, sf.localSSRC)
}
for {
ssrc, err := randUint32()
if err != nil {
return 0, err
}
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
return ssrc, nil
}
}
}
type serverStreamFormat struct { type serverStreamFormat struct {
sm *serverStreamMedia sm *serverStreamMedia
format format.Format format format.Format
localSSRC uint32
localSSRC uint32
rtpSender *rtpsender.Sender rtpSender *rtpsender.Sender
rtpPacketsSent *uint64 rtpPacketsSent *uint64
} }
func (sf *serverStreamFormat) initialize() error { func (sf *serverStreamFormat) initialize() {
var err error
sf.localSSRC, err = serverStreamPickLocalSSRC(sf)
if err != nil {
return err
}
sf.rtpPacketsSent = new(uint64) sf.rtpPacketsSent = new(uint64)
sf.rtpSender = &rtpsender.Sender{ sf.rtpSender = &rtpsender.Sender{
@@ -76,8 +44,6 @@ func (sf *serverStreamFormat) initialize() error {
}, },
} }
sf.rtpSender.Initialize() sf.rtpSender.Initialize()
return nil
} }
func (sf *serverStreamFormat) close() { func (sf *serverStreamFormat) close() {

View File

@@ -1,7 +1,6 @@
package gortsplib package gortsplib
import ( import (
"crypto/rand"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
@@ -10,64 +9,33 @@ import (
) )
type serverStreamMedia struct { type serverStreamMedia struct {
st *ServerStream st *ServerStream
media *description.Media media *description.Media
trackID int trackID int
localSSRCs map[uint8]uint32
srtpOutCtx *wrappedSRTPContext
srtpOutCtx *wrappedSRTPContext
formats map[uint8]*serverStreamFormat formats map[uint8]*serverStreamFormat
multicastWriter *serverMulticastWriter multicastWriter *serverMulticastWriter
bytesSent *uint64 bytesSent *uint64
rtcpPacketsSent *uint64 rtcpPacketsSent *uint64
} }
func (sm *serverStreamMedia) initialize() error { func (sm *serverStreamMedia) initialize() {
sm.bytesSent = new(uint64) sm.bytesSent = new(uint64)
sm.rtcpPacketsSent = new(uint64) sm.rtcpPacketsSent = new(uint64)
sm.formats = make(map[uint8]*serverStreamFormat) sm.formats = make(map[uint8]*serverStreamFormat)
for i, forma := range sm.media.Formats { for _, forma := range sm.media.Formats {
sf := &serverStreamFormat{ sf := &serverStreamFormat{
sm: sm, sm: sm,
format: forma, format: forma,
localSSRC: sm.localSSRCs[forma.PayloadType()],
} }
err := sf.initialize() sf.initialize()
if err != nil {
for _, forma := range sm.media.Formats[:i] {
sm.formats[forma.PayloadType()].close()
}
return err
}
sm.formats[forma.PayloadType()] = sf sm.formats[forma.PayloadType()] = sf
} }
if sm.st.Server.TLSConfig != nil {
srtpOutKey := make([]byte, srtpKeyLength)
_, err := rand.Read(srtpOutKey)
if err != nil {
return err
}
ssrcs := make([]uint32, len(sm.formats))
n := 0
for _, cf := range sm.formats {
ssrcs[n] = cf.localSSRC
n++
}
sm.srtpOutCtx = &wrappedSRTPContext{
key: srtpOutKey,
ssrcs: ssrcs,
}
err = sm.srtpOutCtx.initialize()
if err != nil {
return err
}
}
return nil
} }
func (sm *serverStreamMedia) close() { func (sm *serverStreamMedia) close() {