Add ConnectionIdleTimeout to RTMP server

This commit is contained in:
Ingo Oppermann
2024-05-29 16:16:10 +02:00
parent ad8d214805
commit d6a80c28e5
6 changed files with 111 additions and 32 deletions

View File

@@ -874,12 +874,13 @@ func (a *api) start() error {
a.log.logger.rtmp = a.log.logger.core.WithComponent("RTMP").WithField("address", cfg.RTMP.Address) a.log.logger.rtmp = a.log.logger.core.WithComponent("RTMP").WithField("address", cfg.RTMP.Address)
config := rtmp.Config{ config := rtmp.Config{
Addr: cfg.RTMP.Address, Addr: cfg.RTMP.Address,
TLSAddr: cfg.RTMP.AddressTLS, TLSAddr: cfg.RTMP.AddressTLS,
App: cfg.RTMP.App, App: cfg.RTMP.App,
Token: cfg.RTMP.Token, Token: cfg.RTMP.Token,
Logger: a.log.logger.rtmp, Logger: a.log.logger.rtmp,
Collector: a.sessions.Collector("rtmp"), ConnectionIdleTimeout: 10 * time.Second,
Collector: a.sessions.Collector("rtmp"),
} }
if cfg.RTMP.EnableTLS { if cfg.RTMP.EnableTLS {

2
go.mod
View File

@@ -10,7 +10,7 @@ require (
github.com/atrox/haikunatorgo/v2 v2.0.1 github.com/atrox/haikunatorgo/v2 v2.0.1
github.com/caddyserver/certmagic v0.21.2 github.com/caddyserver/certmagic v0.21.2
github.com/datarhei/gosrt v0.6.0 github.com/datarhei/gosrt v0.6.0
github.com/datarhei/joy4 v0.0.0-20240528121836-da80d79b6435 github.com/datarhei/joy4 v0.0.0-20240529125512-3aeb406414d6
github.com/go-playground/validator/v10 v10.20.0 github.com/go-playground/validator/v10 v10.20.0
github.com/gobwas/glob v0.2.3 github.com/gobwas/glob v0.2.3
github.com/golang-jwt/jwt/v5 v5.2.1 github.com/golang-jwt/jwt/v5 v5.2.1

4
go.sum
View File

@@ -30,8 +30,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lV
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/datarhei/gosrt v0.6.0 h1:HrrXAw90V78ok4WMIhX6se1aTHPCn82Sg2hj+PhdmGc= github.com/datarhei/gosrt v0.6.0 h1:HrrXAw90V78ok4WMIhX6se1aTHPCn82Sg2hj+PhdmGc=
github.com/datarhei/gosrt v0.6.0/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs= github.com/datarhei/gosrt v0.6.0/go.mod h1:fsOWdLSHUHShHjgi/46h6wjtdQrtnSdAQFnlas8ONxs=
github.com/datarhei/joy4 v0.0.0-20240528121836-da80d79b6435 h1:bXcqdyQWtKyb1i82qLMqi4DxbVrZMpk0oVmKtWJjWhg= github.com/datarhei/joy4 v0.0.0-20240529125512-3aeb406414d6 h1:qrAUWrwNUUj8Desdima+jg4xwymQ2b7khI0fm+e4BAw=
github.com/datarhei/joy4 v0.0.0-20240528121836-da80d79b6435/go.mod h1:Jcw/6jZDQQmPx8A7INEkXmuEF7E9jjBbSTfVSLwmiQw= github.com/datarhei/joy4 v0.0.0-20240529125512-3aeb406414d6/go.mod h1:Jcw/6jZDQQmPx8A7INEkXmuEF7E9jjBbSTfVSLwmiQw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=

View File

@@ -195,6 +195,10 @@ type Config struct {
// ListenAndServe, so it's not possible to modify the configuration // ListenAndServe, so it's not possible to modify the configuration
// with methods like tls.Config.SetSessionTicketKeys. // with methods like tls.Config.SetSessionTicketKeys.
TLSConfig *tls.Config TLSConfig *tls.Config
// ConnectionIdleTimeout is the timeout in seconds after which idle
// connection will be closes. Default is no timeout.
ConnectionIdleTimeout time.Duration
} }
// Server represents a RTMP server // Server represents a RTMP server
@@ -252,17 +256,19 @@ func New(config Config) (Server, error) {
} }
s.server = &rtmp.Server{ s.server = &rtmp.Server{
Addr: config.Addr, Addr: config.Addr,
HandlePlay: s.handlePlay, HandlePlay: s.handlePlay,
HandlePublish: s.handlePublish, HandlePublish: s.handlePublish,
ConnectionIdleTimeout: config.ConnectionIdleTimeout,
} }
if len(config.TLSAddr) != 0 { if len(config.TLSAddr) != 0 {
s.tlsServer = &rtmp.Server{ s.tlsServer = &rtmp.Server{
Addr: config.TLSAddr, Addr: config.TLSAddr,
TLSConfig: config.TLSConfig.Clone(), TLSConfig: config.TLSConfig.Clone(),
HandlePlay: s.handlePlay, HandlePlay: s.handlePlay,
HandlePublish: s.handlePublish, HandlePublish: s.handlePublish,
ConnectionIdleTimeout: config.ConnectionIdleTimeout,
} }
} }

View File

@@ -53,6 +53,7 @@ func DialTimeout(uri string, timeout time.Duration) (conn *Conn, err error) {
conn = NewConn(netconn) conn = NewConn(netconn)
conn.URL = u conn.URL = u
return return
} }
@@ -65,9 +66,10 @@ type Server struct {
HandlePlay func(*Conn) HandlePlay func(*Conn)
HandleConn func(*Conn) HandleConn func(*Conn)
MaxProbePacketCount int MaxProbePacketCount int
SkipInvalidMessages bool SkipInvalidMessages bool
DebugChunks func(conn net.Conn) bool DebugChunks func(conn net.Conn) bool
ConnectionIdleTimeout time.Duration
listener net.Listener listener net.Listener
doneChan chan struct{} doneChan chan struct{}
@@ -80,6 +82,7 @@ func (s *Server) HandleNetConn(netconn net.Conn) (err error) {
if s.DebugChunks != nil { if s.DebugChunks != nil {
conn.debugChunks = s.DebugChunks(netconn) conn.debugChunks = s.DebugChunks(netconn)
} }
conn.netconn.SetIdleTimeout(s.ConnectionIdleTimeout)
conn.isserver = true conn.isserver = true
err = s.handleConn(conn) err = s.handleConn(conn)
@@ -101,10 +104,12 @@ func (s *Server) handleConn(conn *Conn) (err error) {
} }
if conn.playing { if conn.playing {
conn.netconn.SetReadIdleTimeout(0)
if s.HandlePlay != nil { if s.HandlePlay != nil {
s.HandlePlay(conn) s.HandlePlay(conn)
} }
} else if conn.publishing { } else if conn.publishing {
conn.netconn.SetWriteIdleTimeout(0)
if s.HandlePublish != nil { if s.HandlePublish != nil {
s.HandlePublish(conn) s.HandlePublish(conn)
} }
@@ -204,14 +209,12 @@ func (s *Server) Serve(listener net.Listener) error {
} }
func (s *Server) Close() { func (s *Server) Close() {
if s.listener == nil { if s.listener != nil {
return s.listener.Close()
s.listener = nil
} }
close(s.doneChan) close(s.doneChan)
s.listener.Close()
s.listener = nil
} }
const ( const (
@@ -239,7 +242,7 @@ type Conn struct {
writebuf []byte writebuf []byte
readbuf []byte readbuf []byte
netconn net.Conn netconn *idleConn
txrxcount *txrxcount txrxcount *txrxcount
writeMaxChunkSize int writeMaxChunkSize int
@@ -278,6 +281,61 @@ type Conn struct {
debugChunks bool debugChunks bool
} }
type idleConn struct {
net.Conn
ReadIdleTimeout time.Duration
WriteIdleTimeout time.Duration
}
func (t *idleConn) Read(p []byte) (int, error) {
if t.ReadIdleTimeout > 0 {
t.Conn.SetReadDeadline(time.Now().Add(t.ReadIdleTimeout))
}
n, err := t.Conn.Read(p)
return n, err
}
func (t *idleConn) Write(p []byte) (int, error) {
if t.WriteIdleTimeout > 0 {
t.Conn.SetWriteDeadline(time.Now().Add(t.WriteIdleTimeout))
}
n, err := t.Conn.Write(p)
return n, err
}
func (t *idleConn) SetReadIdleTimeout(d time.Duration) error {
t.ReadIdleTimeout = d
deadline := time.Time{}
if t.ReadIdleTimeout > 0 {
deadline = time.Now().Add(d)
}
return t.Conn.SetReadDeadline(deadline)
}
func (t *idleConn) SetWriteIdleTimeout(d time.Duration) error {
t.WriteIdleTimeout = d
deadline := time.Time{}
if t.WriteIdleTimeout > 0 {
deadline = time.Now().Add(d)
}
return t.Conn.SetWriteDeadline(deadline)
}
func (t *idleConn) SetIdleTimeout(d time.Duration) error {
err := t.SetReadIdleTimeout(d)
if err != nil {
return err
}
return t.SetWriteIdleTimeout(d)
}
type txrxcount struct { type txrxcount struct {
io.ReadWriter io.ReadWriter
txbytes uint64 txbytes uint64
@@ -299,12 +357,14 @@ func (t *txrxcount) Write(p []byte) (int, error) {
func NewConn(netconn net.Conn) *Conn { func NewConn(netconn net.Conn) *Conn {
conn := &Conn{} conn := &Conn{}
conn.prober = &flv.Prober{} conn.prober = &flv.Prober{}
conn.netconn = netconn conn.netconn = &idleConn{
Conn: netconn,
}
conn.readcsmap = make(map[uint32]*chunkStream) conn.readcsmap = make(map[uint32]*chunkStream)
conn.readMaxChunkSize = 128 conn.readMaxChunkSize = 128
conn.writeMaxChunkSize = 128 conn.writeMaxChunkSize = 128
conn.readAckSize = 1048576 conn.readAckSize = 1048576
conn.txrxcount = &txrxcount{ReadWriter: netconn} conn.txrxcount = &txrxcount{ReadWriter: conn.netconn}
conn.bufr = bufio.NewReaderSize(conn.txrxcount, pio.RecommendBufioSize) conn.bufr = bufio.NewReaderSize(conn.txrxcount, pio.RecommendBufioSize)
conn.bufw = bufio.NewWriterSize(conn.txrxcount, pio.RecommendBufioSize) conn.bufw = bufio.NewWriterSize(conn.txrxcount, pio.RecommendBufioSize)
conn.writebuf = make([]byte, 4096) conn.writebuf = make([]byte, 4096)
@@ -358,7 +418,19 @@ const (
) )
func (conn *Conn) NetConn() net.Conn { func (conn *Conn) NetConn() net.Conn {
return conn.netconn return conn.netconn.Conn
}
func (conn *Conn) SetReadIdleTimeout(d time.Duration) error {
return conn.netconn.SetReadIdleTimeout(d)
}
func (conn *Conn) SetWriteIdleTimeout(d time.Duration) error {
return conn.netconn.SetWriteIdleTimeout(d)
}
func (conn *Conn) SetIdleTimeout(d time.Duration) error {
return conn.netconn.SetIdleTimeout(d)
} }
func (conn *Conn) TxBytes() uint64 { func (conn *Conn) TxBytes() uint64 {
@@ -1421,7 +1493,7 @@ func (conn *Conn) readChunk() (err error) {
cs.msgdatalen = pio.U24BE(h[3:6]) cs.msgdatalen = pio.U24BE(h[3:6])
cs.msgtypeid = h[6] cs.msgtypeid = h[6]
cs.msgsid = pio.U32LE(h[7:11]) cs.msgsid = pio.U32LE(h[7:11])
if timestamp == 0xffffff { if timestamp == FlvTimestampMax {
if _, err = io.ReadFull(conn.bufr, b[:4]); err != nil { if _, err = io.ReadFull(conn.bufr, b[:4]); err != nil {
return return
} }
@@ -1464,7 +1536,7 @@ func (conn *Conn) readChunk() (err error) {
cs.msghdrtype = msghdrtype cs.msghdrtype = msghdrtype
cs.msgdatalen = pio.U24BE(h[3:6]) cs.msgdatalen = pio.U24BE(h[3:6])
cs.msgtypeid = h[6] cs.msgtypeid = h[6]
if timestamp == 0xffffff { if timestamp == FlvTimestampMax {
if _, err = io.ReadFull(conn.bufr, b[:4]); err != nil { if _, err = io.ReadFull(conn.bufr, b[:4]); err != nil {
return return
} }
@@ -1504,7 +1576,7 @@ func (conn *Conn) readChunk() (err error) {
n += len(h) n += len(h)
cs.msghdrtype = msghdrtype cs.msghdrtype = msghdrtype
timestamp = pio.U24BE(h[0:3]) timestamp = pio.U24BE(h[0:3])
if timestamp == 0xffffff { if timestamp == FlvTimestampMax {
if _, err = io.ReadFull(conn.bufr, b[:4]); err != nil { if _, err = io.ReadFull(conn.bufr, b[:4]); err != nil {
return return
} }

2
vendor/modules.txt vendored
View File

@@ -64,7 +64,7 @@ github.com/datarhei/gosrt/crypto
github.com/datarhei/gosrt/net github.com/datarhei/gosrt/net
github.com/datarhei/gosrt/packet github.com/datarhei/gosrt/packet
github.com/datarhei/gosrt/rand github.com/datarhei/gosrt/rand
# github.com/datarhei/joy4 v0.0.0-20240528121836-da80d79b6435 # github.com/datarhei/joy4 v0.0.0-20240529125512-3aeb406414d6
## explicit; go 1.14 ## explicit; go 1.14
github.com/datarhei/joy4/av github.com/datarhei/joy4/av
github.com/datarhei/joy4/av/avutil github.com/datarhei/joy4/av/avutil