diff --git a/gb28181/source.go b/gb28181/source.go index fd85454..c533653 100644 --- a/gb28181/source.go +++ b/gb28181/source.go @@ -101,10 +101,11 @@ func (source *BaseGBSource) Input(data []byte) error { // 非解析缓冲区满的错误, 继续解析 if err != nil { - log.Sugar.Errorf("解析ps流发生err: %s source: %s", err.Error(), source.GetID()) if strings.HasPrefix(err.Error(), "probe") { return err } + + log.Sugar.Errorf("解析ps流发生err: %s source: %s", err.Error(), source.GetID()) } source.probeBuffer.Reset(n) @@ -139,7 +140,7 @@ func (source *BaseGBSource) correctTimestamp(packet *avformat.AVPacket, dts, pts packet.Duration = duration } else { // 时间戳不正确 - log.Sugar.Errorf("推流时间戳不正确, 使用系统时钟. ssrc: %x", source.ssrc) + log.Sugar.Errorf("推流时间戳不正确, 使用系统时钟. source: %s ssrc: %x duration: %d", source.ID, source.ssrc, duration) source.isSystemClock = true } } else { @@ -292,16 +293,16 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, int, } } - var bufferBlockCount int + var queueSize int if active || tcp { - bufferBlockCount = stream.ReceiveBufferTCPBlockCount + queueSize = stream.TCPReceiveBufferQueueSize } else { - bufferBlockCount = stream.ReceiveBufferUdpBlockCount + queueSize = stream.UDPReceiveBufferQueueSize } source.SetID(id) source.SetSSRC(ssrc) - source.Init(bufferBlockCount) + source.Init(queueSize) if _, state := stream.PreparePublishSource(source, false); utils.HookStateOK != state { return nil, 0, fmt.Errorf("error code %d", state) } diff --git a/gb28181/source_udp.go b/gb28181/source_udp.go index 83ac0b5..1c6c294 100644 --- a/gb28181/source_udp.go +++ b/gb28181/source_udp.go @@ -9,8 +9,7 @@ import ( type UDPSource struct { BaseGBSource - jitterBuffer *stream.JitterBuffer[*rtp.Packet] - receiveBuffer *stream.ReceiveBuffer + jitterBuffer *stream.JitterBuffer[*rtp.Packet] } func (u *UDPSource) SetupType() SetupType { @@ -25,7 +24,7 @@ func (u *UDPSource) OnOrderedRtp(packet *rtp.Packet) { // InputRtpPacket 将RTP包排序后,交给Source的主协程处理 func (u *UDPSource) InputRtpPacket(pkt *rtp.Packet) error { - block := u.receiveBuffer.GetBlock() + block := stream.UDPReceiveBufferPool.Get().([]byte) copy(block, pkt.Raw) pkt.Raw = block[:len(pkt.Raw)] @@ -47,7 +46,6 @@ func (u *UDPSource) Close() { func NewUDPSource() *UDPSource { return &UDPSource{ - receiveBuffer: stream.NewReceiveBuffer(1500, stream.ReceiveBufferUdpBlockCount+50), - jitterBuffer: stream.NewJitterBuffer[*rtp.Packet](), + jitterBuffer: stream.NewJitterBuffer[*rtp.Packet](), } } diff --git a/gb28181/tcp_server.go b/gb28181/tcp_server.go index 19ef37f..93d843e 100644 --- a/gb28181/tcp_server.go +++ b/gb28181/tcp_server.go @@ -35,13 +35,7 @@ func (T *TCPServer) OnCloseSession(session *TCPSession) { func (T *TCPServer) OnConnected(conn net.Conn) []byte { T.StreamServer.OnConnected(conn) - - //TCP单端口收流, Session已经绑定Source, 使用ReceiveBuffer读取网络包 - if conn.(*transport.Conn).Data.(*TCPSession).source != nil { - return conn.(*transport.Conn).Data.(*TCPSession).receiveBuffer.GetBlock() - } - - return nil + return stream.TCPReceiveBufferPool.Get().([]byte) } func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte { @@ -69,12 +63,7 @@ func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte { } } - // 绑定Source后, 使用ReceiveBuffer读取网络包, 减少拷贝 - if session.source != nil { - return session.receiveBuffer.GetBlock() - } - - return nil + return stream.TCPReceiveBufferPool.Get().([]byte) } func NewTCPServer(filter Filter) (*TCPServer, error) { diff --git a/gb28181/tcp_session.go b/gb28181/tcp_session.go index f2b0b62..7497851 100644 --- a/gb28181/tcp_session.go +++ b/gb28181/tcp_session.go @@ -10,16 +10,13 @@ import ( // TCPSession 国标TCP主被动推流Session, 统一处理TCP粘包. type TCPSession struct { - conn net.Conn - source GBSource - receiveBuffer *stream.ReceiveBuffer - decoder *transport.LengthFieldFrameDecoder + conn net.Conn + source GBSource + decoder *transport.LengthFieldFrameDecoder } func (t *TCPSession) Init(source GBSource) { t.source = source - // 创建收流缓冲区 - t.receiveBuffer = stream.NewTCPReceiveBuffer() } func (t *TCPSession) Close() { diff --git a/jt1078/jt_server.go b/jt1078/jt_server.go index c900839..9095931 100644 --- a/jt1078/jt_server.go +++ b/jt1078/jt_server.go @@ -31,7 +31,7 @@ func (s *jtServer) OnPacket(conn net.Conn, data []byte) []byte { s.StreamServer.OnPacket(conn, data) session := conn.(*transport.Conn).Data.(*Session) session.PublishSource.Input(data) - return session.receiveBuffer.GetBlock() + return stream.TCPReceiveBufferPool.Get().([]byte) } func (s *jtServer) Start(addr net.Addr) error { diff --git a/jt1078/jt_session.go b/jt1078/jt_session.go index 7be5ec7..6fa14e5 100644 --- a/jt1078/jt_session.go +++ b/jt1078/jt_session.go @@ -11,8 +11,7 @@ import ( type Session struct { stream.PublishSource - decoder *transport.DelimiterFrameDecoder - receiveBuffer *stream.ReceiveBuffer + decoder *transport.DelimiterFrameDecoder } func (s *Session) Input(data []byte) error { @@ -73,12 +72,11 @@ func NewSession(conn net.Conn) *Session { TransDemuxer: NewDemuxer(), }, - decoder: transport.NewDelimiterFrameDecoder(1024*1024*2, delimiter[:]), - receiveBuffer: stream.NewTCPReceiveBuffer(), + decoder: transport.NewDelimiterFrameDecoder(1024*1024*2, delimiter[:]), } session.TransDemuxer.SetHandler(&session) - session.Init(stream.ReceiveBufferTCPBlockCount) + session.Init(stream.TCPReceiveBufferQueueSize) go stream.LoopEvent(&session) return &session } diff --git a/rtmp/rtmp_server.go b/rtmp/rtmp_server.go index ee2e86d..6d602db 100644 --- a/rtmp/rtmp_server.go +++ b/rtmp/rtmp_server.go @@ -66,7 +66,7 @@ func (s *server) OnPacket(conn net.Conn, data []byte) []byte { } if session.isPublisher { - return session.receiveBuffer.GetBlock() + return stream.TCPReceiveBufferPool.Get().([]byte) } return nil diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index aceadca..acf22a0 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -14,8 +14,6 @@ type Session struct { stack *rtmp.ServerStack // rtmp协议栈, 解析message handle interface{} // 持有具体会话句柄(推流端/拉流端), 在@see OnPublish @see OnPlay回调中赋值 isPublisher bool // 是否是推流会话 - - receiveBuffer *stream.ReceiveBuffer // 推流源收流队列 } func (s *Session) generateSourceID(app, stream string) string { @@ -37,7 +35,7 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState { source := NewPublisher(sourceId, s.stack, s.conn) // 初始化放在add source前面, 以防add后再init, 空窗期拉流队列空指针. - source.Init(stream.ReceiveBufferTCPBlockCount) + source.Init(stream.TCPReceiveBufferQueueSize) source.SetUrlValues(values) // 统一处理source推流事件, source是否已经存在, hook回调.... @@ -47,7 +45,6 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState { } else { s.handle = source s.isPublisher = true - s.receiveBuffer = stream.NewTCPReceiveBuffer() go stream.LoopEvent(source) } @@ -105,7 +102,6 @@ func (s *Session) Close() { if s.isPublisher { publisher.Close() - s.receiveBuffer = nil } } else { sink := s.handle.(*Sink) diff --git a/stream/config.go b/stream/config.go index 2c6ab49..634f7cd 100644 --- a/stream/config.go +++ b/stream/config.go @@ -310,7 +310,7 @@ func SetDefaultConfig(config *AppConfig_) { } config.MergeWriteLatency = limitInt(350, 2000, config.MergeWriteLatency) // 最低缓存350毫秒数据才发送 最高缓存2秒数据才发送 - config.ProbeTimeout = limitInt(2000, 5000, config.MergeWriteLatency) // 2-5秒内必须解析完AVStream + config.ProbeTimeout = limitInt(2000, 5000, config.ProbeTimeout) // 2-5秒内必须解析完AVStream config.Log.Level = limitInt(int(zapcore.DebugLevel), int(zapcore.FatalLevel), config.Log.Level) config.Log.MaxSize = limitMin(1, config.Log.MaxSize) diff --git a/stream/hook_source.go b/stream/hook_source.go index 648b2d9..82d68bc 100644 --- a/stream/hook_source.go +++ b/stream/hook_source.go @@ -55,34 +55,14 @@ func HookPublishDoneEvent(source Source) { } } -func HookReceiveTimeoutEvent(source Source) (*http.Response, utils.HookState) { - var response *http.Response - - if AppConfig.Hooks.IsEnableOnReceiveTimeout() { - resp, err := Hook(HookEventReceiveTimeout, source.UrlValues().Encode(), NewHookPublishEventInfo(source)) - if err != nil { - return resp, utils.HookStateFailure - } - - response = resp - } - - return response, utils.HookStateOK +func HookReceiveTimeoutEvent(source Source) (*http.Response, error) { + utils.Assert(AppConfig.Hooks.IsEnableOnReceiveTimeout()) + return Hook(HookEventReceiveTimeout, source.UrlValues().Encode(), NewHookPublishEventInfo(source)) } -func HookIdleTimeoutEvent(source Source) (*http.Response, utils.HookState) { - var response *http.Response - - if AppConfig.Hooks.IsEnableOnIdleTimeout() { - resp, err := Hook(HookEventIdleTimeout, source.UrlValues().Encode(), NewHookPublishEventInfo(source)) - if err != nil { - return resp, utils.HookStateFailure - } - - response = resp - } - - return response, utils.HookStateOK +func HookIdleTimeoutEvent(source Source) (*http.Response, error) { + utils.Assert(AppConfig.Hooks.IsEnableOnIdleTimeout()) + return Hook(HookEventIdleTimeout, source.UrlValues().Encode(), NewHookPublishEventInfo(source)) } func HookRecordEvent(source Source, path string) { diff --git a/stream/receive_buffer.go b/stream/receive_buffer.go index a68d8fa..9465d98 100644 --- a/stream/receive_buffer.go +++ b/stream/receive_buffer.go @@ -1,9 +1,28 @@ package stream -const ( - ReceiveBufferUdpBlockCount = 300 +import "sync" - ReceiveBufferTCPBlockCount = 50 +const ( + UDPReceiveBufferSize = 1500 + TCPReceiveBufferSize = 4096 * 20 + + UDPReceiveBufferQueueSize = 1000 + TCPReceiveBufferQueueSize = 50 +) + +// 后续考虑使用cas队列实现 +var ( + UDPReceiveBufferPool = sync.Pool{ + New: func() any { + return make([]byte, UDPReceiveBufferSize) + }, + } + + TCPReceiveBufferPool = sync.Pool{ + New: func() any { + return make([]byte, TCPReceiveBufferSize) + }, + } ) // ReceiveBuffer 收流缓冲区. 网络收流->解析流->封装流->发送流是同步的,从解析到发送可能耗时,从而影响读取网络流. 使用收流缓冲区,可有效降低出现此问题的概率. @@ -38,11 +57,3 @@ func (r *ReceiveBuffer) BlockCount() int { func NewReceiveBuffer(blockSize, blockCount int) *ReceiveBuffer { return &ReceiveBuffer{blockCapacity: blockSize, blockCount: blockCount, data: make([]byte, blockSize*blockCount), index: 0} } - -func NewUDPReceiveBuffer() *ReceiveBuffer { - return NewReceiveBuffer(1500, ReceiveBufferUdpBlockCount) -} - -func NewTCPReceiveBuffer() *ReceiveBuffer { - return NewReceiveBuffer(4096*20, ReceiveBufferTCPBlockCount) -} diff --git a/stream/source.go b/stream/source.go index 4a5e272..a98fc9f 100644 --- a/stream/source.go +++ b/stream/source.go @@ -60,8 +60,6 @@ type Source interface { // 将Sink添加到等待队列 Close() - DoClose() - // IsCompleted 所有推流track是否解析完毕 IsCompleted() bool @@ -110,6 +108,8 @@ type Source interface { GetTransStreams() map[TransStreamID]TransStream GetStreamEndInfo() *StreamEndInfo + + ProbeTimeout() } type PublishSource struct { @@ -129,11 +129,9 @@ type PublishSource struct { gopBuffer GOPBuffer // GOP缓存, 音频和视频混合使用, 以视频关键帧为界, 缓存第二个视频关键帧时, 释放前一组gop. 如果不存在视频流, 不缓存音频 closed atomic.Bool // source是否已经关闭 - completed bool // 所有推流track是否解析完毕, @see writeHeader 函数中赋值为true + completed atomic.Bool // 所有推流track是否解析完毕, @see writeHeader 函数中赋值为true existVideo bool // 是否存在视频 - probeTimer *time.Timer // track解析超时计时器, 触发时执行@see writeHeader - TransStreams map[TransStreamID]TransStream // 所有输出流 sinks map[SinkID]Sink // 保存所有Sink TransStreamSinks map[TransStreamID]map[SinkID]Sink // 输出流对应的Sink @@ -201,17 +199,8 @@ func (s *PublishSource) Init(receiveQueueSize int) { s.sinks = make(map[SinkID]Sink, 128) s.TransStreamSinks = make(map[TransStreamID]map[SinkID]Sink, len(transStreamFactories)+1) s.statistics = NewBitrateStatistics() - - // 此处设置的探测时长, 只是为了保证在probeTimeout触发前, 一直在探测 - s.TransDemuxer.SetProbeDuration(60000) - - // 启动探测超时计时器 - s.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, func() { - s.PostEvent(func() { - // s.writeHeader() - s.TransDemuxer.ProbeComplete() - }) - }) + // 设置探测时长 + s.TransDemuxer.SetProbeDuration(AppConfig.ProbeTimeout) } func (s *PublishSource) CreateDefaultOutStreams() { @@ -484,7 +473,7 @@ func (s *PublishSource) doAddSink(sink Sink, resume bool) bool { func (s *PublishSource) AddSink(sink Sink) { s.PostEvent(func() { - if !s.completed { + if !s.completed.Load() { AddSinkToWaitingQueue(sink.GetSourceID(), sink) } else { if !s.doAddSink(sink, false) { @@ -586,11 +575,6 @@ func (s *PublishSource) DoClose() { s.TransDemuxer = nil } - // 停止track探测计时器 - if s.probeTimer != nil { - s.probeTimer.Stop() - } - // 关闭录制流 if s.recordSink != nil { s.recordSink.Close() @@ -693,15 +677,12 @@ func (s *PublishSource) Close() { // 解析完所有track后, 创建各种输出流 func (s *PublishSource) writeHeader() { - if s.completed { + if s.completed.Load() { fmt.Printf("添加Stream失败 Source: %s已经WriteHeader", s.ID) return } - s.completed = true - if s.probeTimer != nil { - s.probeTimer.Stop() - } + s.completed.Store(true) if len(s.originTracks.All()) == 0 { log.Sugar.Errorf("没有一路track, 删除source: %s", s.ID) @@ -746,7 +727,7 @@ func (s *PublishSource) writeHeader() { } func (s *PublishSource) IsCompleted() bool { - return s.completed + return s.completed.Load() } // NotTrackAdded 返回该index对应的track是否没有添加 @@ -793,7 +774,7 @@ func (s *PublishSource) OnNewTrack(track avformat.Track) { stream := track.GetStream() - if s.completed { + if s.completed.Load() { log.Sugar.Warnf("添加%s track失败,已经WriteHeader. source: %s", stream.MediaType, s.ID) return } else if !s.NotTrackAdded(stream.Index) { @@ -826,6 +807,8 @@ func (s *PublishSource) OnTrackNotFind() { if AppConfig.Debug { s.streamLogger.OnTrackNotFind() } + + log.Sugar.Errorf("no tracks found. source id: %s", s.ID) } func (s *PublishSource) OnPacket(packet *avformat.AVPacket) { @@ -852,7 +835,7 @@ func (s *PublishSource) OnPacket(packet *avformat.AVPacket) { } // track解析完毕后,才能生成传输流 - if s.completed { + if s.completed.Load() { s.CorrectTimestamp(packet) // 分发给各个传输流 @@ -939,3 +922,7 @@ func (s *PublishSource) GetTransStreams() map[TransStreamID]TransStream { func (s *PublishSource) GetStreamEndInfo() *StreamEndInfo { return s.streamEndInfo } + +func (s *PublishSource) ProbeTimeout() { + s.TransDemuxer.ProbeComplete() +} diff --git a/stream/source_utils.go b/stream/source_utils.go index 1bcf1ce..529649a 100644 --- a/stream/source_utils.go +++ b/stream/source_utils.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/lkmio/avformat/utils" "github.com/lkmio/lkm/log" + "net/http" "net/url" "strings" "time" @@ -218,6 +219,7 @@ func ParseUrl(name string) (string, url.Values) { //} // StartReceiveDataTimer 启动收流超时计时器 +// 收流超时, 客观上认为是流中断, 应该关闭Source. 如果开启了Hook, 并且Hook返回200应答, 则不关闭Source. func StartReceiveDataTimer(source Source) *time.Timer { utils.Assert(AppConfig.ReceiveTimeout > 0) @@ -225,14 +227,17 @@ func StartReceiveDataTimer(source Source) *time.Timer { receiveDataTimer = time.AfterFunc(time.Duration(AppConfig.ReceiveTimeout), func() { dis := time.Now().Sub(source.LastPacketTime()) - // 如果开启Hook通知, 根据响应决定是否关闭Source - // 如果通知失败, 或者非200应答, 释放Source - // 如果没有开启Hook通知, 直接删除 if dis >= time.Duration(AppConfig.ReceiveTimeout) { log.Sugar.Errorf("收流超时 source: %s", source.GetID()) - response, state := HookReceiveTimeoutEvent(source) - if utils.HookStateOK != state || response == nil { + var shouldClose = true + if AppConfig.Hooks.IsEnableOnReceiveTimeout() { + // 此处参考返回值err, 客观希望关闭Source + response, err := HookReceiveTimeoutEvent(source) + shouldClose = !(err == nil && response != nil && http.StatusOK == response.StatusCode) + } + + if shouldClose { source.Close() return } @@ -246,8 +251,10 @@ func StartReceiveDataTimer(source Source) *time.Timer { } // StartIdleTimer 启动拉流空闲计时器 +// 拉流空闲, 不应该关闭Source. 如果开启了Hook, 并且Hook返回非200应答, 则关闭Source. func StartIdleTimer(source Source) *time.Timer { utils.Assert(AppConfig.IdleTimeout > 0) + utils.Assert(AppConfig.Hooks.IsEnableOnIdleTimeout()) var idleTimer *time.Timer idleTimer = time.AfterFunc(time.Duration(AppConfig.IdleTimeout), func() { @@ -256,8 +263,9 @@ func StartIdleTimer(source Source) *time.Timer { if source.SinkCount() < 1 && dis >= time.Duration(AppConfig.IdleTimeout) { log.Sugar.Errorf("拉流空闲超时 source: %s", source.GetID()) - response, state := HookIdleTimeoutEvent(source) - if utils.HookStateOK != state || response == nil { + // 此处不参考返回值err, 客观希望不关闭Source + response, _ := HookIdleTimeoutEvent(source) + if response != nil && http.StatusOK != response.StatusCode { source.Close() return } @@ -274,6 +282,7 @@ func LoopEvent(source Source) { // 将超时计时器放在此处开启, 方便在退出的时候关闭 var receiveTimer *time.Timer var idleTimer *time.Timer + var probeTimer *time.Timer defer func() { log.Sugar.Debugf("主协程执行结束 source: %s", source.GetID()) @@ -282,13 +291,28 @@ func LoopEvent(source Source) { if receiveTimer != nil { receiveTimer.Stop() } + if idleTimer != nil { idleTimer.Stop() } + + if probeTimer != nil { + probeTimer.Stop() + } + + // 未使用的数据, 放回池中 + for len(source.StreamPipe()) > 0 { + data := <-source.StreamPipe() + if size := cap(data); size > UDPReceiveBufferSize { + TCPReceiveBufferPool.Put(data[:size]) + } else { + UDPReceiveBufferPool.Put(data[:size]) + } + } }() // 开启收流超时计时器 - if AppConfig.Hooks.IsEnableOnReceiveTimeout() && AppConfig.ReceiveTimeout > 0 { + if AppConfig.ReceiveTimeout > 0 { receiveTimer = StartReceiveDataTimer(source) } @@ -297,6 +321,24 @@ func LoopEvent(source Source) { idleTimer = StartIdleTimer(source) } + // 开启探测超时计时器 + probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, func() { + if source.IsCompleted() { + return + } + + var ok bool + source.PostEvent(func() { + source.ProbeTimeout() + ok = len(source.OriginTracks()) > 0 + }) + + if !ok { + source.Close() + return + } + }) + for { select { // 读取推流数据 @@ -307,10 +349,16 @@ func LoopEvent(source Source) { if err := source.Input(data); err != nil { log.Sugar.Errorf("解析推流数据发生err: %s 释放source: %s", err.Error(), source.GetID()) - source.DoClose() + go source.Close() return } + // 使用后, 放回池中 + if size := cap(data); size > UDPReceiveBufferSize { + TCPReceiveBufferPool.Put(data[:size]) + } else { + UDPReceiveBufferPool.Put(data[:size]) + } break // 切换到主协程,执行该函数. 目的是用于无锁化处理推拉流的连接与断开, 推流源断开, 查询推流源信息等事件. 不要做耗时操作, 否则会影响推拉流. case event := <-source.MainContextEvents():