server: add methods to get read and written bytes (#147)

ServerConn.ReadBytes()
ServerConn.WrittenBytes()
ServerSession.ReadBytes()
ServerSession.WrittenBytes()
This commit is contained in:
Alessandro Ros
2022-11-09 13:41:48 +01:00
committed by GitHub
parent 9029c3a9a3
commit 95d1562735
6 changed files with 111 additions and 11 deletions

View File

@@ -0,0 +1,45 @@
// Package bytecounter contains a io.ReadWriter wrapper that allows to count read and written bytes.
package bytecounter
import (
"io"
"sync/atomic"
)
// ByteCounter is a io.ReadWriter wrapper that allows to count read and written bytes.
type ByteCounter struct {
rw io.ReadWriter
read uint64
written uint64
}
// New allocates a ByteCounter.
func New(rw io.ReadWriter) *ByteCounter {
return &ByteCounter{
rw: rw,
}
}
// Read implements io.ReadWriter.
func (bc *ByteCounter) Read(p []byte) (int, error) {
n, err := bc.rw.Read(p)
atomic.AddUint64(&bc.read, uint64(n))
return n, err
}
// Write implements io.ReadWriter.
func (bc *ByteCounter) Write(p []byte) (int, error) {
n, err := bc.rw.Write(p)
atomic.AddUint64(&bc.written, uint64(n))
return n, err
}
// ReadBytes returns the number of read bytes.
func (bc *ByteCounter) ReadBytes() uint64 {
return atomic.LoadUint64(&bc.read)
}
// WrittenBytes returns the number of written bytes.
func (bc *ByteCounter) WrittenBytes() uint64 {
return atomic.LoadUint64(&bc.written)
}

View File

@@ -0,0 +1,20 @@
package bytecounter
import (
"bytes"
"testing"
"github.com/stretchr/testify/require"
)
func TestByteCounter(t *testing.T) {
bc := New(bytes.NewBuffer(nil))
bc.Write([]byte{0x01, 0x02, 0x03, 0x04})
buf := make([]byte, 2)
bc.Read(buf)
require.Equal(t, uint64(4), bc.WrittenBytes())
require.Equal(t, uint64(2), bc.ReadBytes())
}

View File

@@ -1,4 +1,4 @@
// Package conn contains a RTSP TCP connection implementation. // Package conn contains a RTSP connection implementation.
package conn package conn
import ( import (
@@ -12,7 +12,7 @@ const (
readBufferSize = 4096 readBufferSize = 4096
) )
// Conn is a RTSP TCP connection. // Conn is a RTSP connection.
type Conn struct { type Conn struct {
w io.Writer w io.Writer
br *bufio.Reader br *bufio.Reader

View File

@@ -8,11 +8,13 @@ import (
"net" "net"
gourl "net/url" gourl "net/url"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/aler9/gortsplib/pkg/base" "github.com/aler9/gortsplib/pkg/base"
"github.com/aler9/gortsplib/pkg/bytecounter"
"github.com/aler9/gortsplib/pkg/conn" "github.com/aler9/gortsplib/pkg/conn"
"github.com/aler9/gortsplib/pkg/liberrors" "github.com/aler9/gortsplib/pkg/liberrors"
"github.com/aler9/gortsplib/pkg/url" "github.com/aler9/gortsplib/pkg/url"
@@ -39,6 +41,7 @@ type ServerConn struct {
ctxCancel func() ctxCancel func()
userData interface{} userData interface{}
remoteAddr *net.TCPAddr remoteAddr *net.TCPAddr
bc *bytecounter.ByteCounter
conn *conn.Conn conn *conn.Conn
session *ServerSession session *ServerSession
readFunc func(readRequest chan readReq) error readFunc func(readRequest chan readReq) error
@@ -56,16 +59,14 @@ func newServerConn(
) *ServerConn { ) *ServerConn {
ctx, ctxCancel := context.WithCancel(s.ctx) ctx, ctxCancel := context.WithCancel(s.ctx)
nconn = func() net.Conn { if s.TLSConfig != nil {
if s.TLSConfig != nil { nconn = tls.Server(nconn, s.TLSConfig)
return tls.Server(nconn, s.TLSConfig) }
}
return nconn
}()
sc := &ServerConn{ sc := &ServerConn{
s: s, s: s,
nconn: nconn, nconn: nconn,
bc: bytecounter.New(nconn),
ctx: ctx, ctx: ctx,
ctxCancel: ctxCancel, ctxCancel: ctxCancel,
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr), remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
@@ -92,6 +93,16 @@ func (sc *ServerConn) NetConn() net.Conn {
return sc.nconn return sc.nconn
} }
// ReadBytes returns the number of read bytes.
func (sc *ServerConn) ReadBytes() uint64 {
return sc.bc.ReadBytes()
}
// WrittenBytes returns the number of written bytes.
func (sc *ServerConn) WrittenBytes() uint64 {
return sc.bc.WrittenBytes()
}
// SetUserData sets some user data associated to the connection. // SetUserData sets some user data associated to the connection.
func (sc *ServerConn) SetUserData(v interface{}) { func (sc *ServerConn) SetUserData(v interface{}) {
sc.userData = v sc.userData = v
@@ -120,7 +131,7 @@ func (sc *ServerConn) run() {
}) })
} }
sc.conn = conn.NewConn(sc.nconn) sc.conn = conn.NewConn(sc.bc)
readRequest := make(chan readReq) readRequest := make(chan readReq)
readErr := make(chan error) readErr := make(chan error)
@@ -315,6 +326,8 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
isRTP = false isRTP = false
} }
atomic.AddUint64(&sc.session.readBytes, uint64(len(twhat.Payload)))
// forward frame only if it has been set up // forward frame only if it has been set up
if track, ok := sc.session.tcpTracksByChannel[channel]; ok { if track, ok := sc.session.tcpTracksByChannel[channel]; ok {
err := processFunc(track, isRTP, twhat.Payload) err := processFunc(track, isRTP, twhat.Payload)

View File

@@ -155,6 +155,8 @@ type ServerSession struct {
ctx context.Context ctx context.Context
ctxCancel func() ctxCancel func()
readBytes uint64
writtenBytes uint64
userData interface{} userData interface{}
conns map[*ServerConn]struct{} conns map[*ServerConn]struct{}
state ServerSessionState state ServerSessionState
@@ -215,6 +217,16 @@ func (ss *ServerSession) Close() error {
return nil return nil
} }
// ReadBytes returns the number of read bytes.
func (ss *ServerSession) ReadBytes() uint64 {
return atomic.LoadUint64(&ss.readBytes)
}
// WrittenBytes returns the number of written bytes.
func (ss *ServerSession) WrittenBytes() uint64 {
return atomic.LoadUint64(&ss.writtenBytes)
}
// State returns the state of the session. // State returns the state of the session.
func (ss *ServerSession) State() ServerSessionState { func (ss *ServerSession) State() ServerSessionState {
return ss.state return ss.state
@@ -1184,6 +1196,8 @@ func (ss *ServerSession) runWriter() {
} }
data := tmp.(trackTypePayload) data := tmp.(trackTypePayload)
atomic.AddUint64(&ss.writtenBytes, uint64(len(data.payload)))
writeFunc(data.trackID, data.isRTP, data.payload) writeFunc(data.trackID, data.isRTP, data.payload)
} }
} }

View File

@@ -201,7 +201,11 @@ func (u *serverUDPListener) runReader() {
} }
func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) { func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
if len(payload) == (maxPacketSize + 1) { plen := len(payload)
atomic.AddUint64(&clientData.session.readBytes, uint64(plen))
if plen == (maxPacketSize + 1) {
onDecodeError(clientData.session, fmt.Errorf("RTP packet is too big to be read with UDP")) onDecodeError(clientData.session, fmt.Errorf("RTP packet is too big to be read with UDP"))
return return
} }
@@ -240,7 +244,11 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
} }
func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) { func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {
if len(payload) == (maxPacketSize + 1) { plen := len(payload)
atomic.AddUint64(&clientData.session.readBytes, uint64(plen))
if plen == (maxPacketSize + 1) {
onDecodeError(clientData.session, fmt.Errorf("RTCP packet is too big to be read with UDP")) onDecodeError(clientData.session, fmt.Errorf("RTCP packet is too big to be read with UDP"))
return return
} }