diff --git a/pkg/bytecounter/bytecounter.go b/pkg/bytecounter/bytecounter.go new file mode 100644 index 00000000..88f2826e --- /dev/null +++ b/pkg/bytecounter/bytecounter.go @@ -0,0 +1,45 @@ +// Package bytecounter contains a io.ReadWriter wrapper that allows to count read and written bytes. +package bytecounter + +import ( + "io" + "sync/atomic" +) + +// ByteCounter is a io.ReadWriter wrapper that allows to count read and written bytes. +type ByteCounter struct { + rw io.ReadWriter + read uint64 + written uint64 +} + +// New allocates a ByteCounter. +func New(rw io.ReadWriter) *ByteCounter { + return &ByteCounter{ + rw: rw, + } +} + +// Read implements io.ReadWriter. +func (bc *ByteCounter) Read(p []byte) (int, error) { + n, err := bc.rw.Read(p) + atomic.AddUint64(&bc.read, uint64(n)) + return n, err +} + +// Write implements io.ReadWriter. +func (bc *ByteCounter) Write(p []byte) (int, error) { + n, err := bc.rw.Write(p) + atomic.AddUint64(&bc.written, uint64(n)) + return n, err +} + +// ReadBytes returns the number of read bytes. +func (bc *ByteCounter) ReadBytes() uint64 { + return atomic.LoadUint64(&bc.read) +} + +// WrittenBytes returns the number of written bytes. +func (bc *ByteCounter) WrittenBytes() uint64 { + return atomic.LoadUint64(&bc.written) +} diff --git a/pkg/bytecounter/bytecounter_test.go b/pkg/bytecounter/bytecounter_test.go new file mode 100644 index 00000000..4f3e9137 --- /dev/null +++ b/pkg/bytecounter/bytecounter_test.go @@ -0,0 +1,20 @@ +package bytecounter + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestByteCounter(t *testing.T) { + bc := New(bytes.NewBuffer(nil)) + + bc.Write([]byte{0x01, 0x02, 0x03, 0x04}) + + buf := make([]byte, 2) + bc.Read(buf) + + require.Equal(t, uint64(4), bc.WrittenBytes()) + require.Equal(t, uint64(2), bc.ReadBytes()) +} diff --git a/pkg/conn/conn.go b/pkg/conn/conn.go index 75bace64..fb538ed8 100644 --- a/pkg/conn/conn.go +++ b/pkg/conn/conn.go @@ -1,4 +1,4 @@ -// Package conn contains a RTSP TCP connection implementation. +// Package conn contains a RTSP connection implementation. package conn import ( @@ -12,7 +12,7 @@ const ( readBufferSize = 4096 ) -// Conn is a RTSP TCP connection. +// Conn is a RTSP connection. type Conn struct { w io.Writer br *bufio.Reader diff --git a/serverconn.go b/serverconn.go index 89fe719a..a563f787 100644 --- a/serverconn.go +++ b/serverconn.go @@ -8,11 +8,13 @@ import ( "net" gourl "net/url" "strings" + "sync/atomic" "time" "github.com/pion/rtcp" "github.com/aler9/gortsplib/pkg/base" + "github.com/aler9/gortsplib/pkg/bytecounter" "github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/url" @@ -39,6 +41,7 @@ type ServerConn struct { ctxCancel func() userData interface{} remoteAddr *net.TCPAddr + bc *bytecounter.ByteCounter conn *conn.Conn session *ServerSession readFunc func(readRequest chan readReq) error @@ -56,16 +59,14 @@ func newServerConn( ) *ServerConn { ctx, ctxCancel := context.WithCancel(s.ctx) - nconn = func() net.Conn { - if s.TLSConfig != nil { - return tls.Server(nconn, s.TLSConfig) - } - return nconn - }() + if s.TLSConfig != nil { + nconn = tls.Server(nconn, s.TLSConfig) + } sc := &ServerConn{ s: s, nconn: nconn, + bc: bytecounter.New(nconn), ctx: ctx, ctxCancel: ctxCancel, remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), @@ -92,6 +93,16 @@ func (sc *ServerConn) NetConn() net.Conn { return sc.nconn } +// ReadBytes returns the number of read bytes. +func (sc *ServerConn) ReadBytes() uint64 { + return sc.bc.ReadBytes() +} + +// WrittenBytes returns the number of written bytes. +func (sc *ServerConn) WrittenBytes() uint64 { + return sc.bc.WrittenBytes() +} + // SetUserData sets some user data associated to the connection. func (sc *ServerConn) SetUserData(v interface{}) { sc.userData = v @@ -120,7 +131,7 @@ func (sc *ServerConn) run() { }) } - sc.conn = conn.NewConn(sc.nconn) + sc.conn = conn.NewConn(sc.bc) readRequest := make(chan readReq) readErr := make(chan error) @@ -315,6 +326,8 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error { isRTP = false } + atomic.AddUint64(&sc.session.readBytes, uint64(len(twhat.Payload))) + // forward frame only if it has been set up if track, ok := sc.session.tcpTracksByChannel[channel]; ok { err := processFunc(track, isRTP, twhat.Payload) diff --git a/serversession.go b/serversession.go index bc7fd1d6..d3939126 100644 --- a/serversession.go +++ b/serversession.go @@ -155,6 +155,8 @@ type ServerSession struct { ctx context.Context ctxCancel func() + readBytes uint64 + writtenBytes uint64 userData interface{} conns map[*ServerConn]struct{} state ServerSessionState @@ -215,6 +217,16 @@ func (ss *ServerSession) Close() error { return nil } +// ReadBytes returns the number of read bytes. +func (ss *ServerSession) ReadBytes() uint64 { + return atomic.LoadUint64(&ss.readBytes) +} + +// WrittenBytes returns the number of written bytes. +func (ss *ServerSession) WrittenBytes() uint64 { + return atomic.LoadUint64(&ss.writtenBytes) +} + // State returns the state of the session. func (ss *ServerSession) State() ServerSessionState { return ss.state @@ -1184,6 +1196,8 @@ func (ss *ServerSession) runWriter() { } data := tmp.(trackTypePayload) + atomic.AddUint64(&ss.writtenBytes, uint64(len(data.payload))) + writeFunc(data.trackID, data.isRTP, data.payload) } } diff --git a/serverudpl.go b/serverudpl.go index 2f1c780b..2489c3c1 100644 --- a/serverudpl.go +++ b/serverudpl.go @@ -201,7 +201,11 @@ func (u *serverUDPListener) runReader() { } func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { - if len(payload) == (maxPacketSize + 1) { + plen := len(payload) + + atomic.AddUint64(&clientData.session.readBytes, uint64(plen)) + + if plen == (maxPacketSize + 1) { onDecodeError(clientData.session, fmt.Errorf("RTP packet is too big to be read with UDP")) return } @@ -240,7 +244,11 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { } func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { - if len(payload) == (maxPacketSize + 1) { + plen := len(payload) + + atomic.AddUint64(&clientData.session.readBytes, uint64(plen)) + + if plen == (maxPacketSize + 1) { onDecodeError(clientData.session, fmt.Errorf("RTCP packet is too big to be read with UDP")) return }