From a2df9d83b31c5a537bd266736339baec97c90fd9 Mon Sep 17 00:00:00 2001 From: Alessandro Ros Date: Fri, 13 Dec 2024 20:55:50 +0100 Subject: [PATCH] client: fix BytesSent / BytesReceived computation (#612) (#654) When the TCP transport protocol is in use, BytesSent and BytesReceived were increased twice. --- client_media.go | 10 ---------- client_play_test.go | 5 +++++ client_record_test.go | 7 +++++++ client_udp_listener.go | 4 ++++ server_play_test.go | 18 +++++++++++++++--- server_record_test.go | 14 ++++++++++++-- 6 files changed, 43 insertions(+), 15 deletions(-) diff --git a/client_media.go b/client_media.go index 1390d073..dca30f95 100644 --- a/client_media.go +++ b/client_media.go @@ -153,24 +153,20 @@ func (cm *clientMedia) findFormatWithSSRC(ssrc uint32) *clientFormat { } func (cm *clientMedia) writePacketRTPInQueueUDP(payload []byte) { - atomic.AddUint64(cm.c.BytesSent, uint64(len(payload))) cm.udpRTPListener.write(payload) //nolint:errcheck } func (cm *clientMedia) writePacketRTCPInQueueUDP(payload []byte) { - atomic.AddUint64(cm.c.BytesSent, uint64(len(payload))) cm.udpRTCPListener.write(payload) //nolint:errcheck } func (cm *clientMedia) writePacketRTPInQueueTCP(payload []byte) { - atomic.AddUint64(cm.c.BytesSent, uint64(len(payload))) cm.tcpRTPFrame.Payload = payload cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) cm.c.conn.WriteInterleavedFrame(cm.tcpRTPFrame, cm.tcpBuffer) //nolint:errcheck } func (cm *clientMedia) writePacketRTCPInQueueTCP(payload []byte) { - atomic.AddUint64(cm.c.BytesSent, uint64(len(payload))) cm.tcpRTCPFrame.Payload = payload cm.c.nconn.SetWriteDeadline(time.Now().Add(cm.c.WriteTimeout)) cm.c.conn.WriteInterleavedFrame(cm.tcpRTCPFrame, cm.tcpBuffer) //nolint:errcheck @@ -264,8 +260,6 @@ func (cm *clientMedia) readRTCPTCPRecord(payload []byte) bool { func (cm *clientMedia) readRTPUDPPlay(payload []byte) bool { plen := len(payload) - atomic.AddUint64(cm.c.BytesReceived, uint64(plen)) - if plen == (udpMaxPayloadSize + 1) { cm.c.OnDecodeError(liberrors.ErrClientRTPPacketTooBigUDP{}) return false @@ -293,8 +287,6 @@ func (cm *clientMedia) readRTCPUDPPlay(payload []byte) bool { now := cm.c.timeNow() plen := len(payload) - atomic.AddUint64(cm.c.BytesReceived, uint64(plen)) - if plen == (udpMaxPayloadSize + 1) { cm.c.OnDecodeError(liberrors.ErrClientRTCPPacketTooBigUDP{}) return false @@ -327,8 +319,6 @@ func (cm *clientMedia) readRTPUDPRecord(_ []byte) bool { func (cm *clientMedia) readRTCPUDPRecord(payload []byte) bool { plen := len(payload) - atomic.AddUint64(cm.c.BytesReceived, uint64(plen)) - if plen == (udpMaxPayloadSize + 1) { cm.c.OnDecodeError(liberrors.ErrClientRTCPPacketTooBigUDP{}) return false diff --git a/client_play_test.go b/client_play_test.go index 722617c9..5d896a73 100644 --- a/client_play_test.go +++ b/client_play_test.go @@ -544,6 +544,11 @@ func TestClientPlay(t *testing.T) { require.NoError(t, err) <-packetRecv + + require.Greater(t, atomic.LoadUint64(c.BytesSent), uint64(620)) + require.Less(t, atomic.LoadUint64(c.BytesSent), uint64(850)) + require.Greater(t, atomic.LoadUint64(c.BytesReceived), uint64(580)) + require.Less(t, atomic.LoadUint64(c.BytesReceived), uint64(650)) }) } } diff --git a/client_record_test.go b/client_record_test.go index 8fb35eca..df3e5276 100644 --- a/client_record_test.go +++ b/client_record_test.go @@ -7,6 +7,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "testing" "time" @@ -334,6 +335,12 @@ func TestClientRecordSerial(t *testing.T) { require.NoError(t, err) <-recvDone + + require.Greater(t, atomic.LoadUint64(c.BytesSent), uint64(730)) + require.Less(t, atomic.LoadUint64(c.BytesSent), uint64(760)) + require.Greater(t, atomic.LoadUint64(c.BytesReceived), uint64(180)) + require.Less(t, atomic.LoadUint64(c.BytesReceived), uint64(210)) + c.Close() <-done diff --git a/client_udp_listener.go b/client_udp_listener.go index 6c0fc575..9285faa0 100644 --- a/client_udp_listener.go +++ b/client_udp_listener.go @@ -173,6 +173,8 @@ func (u *clientUDPListener) run() { now := u.c.timeNow() atomic.StoreInt64(u.lastPacketTime, now.Unix()) + atomic.AddUint64(u.c.BytesReceived, uint64(n)) + if u.readFunc(buf[:n]) { createNewBuffer() } @@ -180,6 +182,8 @@ func (u *clientUDPListener) run() { } func (u *clientUDPListener) write(payload []byte) error { + atomic.AddUint64(u.c.BytesSent, uint64(len(payload))) + // no mutex is needed here since Write() has an internal lock. // https://github.com/golang/go/issues/27203#issuecomment-534386117 u.pc.SetWriteDeadline(time.Now().Add(u.c.WriteTimeout)) diff --git a/server_play_test.go b/server_play_test.go index fb07a70f..2bd5072a 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -588,9 +588,9 @@ func TestServerPlaySetupErrorSameUDPPortsAndIP(t *testing.T) { func TestServerPlay(t *testing.T) { for _, transport := range []string{ "udp", + "multicast", "tcp", "tls", - "multicast", } { t.Run(transport, func(t *testing.T) { var stream *ServerStream @@ -608,13 +608,25 @@ func TestServerPlay(t *testing.T) { onConnOpen: func(_ *ServerHandlerOnConnOpenCtx) { close(nconnOpened) }, - onConnClose: func(_ *ServerHandlerOnConnCloseCtx) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Greater(t, ctx.Conn.BytesSent(), uint64(810)) + require.Less(t, ctx.Conn.BytesSent(), uint64(1150)) + require.Greater(t, ctx.Conn.BytesReceived(), uint64(440)) + require.Less(t, ctx.Conn.BytesReceived(), uint64(660)) + close(nconnClosed) }, onSessionOpen: func(_ *ServerHandlerOnSessionOpenCtx) { close(sessionOpened) }, - onSessionClose: func(_ *ServerHandlerOnSessionCloseCtx) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { + if transport != "multicast" { + require.Greater(t, ctx.Session.BytesSent(), uint64(50)) + require.Less(t, ctx.Session.BytesSent(), uint64(60)) + require.Greater(t, ctx.Session.BytesReceived(), uint64(15)) + require.Less(t, ctx.Session.BytesReceived(), uint64(25)) + } + close(sessionClosed) }, onDescribe: func(_ *ServerHandlerOnDescribeCtx) (*base.Response, *ServerStream, error) { diff --git a/server_record_test.go b/server_record_test.go index f325a470..088481bf 100644 --- a/server_record_test.go +++ b/server_record_test.go @@ -544,13 +544,23 @@ func TestServerRecord(t *testing.T) { onConnOpen: func(_ *ServerHandlerOnConnOpenCtx) { close(nconnOpened) }, - onConnClose: func(_ *ServerHandlerOnConnCloseCtx) { + onConnClose: func(ctx *ServerHandlerOnConnCloseCtx) { + require.Greater(t, ctx.Conn.BytesSent(), uint64(510)) + require.Less(t, ctx.Conn.BytesSent(), uint64(560)) + require.Greater(t, ctx.Conn.BytesReceived(), uint64(1000)) + require.Less(t, ctx.Conn.BytesReceived(), uint64(1200)) + close(nconnClosed) }, onSessionOpen: func(_ *ServerHandlerOnSessionOpenCtx) { close(sessionOpened) }, - onSessionClose: func(_ *ServerHandlerOnSessionCloseCtx) { + onSessionClose: func(ctx *ServerHandlerOnSessionCloseCtx) { + require.Greater(t, ctx.Session.BytesSent(), uint64(75)) + require.Less(t, ctx.Session.BytesSent(), uint64(130)) + require.Greater(t, ctx.Session.BytesReceived(), uint64(70)) + require.Less(t, ctx.Session.BytesReceived(), uint64(80)) + close(sessionClosed) }, onAnnounce: func(_ *ServerHandlerOnAnnounceCtx) (*base.Response, error) {