diff --git a/client.go b/client.go index a9252b33..bc9f43f2 100644 --- a/client.go +++ b/client.go @@ -41,6 +41,47 @@ const ( clientUserAgent = "gortsplib" ) +func generateLocalSSRCs(existing []uint32, formats []format.Format) (map[uint8]uint32, error) { + ret := make(map[uint8]uint32) + + for _, forma := range formats { + for { + ssrc, err := randUint32() + if err != nil { + return nil, err + } + + if ssrc != 0 && !slices.Contains(existing, ssrc) { + existing = append(existing, ssrc) + ret[forma.PayloadType()] = ssrc + break + } + } + } + + return ret, nil +} + +func ssrcsMapToList(m map[uint8]uint32) []uint32 { + ret := make([]uint32, len(m)) + n := 0 + for _, el := range m { + ret[n] = el + n++ + } + return ret +} + +func clientExtractExistingSSRCs(setuppedMedias map[*description.Media]*clientMedia) []uint32 { + var ret []uint32 + for _, media := range setuppedMedias { + for _, forma := range media.formats { + ret = append(ret, forma.localSSRC) + } + } + return ret +} + // avoid an int64 overflow and preserve resolution by splitting division into two parts: // first add the integer part, then the decimal part. func multiplyAndDivide(v, m, d time.Duration) time.Duration { @@ -134,16 +175,16 @@ func announceDataPickLocalSSRC( am *clientAnnounceDataMedia, data map[*description.Media]*clientAnnounceDataMedia, ) (uint32, error) { - var takenSSRCs []uint32 //nolint:prealloc + var existing []uint32 //nolint:prealloc for _, am := range data { for _, af := range am.formats { - takenSSRCs = append(takenSSRCs, af.localSSRC) + existing = append(existing, af.localSSRC) } } for _, af := range am.formats { - takenSSRCs = append(takenSSRCs, af.localSSRC) + existing = append(existing, af.localSSRC) } for { @@ -152,7 +193,7 @@ func announceDataPickLocalSSRC( return 0, err } - if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { + if ssrc != 0 && !slices.Contains(existing, ssrc) { return ssrc, nil } } @@ -214,9 +255,9 @@ func prepareForAnnounce( n++ } - // use a dummy Context. + // create a temporary Context. // Context is needed to extract ROC, but since client has not started streaming, - // ROC is always zero, therefore a dummy Context can be used. + // ROC is always zero, therefore a temporary Context can be used. srtpCtx := &wrappedSRTPContext{ key: announceDataMedia.srtpOutKey, ssrcs: ssrcs, @@ -527,8 +568,8 @@ type Client struct { keepAlivePeriod time.Duration keepAliveTimer *time.Timer closeError error - writer *asyncProcessor writerMutex sync.RWMutex + writer *asyncProcessor reader *clientReader timeDecoder *rtptime.GlobalDecoder2 mustClose bool @@ -1618,20 +1659,41 @@ func (c *Client) doSetup( } } - cm := &clientMedia{ - c: c, - media: medi, - secure: isSecure(th.Profile), - } - err = cm.initialize() - if err != nil { - return nil, err + var localSSRCs map[uint8]uint32 + + if c.state == clientStatePreRecord { + localSSRCs = make(map[uint8]uint32) + for forma, data := range c.announceData[medi].formats { + localSSRCs[forma] = data.localSSRC + } + } else { + localSSRCs, err = generateLocalSSRCs( + clientExtractExistingSSRCs(c.setuppedMedias), + medi.Formats, + ) + if err != nil { + return nil, err + } } + var udpRTPListener *clientUDPListener + var udpRTCPListener *clientUDPListener + var tcpChannel int + var srtpInCtx *wrappedSRTPContext + var srtpOutCtx *wrappedSRTPContext + + defer func() { + if udpRTPListener != nil { + udpRTPListener.close() + } + if udpRTCPListener != nil { + udpRTCPListener.close() + } + }() + switch transport { case TransportUDP, TransportUDPMulticast: if c.Scheme == "rtsps" && !isSecure(th.Profile) { - cm.close() return nil, fmt.Errorf("unable to setup secure UDP") } @@ -1640,29 +1702,27 @@ func (c *Client) doSetup( if transport == TransportUDP { if (rtpPort == 0 && rtcpPort != 0) || (rtpPort != 0 && rtcpPort == 0) { - cm.close() return nil, liberrors.ErrClientUDPPortsZero{} } if rtpPort != 0 && rtcpPort != (rtpPort+1) { - cm.close() return nil, liberrors.ErrClientUDPPortsNotConsecutive{} } - err = cm.createUDPListeners( + udpRTPListener, udpRTCPListener, err = createUDPListenerPair( + c, false, nil, net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), ) if err != nil { - cm.close() return nil, err } v1 := headers.TransportDeliveryUnicast th.Delivery = &v1 - th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()} + th.ClientPorts = &[2]int{udpRTPListener.port(), udpRTCPListener.port()} } else { v1 := headers.TransportDeliveryMulticast th.Delivery = &v1 @@ -1678,7 +1738,6 @@ func (c *Client) doSetup( mediaURL, err := medi.URL(baseURL) if err != nil { - cm.close() return nil, err } @@ -1688,7 +1747,6 @@ func (c *Client) doSetup( if medi.IsBackChannel { if !c.RequestBackChannels { - cm.close() return nil, fmt.Errorf("we are setupping a back channel but we did not request back channels") } @@ -1696,17 +1754,30 @@ func (c *Client) doSetup( } if isSecure(th.Profile) { - ssrcs := make([]uint32, len(cm.formats)) - n := 0 - for _, cf := range cm.formats { - ssrcs[n] = cf.localSSRC - n++ + var srtpOutKey []byte + + if c.state == clientStatePreRecord { + srtpOutKey = c.announceData[medi].srtpOutKey + } else { + srtpOutKey = make([]byte, srtpKeyLength) + _, err = rand.Read(srtpOutKey) + if err != nil { + return nil, err + } + } + + srtpOutCtx = &wrappedSRTPContext{ + key: srtpOutKey, + ssrcs: ssrcsMapToList(localSSRCs), + } + err = srtpOutCtx.initialize() + if err != nil { + return nil, err } var mikeyMsg *mikey.Message - mikeyMsg, err = mikeyGenerate(cm.srtpOutCtx) + mikeyMsg, err = mikeyGenerate(srtpOutCtx) if err != nil { - cm.close() return nil, err } @@ -1716,7 +1787,6 @@ func (c *Client) doSetup( MikeyMessage: mikeyMsg, }.Marshal() if err != nil { - cm.close() return nil, err } @@ -1729,13 +1799,10 @@ func (c *Client) doSetup( Header: header, }, false) if err != nil { - cm.close() return nil, err } if res.StatusCode != base.StatusOK { - cm.close() - // switch transport automatically if res.StatusCode == base.StatusUnsupportedTransport && c.setuppedTransport == nil && c.Transport == nil { @@ -1753,15 +1820,12 @@ func (c *Client) doSetup( var thRes headers.Transport err = thRes.Unmarshal(res.Header["Transport"]) if err != nil { - cm.close() return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err} } switch transport { case TransportUDP, TransportUDPMulticast: if thRes.Protocol == headers.TransportProtocolTCP { - cm.close() - // switch transport automatically if c.setuppedTransport == nil && c.Transport == nil { c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{}) @@ -1790,14 +1854,12 @@ func (c *Client) doSetup( switch transport { case TransportUDP: if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { - cm.close() return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} } serverPortsValid := thRes.ServerPorts != nil && !isAnyPort(thRes.ServerPorts[0]) && !isAnyPort(thRes.ServerPorts[1]) if (c.state == clientStatePreRecord || !c.AnyPortEnable) && !serverPortsValid { - cm.close() return nil, liberrors.ErrClientServerPortsNotProvided{} } @@ -1810,41 +1872,38 @@ func (c *Client) doSetup( if serverPortsValid { if !c.AnyPortEnable { - cm.udpRTPListener.readPort = thRes.ServerPorts[0] + udpRTPListener.readPort = thRes.ServerPorts[0] } - cm.udpRTPListener.writeAddr = &net.UDPAddr{ + udpRTPListener.writeAddr = &net.UDPAddr{ IP: remoteIP, Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, Port: thRes.ServerPorts[0], } } - cm.udpRTPListener.readIP = remoteIP + udpRTPListener.readIP = remoteIP if serverPortsValid { if !c.AnyPortEnable { - cm.udpRTCPListener.readPort = thRes.ServerPorts[1] + udpRTCPListener.readPort = thRes.ServerPorts[1] } - cm.udpRTCPListener.writeAddr = &net.UDPAddr{ + udpRTCPListener.writeAddr = &net.UDPAddr{ IP: remoteIP, Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone, Port: thRes.ServerPorts[1], } } - cm.udpRTCPListener.readIP = remoteIP + udpRTCPListener.readIP = remoteIP case TransportUDPMulticast: if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast { - cm.close() return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} } if thRes.Ports == nil { - cm.close() return nil, liberrors.ErrClientTransportHeaderNoPorts{} } if thRes.Destination == nil { - cm.close() return nil, liberrors.ErrClientTransportHeaderNoDestination{} } @@ -1858,72 +1917,65 @@ func (c *Client) doSetup( var intf *net.Interface intf, err = interfaceOfConn(c.nconn) if err != nil { - cm.close() return nil, err } - err = cm.createUDPListeners( + udpRTPListener, udpRTCPListener, err = createUDPListenerPair( + c, true, intf, net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)), net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)), ) if err != nil { - cm.close() return nil, err } - cm.udpRTPListener.readIP = remoteIP - cm.udpRTPListener.readPort = thRes.Ports[0] - cm.udpRTPListener.writeAddr = &net.UDPAddr{ + udpRTPListener.readIP = remoteIP + udpRTPListener.readPort = thRes.Ports[0] + udpRTPListener.writeAddr = &net.UDPAddr{ IP: remoteIP, Port: thRes.Ports[0], } - cm.udpRTCPListener.readIP = remoteIP - cm.udpRTCPListener.readPort = thRes.Ports[1] - cm.udpRTCPListener.writeAddr = &net.UDPAddr{ + udpRTCPListener.readIP = remoteIP + udpRTCPListener.readPort = thRes.Ports[1] + udpRTCPListener.writeAddr = &net.UDPAddr{ IP: remoteIP, Port: thRes.Ports[1], } case TransportTCP: if thRes.Protocol != headers.TransportProtocolTCP { - cm.close() return nil, liberrors.ErrClientServerRequestedUDP{} } if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast { - cm.close() return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{} } if thRes.InterleavedIDs == nil { - cm.close() return nil, liberrors.ErrClientTransportHeaderNoInterleavedIDs{} } if (thRes.InterleavedIDs[0] + 1) != thRes.InterleavedIDs[1] { - cm.close() return nil, liberrors.ErrClientTransportHeaderInvalidInterleavedIDs{} } if c.isChannelPairInUse(thRes.InterleavedIDs[0]) { - cm.close() return &base.Response{ StatusCode: base.StatusBadRequest, }, liberrors.ErrClientTransportHeaderInterleavedIDsInUse{} } - cm.tcpChannel = thRes.InterleavedIDs[0] + tcpChannel = thRes.InterleavedIDs[0] } if thRes.Profile != th.Profile { - cm.close() return nil, fmt.Errorf("returned profile does not match requested profile") } - if cm.secure { + if isSecure(th.Profile) { var mikeyMsg *mikey.Message // extract key-mgmt from (in order of priority): @@ -1935,7 +1987,6 @@ func (c *Client) doSetup( var keyMgmt headers.KeyMgmt err = keyMgmt.Unmarshal(res.Header["KeyMgmt"]) if err != nil { - cm.close() return nil, err } mikeyMsg = keyMgmt.MikeyMessage @@ -1950,13 +2001,28 @@ func (c *Client) doSetup( return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way") } - cm.srtpInCtx, err = mikeyToContext(mikeyMsg) + srtpInCtx, err = mikeyToContext(mikeyMsg) if err != nil { - cm.close() return nil, err } } + cm := &clientMedia{ + c: c, + media: medi, + secure: isSecure(th.Profile), + udpRTPListener: udpRTPListener, + udpRTCPListener: udpRTCPListener, + tcpChannel: tcpChannel, + localSSRCs: localSSRCs, + srtpInCtx: srtpInCtx, + srtpOutCtx: srtpOutCtx, + } + cm.initialize() + + udpRTPListener = nil + udpRTCPListener = nil + if c.setuppedMedias == nil { c.setuppedMedias = make(map[*description.Media]*clientMedia) } @@ -2309,13 +2375,6 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet, default: } - c.writerMutex.RLock() - defer c.writerMutex.RUnlock() - - if c.writer == nil { - return nil - } - cm := c.setuppedMedias[medi] cf := cm.formats[pkt.PayloadType] return cf.writePacketRTP(pkt, ntp) @@ -2329,13 +2388,6 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error default: } - c.writerMutex.RLock() - defer c.writerMutex.RUnlock() - - if c.writer == nil { - return nil - } - cm := c.setuppedMedias[medi] return cm.writePacketRTCP(pkt) } diff --git a/client_format.go b/client_format.go index befd91a3..068f298f 100644 --- a/client_format.go +++ b/client_format.go @@ -1,7 +1,6 @@ package gortsplib import ( - "slices" "sync/atomic" "time" @@ -14,37 +13,12 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/rtpsender" ) -func clientPickLocalSSRC(cf *clientFormat) (uint32, error) { - var takenSSRCs []uint32 //nolint:prealloc - - for _, cm := range cf.cm.c.setuppedMedias { - for _, cf := range cm.formats { - takenSSRCs = append(takenSSRCs, cf.localSSRC) - } - } - - for _, cf := range cf.cm.formats { - takenSSRCs = append(takenSSRCs, cf.localSSRC) - } - - for { - ssrc, err := randUint32() - if err != nil { - return 0, err - } - - if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { - return ssrc, nil - } - } -} - type clientFormat struct { cm *clientMedia format format.Format + localSSRC uint32 onPacketRTP OnPacketRTPFunc - localSSRC uint32 rtpReceiver *rtpreceiver.Receiver // play rtpSender *rtpsender.Sender // record or back channel writePacketRTPInQueue func([]byte) error @@ -53,32 +27,18 @@ type clientFormat struct { rtpPacketsLost *uint64 } -func (cf *clientFormat) initialize() error { - if cf.cm.c.state == clientStatePreRecord { - cf.localSSRC = cf.cm.c.announceData[cf.cm.media].formats[cf.format.PayloadType()].localSSRC - } else { - var err error - cf.localSSRC, err = clientPickLocalSSRC(cf) - if err != nil { - return err - } - } - +func (cf *clientFormat) initialize() { cf.rtpPacketsReceived = new(uint64) cf.rtpPacketsSent = new(uint64) cf.rtpPacketsLost = new(uint64) - return nil -} - -func (cf *clientFormat) start() { if cf.cm.udpRTPListener != nil { cf.writePacketRTPInQueue = cf.writePacketRTPInQueueUDP } else { cf.writePacketRTPInQueue = cf.writePacketRTPInQueueTCP } - if cf.cm.c.state == clientStateRecord || cf.cm.media.IsBackChannel { + if cf.cm.c.state == clientStatePreRecord || cf.cm.media.IsBackChannel { cf.rtpSender = &rtpsender.Sender{ ClockRate: cf.format.ClockRate(), Period: cf.cm.c.senderReportPeriod, @@ -110,7 +70,7 @@ func (cf *clientFormat) start() { } } -func (cf *clientFormat) stop() { +func (cf *clientFormat) close() { if cf.rtpReceiver != nil { cf.rtpReceiver.Close() cf.rtpReceiver = nil @@ -178,6 +138,13 @@ func (cf *clientFormat) writePacketRTP(pkt *rtp.Packet, ntp time.Time) error { buf = encr } + cf.cm.c.writerMutex.RLock() + defer cf.cm.c.writerMutex.RUnlock() + + if cf.cm.c.writer == nil { + return nil + } + ok := cf.cm.c.writer.push(func() error { return cf.writePacketRTPInQueue(buf) }) diff --git a/client_media.go b/client_media.go index 033146d1..c39a162d 100644 --- a/client_media.go +++ b/client_media.go @@ -1,7 +1,6 @@ package gortsplib import ( - "crypto/rand" "fmt" "net" "strconv" @@ -15,18 +14,87 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/liberrors" ) -type clientMedia struct { - c *Client - media *description.Media - secure bool +func createUDPListenerPair( + c *Client, + multicast bool, + multicastInterface *net.Interface, + rtpAddress string, + rtcpAddress string, +) (*clientUDPListener, *clientUDPListener, error) { + if rtpAddress != ":0" { + l1 := &clientUDPListener{ + c: c, + multicast: multicast, + multicastInterface: multicastInterface, + address: rtpAddress, + } + err := l1.initialize() + if err != nil { + return nil, nil, err + } + + l2 := &clientUDPListener{ + c: c, + multicast: multicast, + multicastInterface: multicastInterface, + address: rtcpAddress, + } + err = l2.initialize() + if err != nil { + l1.close() + return nil, nil, err + } + + return l1, l2, nil + } + + // pick two consecutive ports in range 65535-10000 + // RTP port must be even and RTCP port odd + for { + v, err := randInRange((65535 - 10000) / 2) + if err != nil { + return nil, nil, err + } + + rtpPort := v*2 + 10000 + rtcpPort := rtpPort + 1 + + l1 := &clientUDPListener{ + c: c, + address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), + } + err = l1.initialize() + if err != nil { + continue + } + + l2 := &clientUDPListener{ + c: c, + address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), + } + err = l2.initialize() + if err != nil { + l2.close() + continue + } + + return l1, l2, nil + } +} + +type clientMedia struct { + c *Client + media *description.Media + secure bool + udpRTPListener *clientUDPListener + udpRTCPListener *clientUDPListener + tcpChannel int + localSSRCs map[uint8]uint32 + srtpInCtx *wrappedSRTPContext + srtpOutCtx *wrappedSRTPContext - srtpOutCtx *wrappedSRTPContext - srtpInCtx *wrappedSRTPContext onPacketRTCP OnPacketRTCPFunc formats map[uint8]*clientFormat - tcpChannel int - udpRTPListener *clientUDPListener - udpRTCPListener *clientUDPListener writePacketRTCPInQueue func([]byte) error bytesReceived *uint64 bytesSent *uint64 @@ -36,7 +104,7 @@ type clientMedia struct { rtcpPacketsInError *uint64 } -func (cm *clientMedia) initialize() error { +func (cm *clientMedia) initialize() { cm.onPacketRTCP = func(rtcp.Packet) {} cm.bytesReceived = new(uint64) cm.bytesSent = new(uint64) @@ -51,123 +119,19 @@ func (cm *clientMedia) initialize() error { f := &clientFormat{ cm: cm, format: forma, + localSSRC: cm.localSSRCs[forma.PayloadType()], onPacketRTP: func(*rtp.Packet) {}, } - err := f.initialize() - if err != nil { - return err - } - + f.initialize() cm.formats[forma.PayloadType()] = f } - - if cm.secure { - var srtpOutKey []byte - if cm.c.state == clientStatePreRecord { - srtpOutKey = cm.c.announceData[cm.media].srtpOutKey - } else { - srtpOutKey = make([]byte, srtpKeyLength) - _, err := rand.Read(srtpOutKey) - if err != nil { - return err - } - } - - ssrcs := make([]uint32, len(cm.formats)) - n := 0 - for _, cf := range cm.formats { - ssrcs[n] = cf.localSSRC - n++ - } - - cm.srtpOutCtx = &wrappedSRTPContext{ - key: srtpOutKey, - ssrcs: ssrcs, - } - err := cm.srtpOutCtx.initialize() - if err != nil { - return err - } - } - - return nil } func (cm *clientMedia) close() { - if cm.udpRTPListener != nil { - cm.udpRTPListener.close() - cm.udpRTCPListener.close() - } -} + cm.stop() -func (cm *clientMedia) createUDPListeners( - multicast bool, - multicastInterface *net.Interface, - rtpAddress string, - rtcpAddress string, -) error { - if rtpAddress != ":0" { - l1 := &clientUDPListener{ - c: cm.c, - multicast: multicast, - multicastInterface: multicastInterface, - address: rtpAddress, - } - err := l1.initialize() - if err != nil { - return err - } - - l2 := &clientUDPListener{ - c: cm.c, - multicast: multicast, - multicastInterface: multicastInterface, - address: rtcpAddress, - } - err = l2.initialize() - if err != nil { - l1.close() - return err - } - - cm.udpRTPListener, cm.udpRTCPListener = l1, l2 - return nil - } - - // pick two consecutive ports in range 65535-10000 - // RTP port must be even and RTCP port odd - for { - v, err := randInRange((65535 - 10000) / 2) - if err != nil { - return err - } - - rtpPort := v*2 + 10000 - rtcpPort := rtpPort + 1 - - cm.udpRTPListener = &clientUDPListener{ - c: cm.c, - address: net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)), - } - err = cm.udpRTPListener.initialize() - if err != nil { - cm.udpRTPListener = nil - continue - } - - cm.udpRTCPListener = &clientUDPListener{ - c: cm.c, - address: net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)), - } - err = cm.udpRTCPListener.initialize() - if err != nil { - cm.udpRTPListener.close() - cm.udpRTPListener = nil - cm.udpRTCPListener = nil - continue - } - - return nil + for _, ct := range cm.formats { + ct.close() } } @@ -198,10 +162,6 @@ func (cm *clientMedia) start() { } } - for _, ct := range cm.formats { - ct.start() - } - if cm.udpRTPListener != nil { cm.udpRTPListener.start() cm.udpRTCPListener.start() @@ -213,10 +173,6 @@ func (cm *clientMedia) stop() { cm.udpRTPListener.stop() cm.udpRTCPListener.stop() } - - for _, ct := range cm.formats { - ct.stop() - } } func (cm *clientMedia) findFormatByRemoteSSRC(ssrc uint32) *clientFormat { @@ -460,6 +416,13 @@ func (cm *clientMedia) writePacketRTCP(pkt rtcp.Packet) error { buf = encr } + cm.c.writerMutex.RLock() + defer cm.c.writerMutex.RUnlock() + + if cm.c.writer == nil { + return nil + } + ok := cm.c.writer.push(func() error { return cm.writePacketRTCPInQueue(buf) }) diff --git a/server_session.go b/server_session.go index dc3c2f30..68b48e46 100644 --- a/server_session.go +++ b/server_session.go @@ -3,6 +3,7 @@ package gortsplib import ( "bytes" "context" + "crypto/rand" "fmt" "log" "net" @@ -31,6 +32,16 @@ import ( type readFunc func([]byte) bool +func serverSessionExtractExistingSSRCs(medias map[*description.Media]*serverSessionMedia) []uint32 { + var ret []uint32 + for _, media := range medias { + for _, forma := range media.formats { + ret = append(ret, forma.localSSRC) + } + } + return ret +} + func isSecure(profile headers.TransportProfile) bool { return profile == headers.TransportProfileSAVP } @@ -433,8 +444,8 @@ type ServerSession struct { announcedDesc *description.Session // record udpLastPacketTime *int64 // record udpCheckStreamTimer *time.Timer - writer *asyncProcessor writerMutex sync.RWMutex + writer *asyncProcessor timeDecoder *rtptime.GlobalDecoder2 tcpFrame *base.InterleavedFrame tcpBuffer []byte @@ -818,7 +829,7 @@ func (ss *ServerSession) run() { } for _, sm := range ss.setuppedMedias { - sm.stop() + sm.close() } if ss.writer != nil { @@ -1290,37 +1301,76 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( res.Header = make(base.Header) } - sm := &serverSessionMedia{ - ss: ss, - media: medi, - srtpInCtx: srtpInCtx, - onPacketRTCP: func(_ rtcp.Packet) {}, + var localSSRCs map[uint8]uint32 + + if ss.state == ServerSessionStatePreRecord || medi.IsBackChannel { + localSSRCs, err = generateLocalSSRCs( + serverSessionExtractExistingSSRCs(ss.setuppedMedias), + medi.Formats, + ) + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, err + } + } else { + localSSRCs = make(map[uint8]uint32) + for forma, data := range ss.setuppedStream.medias[medi].formats { + localSSRCs[forma] = data.localSSRC + } } - err = sm.initialize() - if err != nil { - return &base.Response{ - StatusCode: base.StatusInternalServerError, - }, err + + var srtpOutCtx *wrappedSRTPContext + + if ss.s.TLSConfig != nil { + if ss.state == ServerSessionStatePreRecord || medi.IsBackChannel { + srtpOutKey := make([]byte, srtpKeyLength) + _, err = rand.Read(srtpOutKey) + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, err + } + + srtpOutCtx = &wrappedSRTPContext{ + key: srtpOutKey, + ssrcs: ssrcsMapToList(localSSRCs), + } + err = srtpOutCtx.initialize() + if err != nil { + return &base.Response{ + StatusCode: base.StatusInternalServerError, + }, err + } + } else { + srtpOutCtx = ss.setuppedStream.medias[medi].srtpOutCtx + } } + var udpRTPReadPort int + var udpRTPWriteAddr *net.UDPAddr + var udpRTCPReadPort int + var udpRTCPWriteAddr *net.UDPAddr + var tcpChannel int + switch transport { case TransportUDP, TransportUDPMulticast: th.Protocol = headers.TransportProtocolUDP if transport == TransportUDP { - sm.udpRTPReadPort = inTH.ClientPorts[0] - sm.udpRTCPReadPort = inTH.ClientPorts[1] + udpRTPReadPort = inTH.ClientPorts[0] + udpRTCPReadPort = inTH.ClientPorts[1] - sm.udpRTPWriteAddr = &net.UDPAddr{ + udpRTPWriteAddr = &net.UDPAddr{ IP: ss.author.ip(), Zone: ss.author.zone(), - Port: sm.udpRTPReadPort, + Port: udpRTPReadPort, } - sm.udpRTCPWriteAddr = &net.UDPAddr{ + udpRTCPWriteAddr = &net.UDPAddr{ IP: ss.author.ip(), Zone: ss.author.zone(), - Port: sm.udpRTCPReadPort, + Port: udpRTCPReadPort, } de := headers.TransportDeliveryUnicast @@ -1341,32 +1391,41 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( th.Protocol = headers.TransportProtocolTCP if inTH.InterleavedIDs != nil { - sm.tcpChannel = inTH.InterleavedIDs[0] + tcpChannel = inTH.InterleavedIDs[0] } else { - sm.tcpChannel = ss.findFreeChannelPair() + tcpChannel = ss.findFreeChannelPair() } de := headers.TransportDeliveryUnicast th.Delivery = &de - th.InterleavedIDs = &[2]int{sm.tcpChannel, sm.tcpChannel + 1} + th.InterleavedIDs = &[2]int{tcpChannel, tcpChannel + 1} } + sm := &serverSessionMedia{ + ss: ss, + media: medi, + localSSRCs: localSSRCs, + srtpInCtx: srtpInCtx, + srtpOutCtx: srtpOutCtx, + udpRTPReadPort: udpRTPReadPort, + udpRTPWriteAddr: udpRTPWriteAddr, + udpRTCPReadPort: udpRTCPReadPort, + udpRTCPWriteAddr: udpRTCPWriteAddr, + tcpChannel: tcpChannel, + onPacketRTCP: func(_ rtcp.Packet) {}, + } + sm.initialize() + if ss.setuppedMedias == nil { ss.setuppedMedias = make(map[*description.Media]*serverSessionMedia) } + ss.setuppedMedias[medi] = sm ss.setuppedMediasOrdered = append(ss.setuppedMediasOrdered, sm) res.Header["Transport"] = th.Marshal() if isSecure(inTH.Profile) { - ssrcs := make([]uint32, len(sm.formats)) - n := 0 - for _, sf := range sm.formats { - ssrcs[n] = sf.localSSRC - n++ - } - var mk *mikey.Message mk, err = mikeyGenerate(sm.srtpOutCtx) if err != nil { diff --git a/server_session_format.go b/server_session_format.go index dcaf8b18..e581bba3 100644 --- a/server_session_format.go +++ b/server_session_format.go @@ -2,7 +2,6 @@ package gortsplib import ( "log" - "slices" "sync/atomic" "time" @@ -14,37 +13,12 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/rtpreceiver" ) -func serverSessionPickLocalSSRC(sf *serverSessionFormat) (uint32, error) { - var takenSSRCs []uint32 //nolint:prealloc - - for _, sm := range sf.sm.ss.setuppedMedias { - for _, sf := range sm.formats { - takenSSRCs = append(takenSSRCs, sf.localSSRC) - } - } - - for _, sf := range sf.sm.formats { - takenSSRCs = append(takenSSRCs, sf.localSSRC) - } - - for { - ssrc, err := randUint32() - if err != nil { - return 0, err - } - - if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { - return ssrc, nil - } - } -} - type serverSessionFormat struct { sm *serverSessionMedia format format.Format + localSSRC uint32 onPacketRTP OnPacketRTPFunc - localSSRC uint32 rtpReceiver *rtpreceiver.Receiver writePacketRTPInQueue func([]byte) error rtpPacketsReceived *uint64 @@ -52,25 +26,11 @@ type serverSessionFormat struct { rtpPacketsLost *uint64 } -func (sf *serverSessionFormat) initialize() error { - if sf.sm.ss.state == ServerSessionStatePreRecord || sf.sm.media.IsBackChannel { - var err error - sf.localSSRC, err = serverSessionPickLocalSSRC(sf) - if err != nil { - return err - } - } else { - sf.localSSRC = sf.sm.ss.setuppedStream.medias[sf.sm.media].formats[sf.format.PayloadType()].localSSRC - } - +func (sf *serverSessionFormat) initialize() { sf.rtpPacketsReceived = new(uint64) sf.rtpPacketsSent = new(uint64) sf.rtpPacketsLost = new(uint64) - return nil -} - -func (sf *serverSessionFormat) start() { udp := *sf.sm.ss.setuppedTransport == TransportUDP || *sf.sm.ss.setuppedTransport == TransportUDPMulticast if udp { @@ -79,7 +39,7 @@ func (sf *serverSessionFormat) start() { sf.writePacketRTPInQueue = sf.writePacketRTPInQueueTCP } - if sf.sm.ss.state == ServerSessionStateRecord || sf.sm.media.IsBackChannel { + if sf.sm.ss.state == ServerSessionStatePreRecord || sf.sm.media.IsBackChannel { sf.rtpReceiver = &rtpreceiver.Receiver{ ClockRate: sf.format.ClockRate(), LocalSSRC: &sf.localSSRC, @@ -99,7 +59,7 @@ func (sf *serverSessionFormat) start() { } } -func (sf *serverSessionFormat) stop() { +func (sf *serverSessionFormat) close() { if sf.rtpReceiver != nil { sf.rtpReceiver.Close() sf.rtpReceiver = nil diff --git a/server_session_media.go b/server_session_media.go index 7b78f814..fe0c2355 100644 --- a/server_session_media.go +++ b/server_session_media.go @@ -1,7 +1,6 @@ package gortsplib import ( - "crypto/rand" "fmt" "log" "net" @@ -16,17 +15,18 @@ import ( ) type serverSessionMedia struct { - ss *ServerSession - media *description.Media - srtpInCtx *wrappedSRTPContext - onPacketRTCP OnPacketRTCPFunc + ss *ServerSession + media *description.Media + localSSRCs map[uint8]uint32 + srtpInCtx *wrappedSRTPContext + srtpOutCtx *wrappedSRTPContext + udpRTPReadPort int + udpRTPWriteAddr *net.UDPAddr + udpRTCPReadPort int + udpRTCPWriteAddr *net.UDPAddr + tcpChannel int + onPacketRTCP OnPacketRTCPFunc - srtpOutCtx *wrappedSRTPContext - tcpChannel int - udpRTPReadPort int - udpRTPWriteAddr *net.UDPAddr - udpRTCPReadPort int - udpRTCPWriteAddr *net.UDPAddr formats map[uint8]*serverSessionFormat // record only writePacketRTCPInQueue func([]byte) error bytesReceived *uint64 @@ -37,7 +37,7 @@ type serverSessionMedia struct { rtcpPacketsInError *uint64 } -func (sm *serverSessionMedia) initialize() error { +func (sm *serverSessionMedia) initialize() { sm.bytesReceived = new(uint64) sm.bytesSent = new(uint64) sm.rtpPacketsInError = new(uint64) @@ -51,54 +51,23 @@ func (sm *serverSessionMedia) initialize() error { f := &serverSessionFormat{ sm: sm, format: forma, + localSSRC: sm.localSSRCs[forma.PayloadType()], onPacketRTP: func(*rtp.Packet) {}, } - err := f.initialize() - if err != nil { - return err - } + f.initialize() sm.formats[forma.PayloadType()] = f } +} - if sm.ss.s.TLSConfig != nil { - if sm.ss.state == ServerSessionStatePreRecord || sm.media.IsBackChannel { - srtpOutKey := make([]byte, srtpKeyLength) - _, err := rand.Read(srtpOutKey) - if err != nil { - return err - } +func (sm *serverSessionMedia) close() { + sm.stop() - ssrcs := make([]uint32, len(sm.formats)) - n := 0 - for _, cf := range sm.formats { - ssrcs[n] = cf.localSSRC - n++ - } - - sm.srtpOutCtx = &wrappedSRTPContext{ - key: srtpOutKey, - ssrcs: ssrcs, - } - err = sm.srtpOutCtx.initialize() - if err != nil { - return err - } - } else { - streamMedia := sm.ss.setuppedStream.medias[sm.media] - sm.srtpOutCtx = streamMedia.srtpOutCtx - } + for _, forma := range sm.formats { + forma.close() } - - return nil } func (sm *serverSessionMedia) start() error { - // allocate udpRTCPReceiver before udpRTCPListener - // otherwise udpRTCPReceiver.LastSSRC() cannot be called. - for _, sf := range sm.formats { - sf.start() - } - switch *sm.ss.setuppedTransport { case TransportUDP, TransportUDPMulticast: sm.writePacketRTCPInQueue = sm.writePacketRTCPInQueueUDP @@ -168,10 +137,6 @@ func (sm *serverSessionMedia) stop() { sm.ss.s.udpRTPListener.removeClient(sm.ss.author.ip(), sm.udpRTPReadPort) sm.ss.s.udpRTCPListener.removeClient(sm.ss.author.ip(), sm.udpRTCPReadPort) } - - for _, sf := range sm.formats { - sf.stop() - } } func (sm *serverSessionMedia) findFormatByRemoteSSRC(ssrc uint32) *serverSessionFormat { diff --git a/server_stream.go b/server_stream.go index 3f3c7171..f1549044 100644 --- a/server_stream.go +++ b/server_stream.go @@ -1,6 +1,7 @@ package gortsplib import ( + "crypto/rand" "fmt" "sync" "sync/atomic" @@ -14,6 +15,16 @@ import ( "github.com/bluenviron/gortsplib/v4/pkg/liberrors" ) +func serverStreamExtractExistingSSRCs(medias map[*description.Media]*serverStreamMedia) []uint32 { + var ret []uint32 + for _, media := range medias { + for _, forma := range media.formats { + ret = append(ret, forma.localSSRC) + } + } + return ret +} + // NewServerStream allocates a ServerStream. // // Deprecated: replaced by ServerStream.Initialize(). @@ -56,20 +67,53 @@ func (st *ServerStream) Initialize() error { st.activeUnicastReaders = make(map[*ServerSession]struct{}) st.medias = make(map[*description.Media]*serverStreamMedia, len(st.Desc.Medias)) + for i, medi := range st.Desc.Medias { - sm := &serverStreamMedia{ - st: st, - media: medi, - trackID: i, - } - err := sm.initialize() + localSSRCs, err := generateLocalSSRCs( + serverStreamExtractExistingSSRCs(st.medias), + medi.Formats, + ) if err != nil { - for _, medi := range st.Desc.Medias[:i] { - st.medias[medi].close() + for _, sm := range st.medias { + sm.close() } return err } + var srtpOutCtx *wrappedSRTPContext + + if st.Server.TLSConfig != nil { + srtpOutKey := make([]byte, srtpKeyLength) + _, err = rand.Read(srtpOutKey) + if err != nil { + for _, sm := range st.medias { + sm.close() + } + return err + } + + srtpOutCtx = &wrappedSRTPContext{ + key: srtpOutKey, + ssrcs: ssrcsMapToList(localSSRCs), + } + err = srtpOutCtx.initialize() + if err != nil { + for _, sm := range st.medias { + sm.close() + } + return err + } + } + + sm := &serverStreamMedia{ + st: st, + media: medi, + trackID: i, + localSSRCs: localSSRCs, + srtpOutCtx: srtpOutCtx, + } + sm.initialize() + st.medias[medi] = sm } diff --git a/server_stream_format.go b/server_stream_format.go index 194bc9b7..eb14cc62 100644 --- a/server_stream_format.go +++ b/server_stream_format.go @@ -2,7 +2,6 @@ package gortsplib import ( "crypto/rand" - "slices" "sync/atomic" "time" @@ -22,47 +21,16 @@ func randUint32() (uint32, error) { return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3]), nil } -func serverStreamPickLocalSSRC(sf *serverStreamFormat) (uint32, error) { - var takenSSRCs []uint32 //nolint:prealloc - - for _, sm := range sf.sm.st.medias { - for _, sf := range sm.formats { - takenSSRCs = append(takenSSRCs, sf.localSSRC) - } - } - - for _, sf := range sf.sm.formats { - takenSSRCs = append(takenSSRCs, sf.localSSRC) - } - - for { - ssrc, err := randUint32() - if err != nil { - return 0, err - } - - if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) { - return ssrc, nil - } - } -} - type serverStreamFormat struct { - sm *serverStreamMedia - format format.Format + sm *serverStreamMedia + format format.Format + localSSRC uint32 - localSSRC uint32 rtpSender *rtpsender.Sender rtpPacketsSent *uint64 } -func (sf *serverStreamFormat) initialize() error { - var err error - sf.localSSRC, err = serverStreamPickLocalSSRC(sf) - if err != nil { - return err - } - +func (sf *serverStreamFormat) initialize() { sf.rtpPacketsSent = new(uint64) sf.rtpSender = &rtpsender.Sender{ @@ -76,8 +44,6 @@ func (sf *serverStreamFormat) initialize() error { }, } sf.rtpSender.Initialize() - - return nil } func (sf *serverStreamFormat) close() { diff --git a/server_stream_media.go b/server_stream_media.go index ae4bef16..80058985 100644 --- a/server_stream_media.go +++ b/server_stream_media.go @@ -1,7 +1,6 @@ package gortsplib import ( - "crypto/rand" "fmt" "sync/atomic" @@ -10,64 +9,33 @@ import ( ) type serverStreamMedia struct { - st *ServerStream - media *description.Media - trackID int + st *ServerStream + media *description.Media + trackID int + localSSRCs map[uint8]uint32 + srtpOutCtx *wrappedSRTPContext - srtpOutCtx *wrappedSRTPContext formats map[uint8]*serverStreamFormat multicastWriter *serverMulticastWriter bytesSent *uint64 rtcpPacketsSent *uint64 } -func (sm *serverStreamMedia) initialize() error { +func (sm *serverStreamMedia) initialize() { sm.bytesSent = new(uint64) sm.rtcpPacketsSent = new(uint64) sm.formats = make(map[uint8]*serverStreamFormat) - for i, forma := range sm.media.Formats { + for _, forma := range sm.media.Formats { sf := &serverStreamFormat{ - sm: sm, - format: forma, + sm: sm, + format: forma, + localSSRC: sm.localSSRCs[forma.PayloadType()], } - err := sf.initialize() - if err != nil { - for _, forma := range sm.media.Formats[:i] { - sm.formats[forma.PayloadType()].close() - } - return err - } - + sf.initialize() sm.formats[forma.PayloadType()] = sf } - - if sm.st.Server.TLSConfig != nil { - srtpOutKey := make([]byte, srtpKeyLength) - _, err := rand.Read(srtpOutKey) - if err != nil { - return err - } - - ssrcs := make([]uint32, len(sm.formats)) - n := 0 - for _, cf := range sm.formats { - ssrcs[n] = cf.localSSRC - n++ - } - - sm.srtpOutCtx = &wrappedSRTPContext{ - key: srtpOutKey, - ssrcs: ssrcs, - } - err = sm.srtpOutCtx.initialize() - if err != nil { - return err - } - } - - return nil } func (sm *serverStreamMedia) close() {