mirror of
https://github.com/aler9/gortsplib
synced 2025-10-05 23:26:54 +08:00
server: add methods to get read and written bytes (#147)
ServerConn.ReadBytes() ServerConn.WrittenBytes() ServerSession.ReadBytes() ServerSession.WrittenBytes()
This commit is contained in:
45
pkg/bytecounter/bytecounter.go
Normal file
45
pkg/bytecounter/bytecounter.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
// Package bytecounter contains a io.ReadWriter wrapper that allows to count read and written bytes.
|
||||||
|
package bytecounter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ByteCounter is a io.ReadWriter wrapper that allows to count read and written bytes.
|
||||||
|
type ByteCounter struct {
|
||||||
|
rw io.ReadWriter
|
||||||
|
read uint64
|
||||||
|
written uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// New allocates a ByteCounter.
|
||||||
|
func New(rw io.ReadWriter) *ByteCounter {
|
||||||
|
return &ByteCounter{
|
||||||
|
rw: rw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read implements io.ReadWriter.
|
||||||
|
func (bc *ByteCounter) Read(p []byte) (int, error) {
|
||||||
|
n, err := bc.rw.Read(p)
|
||||||
|
atomic.AddUint64(&bc.read, uint64(n))
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write implements io.ReadWriter.
|
||||||
|
func (bc *ByteCounter) Write(p []byte) (int, error) {
|
||||||
|
n, err := bc.rw.Write(p)
|
||||||
|
atomic.AddUint64(&bc.written, uint64(n))
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadBytes returns the number of read bytes.
|
||||||
|
func (bc *ByteCounter) ReadBytes() uint64 {
|
||||||
|
return atomic.LoadUint64(&bc.read)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrittenBytes returns the number of written bytes.
|
||||||
|
func (bc *ByteCounter) WrittenBytes() uint64 {
|
||||||
|
return atomic.LoadUint64(&bc.written)
|
||||||
|
}
|
20
pkg/bytecounter/bytecounter_test.go
Normal file
20
pkg/bytecounter/bytecounter_test.go
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
package bytecounter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestByteCounter(t *testing.T) {
|
||||||
|
bc := New(bytes.NewBuffer(nil))
|
||||||
|
|
||||||
|
bc.Write([]byte{0x01, 0x02, 0x03, 0x04})
|
||||||
|
|
||||||
|
buf := make([]byte, 2)
|
||||||
|
bc.Read(buf)
|
||||||
|
|
||||||
|
require.Equal(t, uint64(4), bc.WrittenBytes())
|
||||||
|
require.Equal(t, uint64(2), bc.ReadBytes())
|
||||||
|
}
|
@@ -1,4 +1,4 @@
|
|||||||
// Package conn contains a RTSP TCP connection implementation.
|
// Package conn contains a RTSP connection implementation.
|
||||||
package conn
|
package conn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -12,7 +12,7 @@ const (
|
|||||||
readBufferSize = 4096
|
readBufferSize = 4096
|
||||||
)
|
)
|
||||||
|
|
||||||
// Conn is a RTSP TCP connection.
|
// Conn is a RTSP connection.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
w io.Writer
|
w io.Writer
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
|
@@ -8,11 +8,13 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
gourl "net/url"
|
gourl "net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pion/rtcp"
|
"github.com/pion/rtcp"
|
||||||
|
|
||||||
"github.com/aler9/gortsplib/pkg/base"
|
"github.com/aler9/gortsplib/pkg/base"
|
||||||
|
"github.com/aler9/gortsplib/pkg/bytecounter"
|
||||||
"github.com/aler9/gortsplib/pkg/conn"
|
"github.com/aler9/gortsplib/pkg/conn"
|
||||||
"github.com/aler9/gortsplib/pkg/liberrors"
|
"github.com/aler9/gortsplib/pkg/liberrors"
|
||||||
"github.com/aler9/gortsplib/pkg/url"
|
"github.com/aler9/gortsplib/pkg/url"
|
||||||
@@ -39,6 +41,7 @@ type ServerConn struct {
|
|||||||
ctxCancel func()
|
ctxCancel func()
|
||||||
userData interface{}
|
userData interface{}
|
||||||
remoteAddr *net.TCPAddr
|
remoteAddr *net.TCPAddr
|
||||||
|
bc *bytecounter.ByteCounter
|
||||||
conn *conn.Conn
|
conn *conn.Conn
|
||||||
session *ServerSession
|
session *ServerSession
|
||||||
readFunc func(readRequest chan readReq) error
|
readFunc func(readRequest chan readReq) error
|
||||||
@@ -56,16 +59,14 @@ func newServerConn(
|
|||||||
) *ServerConn {
|
) *ServerConn {
|
||||||
ctx, ctxCancel := context.WithCancel(s.ctx)
|
ctx, ctxCancel := context.WithCancel(s.ctx)
|
||||||
|
|
||||||
nconn = func() net.Conn {
|
if s.TLSConfig != nil {
|
||||||
if s.TLSConfig != nil {
|
nconn = tls.Server(nconn, s.TLSConfig)
|
||||||
return tls.Server(nconn, s.TLSConfig)
|
}
|
||||||
}
|
|
||||||
return nconn
|
|
||||||
}()
|
|
||||||
|
|
||||||
sc := &ServerConn{
|
sc := &ServerConn{
|
||||||
s: s,
|
s: s,
|
||||||
nconn: nconn,
|
nconn: nconn,
|
||||||
|
bc: bytecounter.New(nconn),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: ctxCancel,
|
ctxCancel: ctxCancel,
|
||||||
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
|
remoteAddr: nconn.RemoteAddr().(*net.TCPAddr),
|
||||||
@@ -92,6 +93,16 @@ func (sc *ServerConn) NetConn() net.Conn {
|
|||||||
return sc.nconn
|
return sc.nconn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadBytes returns the number of read bytes.
|
||||||
|
func (sc *ServerConn) ReadBytes() uint64 {
|
||||||
|
return sc.bc.ReadBytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrittenBytes returns the number of written bytes.
|
||||||
|
func (sc *ServerConn) WrittenBytes() uint64 {
|
||||||
|
return sc.bc.WrittenBytes()
|
||||||
|
}
|
||||||
|
|
||||||
// SetUserData sets some user data associated to the connection.
|
// SetUserData sets some user data associated to the connection.
|
||||||
func (sc *ServerConn) SetUserData(v interface{}) {
|
func (sc *ServerConn) SetUserData(v interface{}) {
|
||||||
sc.userData = v
|
sc.userData = v
|
||||||
@@ -120,7 +131,7 @@ func (sc *ServerConn) run() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
sc.conn = conn.NewConn(sc.nconn)
|
sc.conn = conn.NewConn(sc.bc)
|
||||||
|
|
||||||
readRequest := make(chan readReq)
|
readRequest := make(chan readReq)
|
||||||
readErr := make(chan error)
|
readErr := make(chan error)
|
||||||
@@ -315,6 +326,8 @@ func (sc *ServerConn) readFuncTCP(readRequest chan readReq) error {
|
|||||||
isRTP = false
|
isRTP = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
atomic.AddUint64(&sc.session.readBytes, uint64(len(twhat.Payload)))
|
||||||
|
|
||||||
// forward frame only if it has been set up
|
// forward frame only if it has been set up
|
||||||
if track, ok := sc.session.tcpTracksByChannel[channel]; ok {
|
if track, ok := sc.session.tcpTracksByChannel[channel]; ok {
|
||||||
err := processFunc(track, isRTP, twhat.Payload)
|
err := processFunc(track, isRTP, twhat.Payload)
|
||||||
|
@@ -155,6 +155,8 @@ type ServerSession struct {
|
|||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
ctxCancel func()
|
ctxCancel func()
|
||||||
|
readBytes uint64
|
||||||
|
writtenBytes uint64
|
||||||
userData interface{}
|
userData interface{}
|
||||||
conns map[*ServerConn]struct{}
|
conns map[*ServerConn]struct{}
|
||||||
state ServerSessionState
|
state ServerSessionState
|
||||||
@@ -215,6 +217,16 @@ func (ss *ServerSession) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadBytes returns the number of read bytes.
|
||||||
|
func (ss *ServerSession) ReadBytes() uint64 {
|
||||||
|
return atomic.LoadUint64(&ss.readBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrittenBytes returns the number of written bytes.
|
||||||
|
func (ss *ServerSession) WrittenBytes() uint64 {
|
||||||
|
return atomic.LoadUint64(&ss.writtenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
// State returns the state of the session.
|
// State returns the state of the session.
|
||||||
func (ss *ServerSession) State() ServerSessionState {
|
func (ss *ServerSession) State() ServerSessionState {
|
||||||
return ss.state
|
return ss.state
|
||||||
@@ -1184,6 +1196,8 @@ func (ss *ServerSession) runWriter() {
|
|||||||
}
|
}
|
||||||
data := tmp.(trackTypePayload)
|
data := tmp.(trackTypePayload)
|
||||||
|
|
||||||
|
atomic.AddUint64(&ss.writtenBytes, uint64(len(data.payload)))
|
||||||
|
|
||||||
writeFunc(data.trackID, data.isRTP, data.payload)
|
writeFunc(data.trackID, data.isRTP, data.payload)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -201,7 +201,11 @@ func (u *serverUDPListener) runReader() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
|
func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
|
||||||
if len(payload) == (maxPacketSize + 1) {
|
plen := len(payload)
|
||||||
|
|
||||||
|
atomic.AddUint64(&clientData.session.readBytes, uint64(plen))
|
||||||
|
|
||||||
|
if plen == (maxPacketSize + 1) {
|
||||||
onDecodeError(clientData.session, fmt.Errorf("RTP packet is too big to be read with UDP"))
|
onDecodeError(clientData.session, fmt.Errorf("RTP packet is too big to be read with UDP"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -240,7 +244,11 @@ func (u *serverUDPListener) processRTP(clientData *clientData, payload []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {
|
func (u *serverUDPListener) processRTCP(clientData *clientData, payload []byte) {
|
||||||
if len(payload) == (maxPacketSize + 1) {
|
plen := len(payload)
|
||||||
|
|
||||||
|
atomic.AddUint64(&clientData.session.readBytes, uint64(plen))
|
||||||
|
|
||||||
|
if plen == (maxPacketSize + 1) {
|
||||||
onDecodeError(clientData.session, fmt.Errorf("RTCP packet is too big to be read with UDP"))
|
onDecodeError(clientData.session, fmt.Errorf("RTCP packet is too big to be read with UDP"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user