mirror of
https://github.com/aler9/gortsplib
synced 2025-09-27 03:25:52 +08:00
optimize code (#878)
* remove unused code * initialize UDP listeners and SRTP before initializing medias * make rtpSender and rtpReceiver available before PLAY / RECORD * use writerMutex to protect writer only
This commit is contained in:
216
client.go
216
client.go
@@ -41,6 +41,47 @@ const (
|
||||
clientUserAgent = "gortsplib"
|
||||
)
|
||||
|
||||
func generateLocalSSRCs(existing []uint32, formats []format.Format) (map[uint8]uint32, error) {
|
||||
ret := make(map[uint8]uint32)
|
||||
|
||||
for _, forma := range formats {
|
||||
for {
|
||||
ssrc, err := randUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ssrc != 0 && !slices.Contains(existing, ssrc) {
|
||||
existing = append(existing, ssrc)
|
||||
ret[forma.PayloadType()] = ssrc
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func ssrcsMapToList(m map[uint8]uint32) []uint32 {
|
||||
ret := make([]uint32, len(m))
|
||||
n := 0
|
||||
for _, el := range m {
|
||||
ret[n] = el
|
||||
n++
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func clientExtractExistingSSRCs(setuppedMedias map[*description.Media]*clientMedia) []uint32 {
|
||||
var ret []uint32
|
||||
for _, media := range setuppedMedias {
|
||||
for _, forma := range media.formats {
|
||||
ret = append(ret, forma.localSSRC)
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// avoid an int64 overflow and preserve resolution by splitting division into two parts:
|
||||
// first add the integer part, then the decimal part.
|
||||
func multiplyAndDivide(v, m, d time.Duration) time.Duration {
|
||||
@@ -134,16 +175,16 @@ func announceDataPickLocalSSRC(
|
||||
am *clientAnnounceDataMedia,
|
||||
data map[*description.Media]*clientAnnounceDataMedia,
|
||||
) (uint32, error) {
|
||||
var takenSSRCs []uint32 //nolint:prealloc
|
||||
var existing []uint32 //nolint:prealloc
|
||||
|
||||
for _, am := range data {
|
||||
for _, af := range am.formats {
|
||||
takenSSRCs = append(takenSSRCs, af.localSSRC)
|
||||
existing = append(existing, af.localSSRC)
|
||||
}
|
||||
}
|
||||
|
||||
for _, af := range am.formats {
|
||||
takenSSRCs = append(takenSSRCs, af.localSSRC)
|
||||
existing = append(existing, af.localSSRC)
|
||||
}
|
||||
|
||||
for {
|
||||
@@ -152,7 +193,7 @@ func announceDataPickLocalSSRC(
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if ssrc != 0 && !slices.Contains(takenSSRCs, ssrc) {
|
||||
if ssrc != 0 && !slices.Contains(existing, ssrc) {
|
||||
return ssrc, nil
|
||||
}
|
||||
}
|
||||
@@ -214,9 +255,9 @@ func prepareForAnnounce(
|
||||
n++
|
||||
}
|
||||
|
||||
// use a dummy Context.
|
||||
// create a temporary Context.
|
||||
// Context is needed to extract ROC, but since client has not started streaming,
|
||||
// ROC is always zero, therefore a dummy Context can be used.
|
||||
// ROC is always zero, therefore a temporary Context can be used.
|
||||
srtpCtx := &wrappedSRTPContext{
|
||||
key: announceDataMedia.srtpOutKey,
|
||||
ssrcs: ssrcs,
|
||||
@@ -527,8 +568,8 @@ type Client struct {
|
||||
keepAlivePeriod time.Duration
|
||||
keepAliveTimer *time.Timer
|
||||
closeError error
|
||||
writer *asyncProcessor
|
||||
writerMutex sync.RWMutex
|
||||
writer *asyncProcessor
|
||||
reader *clientReader
|
||||
timeDecoder *rtptime.GlobalDecoder2
|
||||
mustClose bool
|
||||
@@ -1618,20 +1659,41 @@ func (c *Client) doSetup(
|
||||
}
|
||||
}
|
||||
|
||||
cm := &clientMedia{
|
||||
c: c,
|
||||
media: medi,
|
||||
secure: isSecure(th.Profile),
|
||||
}
|
||||
err = cm.initialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
var localSSRCs map[uint8]uint32
|
||||
|
||||
if c.state == clientStatePreRecord {
|
||||
localSSRCs = make(map[uint8]uint32)
|
||||
for forma, data := range c.announceData[medi].formats {
|
||||
localSSRCs[forma] = data.localSSRC
|
||||
}
|
||||
} else {
|
||||
localSSRCs, err = generateLocalSSRCs(
|
||||
clientExtractExistingSSRCs(c.setuppedMedias),
|
||||
medi.Formats,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var udpRTPListener *clientUDPListener
|
||||
var udpRTCPListener *clientUDPListener
|
||||
var tcpChannel int
|
||||
var srtpInCtx *wrappedSRTPContext
|
||||
var srtpOutCtx *wrappedSRTPContext
|
||||
|
||||
defer func() {
|
||||
if udpRTPListener != nil {
|
||||
udpRTPListener.close()
|
||||
}
|
||||
if udpRTCPListener != nil {
|
||||
udpRTCPListener.close()
|
||||
}
|
||||
}()
|
||||
|
||||
switch transport {
|
||||
case TransportUDP, TransportUDPMulticast:
|
||||
if c.Scheme == "rtsps" && !isSecure(th.Profile) {
|
||||
cm.close()
|
||||
return nil, fmt.Errorf("unable to setup secure UDP")
|
||||
}
|
||||
|
||||
@@ -1640,29 +1702,27 @@ func (c *Client) doSetup(
|
||||
if transport == TransportUDP {
|
||||
if (rtpPort == 0 && rtcpPort != 0) ||
|
||||
(rtpPort != 0 && rtcpPort == 0) {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientUDPPortsZero{}
|
||||
}
|
||||
|
||||
if rtpPort != 0 && rtcpPort != (rtpPort+1) {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientUDPPortsNotConsecutive{}
|
||||
}
|
||||
|
||||
err = cm.createUDPListeners(
|
||||
udpRTPListener, udpRTCPListener, err = createUDPListenerPair(
|
||||
c,
|
||||
false,
|
||||
nil,
|
||||
net.JoinHostPort("", strconv.FormatInt(int64(rtpPort), 10)),
|
||||
net.JoinHostPort("", strconv.FormatInt(int64(rtcpPort), 10)),
|
||||
)
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
v1 := headers.TransportDeliveryUnicast
|
||||
th.Delivery = &v1
|
||||
th.ClientPorts = &[2]int{cm.udpRTPListener.port(), cm.udpRTCPListener.port()}
|
||||
th.ClientPorts = &[2]int{udpRTPListener.port(), udpRTCPListener.port()}
|
||||
} else {
|
||||
v1 := headers.TransportDeliveryMulticast
|
||||
th.Delivery = &v1
|
||||
@@ -1678,7 +1738,6 @@ func (c *Client) doSetup(
|
||||
|
||||
mediaURL, err := medi.URL(baseURL)
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1688,7 +1747,6 @@ func (c *Client) doSetup(
|
||||
|
||||
if medi.IsBackChannel {
|
||||
if !c.RequestBackChannels {
|
||||
cm.close()
|
||||
return nil, fmt.Errorf("we are setupping a back channel but we did not request back channels")
|
||||
}
|
||||
|
||||
@@ -1696,17 +1754,30 @@ func (c *Client) doSetup(
|
||||
}
|
||||
|
||||
if isSecure(th.Profile) {
|
||||
ssrcs := make([]uint32, len(cm.formats))
|
||||
n := 0
|
||||
for _, cf := range cm.formats {
|
||||
ssrcs[n] = cf.localSSRC
|
||||
n++
|
||||
var srtpOutKey []byte
|
||||
|
||||
if c.state == clientStatePreRecord {
|
||||
srtpOutKey = c.announceData[medi].srtpOutKey
|
||||
} else {
|
||||
srtpOutKey = make([]byte, srtpKeyLength)
|
||||
_, err = rand.Read(srtpOutKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
srtpOutCtx = &wrappedSRTPContext{
|
||||
key: srtpOutKey,
|
||||
ssrcs: ssrcsMapToList(localSSRCs),
|
||||
}
|
||||
err = srtpOutCtx.initialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var mikeyMsg *mikey.Message
|
||||
mikeyMsg, err = mikeyGenerate(cm.srtpOutCtx)
|
||||
mikeyMsg, err = mikeyGenerate(srtpOutCtx)
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1716,7 +1787,6 @@ func (c *Client) doSetup(
|
||||
MikeyMessage: mikeyMsg,
|
||||
}.Marshal()
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1729,13 +1799,10 @@ func (c *Client) doSetup(
|
||||
Header: header,
|
||||
}, false)
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode != base.StatusOK {
|
||||
cm.close()
|
||||
|
||||
// switch transport automatically
|
||||
if res.StatusCode == base.StatusUnsupportedTransport &&
|
||||
c.setuppedTransport == nil && c.Transport == nil {
|
||||
@@ -1753,15 +1820,12 @@ func (c *Client) doSetup(
|
||||
var thRes headers.Transport
|
||||
err = thRes.Unmarshal(res.Header["Transport"])
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderInvalid{Err: err}
|
||||
}
|
||||
|
||||
switch transport {
|
||||
case TransportUDP, TransportUDPMulticast:
|
||||
if thRes.Protocol == headers.TransportProtocolTCP {
|
||||
cm.close()
|
||||
|
||||
// switch transport automatically
|
||||
if c.setuppedTransport == nil && c.Transport == nil {
|
||||
c.OnTransportSwitch(liberrors.ErrClientSwitchToTCP2{})
|
||||
@@ -1790,14 +1854,12 @@ func (c *Client) doSetup(
|
||||
switch transport {
|
||||
case TransportUDP:
|
||||
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
|
||||
}
|
||||
|
||||
serverPortsValid := thRes.ServerPorts != nil && !isAnyPort(thRes.ServerPorts[0]) && !isAnyPort(thRes.ServerPorts[1])
|
||||
|
||||
if (c.state == clientStatePreRecord || !c.AnyPortEnable) && !serverPortsValid {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientServerPortsNotProvided{}
|
||||
}
|
||||
|
||||
@@ -1810,41 +1872,38 @@ func (c *Client) doSetup(
|
||||
|
||||
if serverPortsValid {
|
||||
if !c.AnyPortEnable {
|
||||
cm.udpRTPListener.readPort = thRes.ServerPorts[0]
|
||||
udpRTPListener.readPort = thRes.ServerPorts[0]
|
||||
}
|
||||
cm.udpRTPListener.writeAddr = &net.UDPAddr{
|
||||
udpRTPListener.writeAddr = &net.UDPAddr{
|
||||
IP: remoteIP,
|
||||
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
|
||||
Port: thRes.ServerPorts[0],
|
||||
}
|
||||
}
|
||||
cm.udpRTPListener.readIP = remoteIP
|
||||
udpRTPListener.readIP = remoteIP
|
||||
|
||||
if serverPortsValid {
|
||||
if !c.AnyPortEnable {
|
||||
cm.udpRTCPListener.readPort = thRes.ServerPorts[1]
|
||||
udpRTCPListener.readPort = thRes.ServerPorts[1]
|
||||
}
|
||||
cm.udpRTCPListener.writeAddr = &net.UDPAddr{
|
||||
udpRTCPListener.writeAddr = &net.UDPAddr{
|
||||
IP: remoteIP,
|
||||
Zone: c.nconn.RemoteAddr().(*net.TCPAddr).Zone,
|
||||
Port: thRes.ServerPorts[1],
|
||||
}
|
||||
}
|
||||
cm.udpRTCPListener.readIP = remoteIP
|
||||
udpRTCPListener.readIP = remoteIP
|
||||
|
||||
case TransportUDPMulticast:
|
||||
if thRes.Delivery == nil || *thRes.Delivery != headers.TransportDeliveryMulticast {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
|
||||
}
|
||||
|
||||
if thRes.Ports == nil {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderNoPorts{}
|
||||
}
|
||||
|
||||
if thRes.Destination == nil {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderNoDestination{}
|
||||
}
|
||||
|
||||
@@ -1858,72 +1917,65 @@ func (c *Client) doSetup(
|
||||
var intf *net.Interface
|
||||
intf, err = interfaceOfConn(c.nconn)
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = cm.createUDPListeners(
|
||||
udpRTPListener, udpRTCPListener, err = createUDPListenerPair(
|
||||
c,
|
||||
true,
|
||||
intf,
|
||||
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[0]), 10)),
|
||||
net.JoinHostPort(thRes.Destination.String(), strconv.FormatInt(int64(thRes.Ports[1]), 10)),
|
||||
)
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cm.udpRTPListener.readIP = remoteIP
|
||||
cm.udpRTPListener.readPort = thRes.Ports[0]
|
||||
cm.udpRTPListener.writeAddr = &net.UDPAddr{
|
||||
udpRTPListener.readIP = remoteIP
|
||||
udpRTPListener.readPort = thRes.Ports[0]
|
||||
udpRTPListener.writeAddr = &net.UDPAddr{
|
||||
IP: remoteIP,
|
||||
Port: thRes.Ports[0],
|
||||
}
|
||||
|
||||
cm.udpRTCPListener.readIP = remoteIP
|
||||
cm.udpRTCPListener.readPort = thRes.Ports[1]
|
||||
cm.udpRTCPListener.writeAddr = &net.UDPAddr{
|
||||
udpRTCPListener.readIP = remoteIP
|
||||
udpRTCPListener.readPort = thRes.Ports[1]
|
||||
udpRTCPListener.writeAddr = &net.UDPAddr{
|
||||
IP: remoteIP,
|
||||
Port: thRes.Ports[1],
|
||||
}
|
||||
|
||||
case TransportTCP:
|
||||
if thRes.Protocol != headers.TransportProtocolTCP {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientServerRequestedUDP{}
|
||||
}
|
||||
|
||||
if thRes.Delivery != nil && *thRes.Delivery != headers.TransportDeliveryUnicast {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderInvalidDelivery{}
|
||||
}
|
||||
|
||||
if thRes.InterleavedIDs == nil {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderNoInterleavedIDs{}
|
||||
}
|
||||
|
||||
if (thRes.InterleavedIDs[0] + 1) != thRes.InterleavedIDs[1] {
|
||||
cm.close()
|
||||
return nil, liberrors.ErrClientTransportHeaderInvalidInterleavedIDs{}
|
||||
}
|
||||
|
||||
if c.isChannelPairInUse(thRes.InterleavedIDs[0]) {
|
||||
cm.close()
|
||||
return &base.Response{
|
||||
StatusCode: base.StatusBadRequest,
|
||||
}, liberrors.ErrClientTransportHeaderInterleavedIDsInUse{}
|
||||
}
|
||||
|
||||
cm.tcpChannel = thRes.InterleavedIDs[0]
|
||||
tcpChannel = thRes.InterleavedIDs[0]
|
||||
}
|
||||
|
||||
if thRes.Profile != th.Profile {
|
||||
cm.close()
|
||||
return nil, fmt.Errorf("returned profile does not match requested profile")
|
||||
}
|
||||
|
||||
if cm.secure {
|
||||
if isSecure(th.Profile) {
|
||||
var mikeyMsg *mikey.Message
|
||||
|
||||
// extract key-mgmt from (in order of priority):
|
||||
@@ -1935,7 +1987,6 @@ func (c *Client) doSetup(
|
||||
var keyMgmt headers.KeyMgmt
|
||||
err = keyMgmt.Unmarshal(res.Header["KeyMgmt"])
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
mikeyMsg = keyMgmt.MikeyMessage
|
||||
@@ -1950,13 +2001,28 @@ func (c *Client) doSetup(
|
||||
return nil, fmt.Errorf("server did not provide key-mgmt data in any supported way")
|
||||
}
|
||||
|
||||
cm.srtpInCtx, err = mikeyToContext(mikeyMsg)
|
||||
srtpInCtx, err = mikeyToContext(mikeyMsg)
|
||||
if err != nil {
|
||||
cm.close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cm := &clientMedia{
|
||||
c: c,
|
||||
media: medi,
|
||||
secure: isSecure(th.Profile),
|
||||
udpRTPListener: udpRTPListener,
|
||||
udpRTCPListener: udpRTCPListener,
|
||||
tcpChannel: tcpChannel,
|
||||
localSSRCs: localSSRCs,
|
||||
srtpInCtx: srtpInCtx,
|
||||
srtpOutCtx: srtpOutCtx,
|
||||
}
|
||||
cm.initialize()
|
||||
|
||||
udpRTPListener = nil
|
||||
udpRTCPListener = nil
|
||||
|
||||
if c.setuppedMedias == nil {
|
||||
c.setuppedMedias = make(map[*description.Media]*clientMedia)
|
||||
}
|
||||
@@ -2309,13 +2375,6 @@ func (c *Client) WritePacketRTPWithNTP(medi *description.Media, pkt *rtp.Packet,
|
||||
default:
|
||||
}
|
||||
|
||||
c.writerMutex.RLock()
|
||||
defer c.writerMutex.RUnlock()
|
||||
|
||||
if c.writer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cm := c.setuppedMedias[medi]
|
||||
cf := cm.formats[pkt.PayloadType]
|
||||
return cf.writePacketRTP(pkt, ntp)
|
||||
@@ -2329,13 +2388,6 @@ func (c *Client) WritePacketRTCP(medi *description.Media, pkt rtcp.Packet) error
|
||||
default:
|
||||
}
|
||||
|
||||
c.writerMutex.RLock()
|
||||
defer c.writerMutex.RUnlock()
|
||||
|
||||
if c.writer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cm := c.setuppedMedias[medi]
|
||||
return cm.writePacketRTCP(pkt)
|
||||
}
|
||||
|
Reference in New Issue
Block a user