feat: 使用引用计数器管理合并写切片的生命周期

This commit is contained in:
ydajiang
2025-04-18 10:58:09 +08:00
parent f3774f2151
commit bedf402ab4
17 changed files with 349 additions and 385 deletions

View File

@@ -3,6 +3,7 @@ package stream
import (
"fmt"
"github.com/lkmio/avformat"
"github.com/lkmio/avformat/collections"
"github.com/lkmio/lkm/log"
"github.com/lkmio/transport"
"net"
@@ -135,7 +136,6 @@ type PublishSource struct {
TransStreams map[TransStreamID]TransStream // 所有输出流
sinks map[SinkID]Sink // 保存所有Sink
slowSinks map[SinkID]Sink // 因推流慢被挂起的sink队列
TransStreamSinks map[TransStreamID]map[SinkID]Sink // 输出流对应的Sink
streamEndInfo *StreamEndInfo // 之前推流源信息
accumulateTimestamps bool // 是否累加时间戳
@@ -199,7 +199,6 @@ func (s *PublishSource) Init(receiveQueueSize int) {
s.TransStreams = make(map[TransStreamID]TransStream, 10)
s.sinks = make(map[SinkID]Sink, 128)
s.slowSinks = make(map[SinkID]Sink, 12)
s.TransStreamSinks = make(map[TransStreamID]map[SinkID]Sink, len(transStreamFactories)+1)
s.statistics = NewBitrateStatistics()
@@ -318,7 +317,7 @@ func (s *PublishSource) DispatchPacket(transStream TransStream, packet *avformat
}
// DispatchBuffer 分发传输流
func (s *PublishSource) DispatchBuffer(transStream TransStream, index int, data [][]byte, timestamp int64, videoKey bool) {
func (s *PublishSource) DispatchBuffer(transStream TransStream, index int, data []*collections.ReferenceCounter[[]byte], timestamp int64, videoKey bool) {
sinks := s.TransStreamSinks[transStream.GetID()]
exist := transStream.IsExistVideo()
@@ -331,76 +330,42 @@ func (s *PublishSource) DispatchBuffer(transStream TransStream, index int, data
}
if extraData, _, _ := transStream.ReadExtraData(timestamp); len(extraData) > 0 {
s.write(transStream, sink, index, extraData, timestamp)
if ok := s.write(sink, index, extraData, timestamp); !ok {
continue
}
}
}
s.write(transStream, sink, index, data, timestamp)
if ok := s.write(sink, index, data, timestamp); !ok {
continue
}
}
}
func (s *PublishSource) pendingSink(sink Sink) {
if s.existVideo {
log.Sugar.Errorf("向sink推流超时,挂起%s-sink: %s source: %s", sink.GetProtocol().String(), sink.GetID(), s.ID)
// 等待下个关键帧恢复推流
s.PauseStreaming(sink)
} else {
log.Sugar.Errorf("向sink推流超时,关闭连接. %s-sink: %s source: %s", sink.GetProtocol().String(), sink.GetID(), s.ID)
go sink.Close()
}
log.Sugar.Errorf("向sink推流超时,关闭连接. %s-sink: %s source: %s", sink.GetProtocol().String(), sink.GetID(), s.ID)
go sink.Close()
}
// 向sink推流
func (s *PublishSource) write(transStream TransStream, sink Sink, index int, data [][]byte, timestamp int64) {
func (s *PublishSource) write(sink Sink, index int, data []*collections.ReferenceCounter[[]byte], timestamp int64) bool {
err := sink.Write(index, data, timestamp)
ok := err == nil
defer func() {
if ok {
sink.IncreaseSentPacketCount()
}
}()
// 跳过非TCP流和待发送包数量小于合并写缓冲区大小的sink
if !transStream.IsTCPStreaming() || sink.PendingSendQueueSize() <= transStream.Capacity() {
return
}
// 尝试扩容合并写缓冲区, 不能扩容, 则挂起Sink
if !transStream.GrowMWBuffer() {
ok = false
s.pendingSink(sink)
if err == nil {
sink.IncreaseSentPacketCount()
return true
}
// 推流超时, 可能是服务器或拉流端带宽不够、拉流端不读取数据等情况造成内核发送缓冲区满, 进而阻塞.
// 直接关闭连接. 当然也可以将sink先挂起, 后续再继续推流.
//_, ok := err.(*transport.ZeroWindowSizeError)
//if ok {
// s.pendingSink(sink)
//}
}
func (s *PublishSource) PauseStreaming(sink Sink) {
s.cleanupSinkStreaming(sink)
s.slowSinks[sink.GetID()] = sink
}
func (s *PublishSource) ResumeStreaming() {
for id, sink := range s.sinks {
if !sink.IsExited() {
continue
}
delete(s.slowSinks, id)
ok := s.doAddSink(sink)
if ok {
go sink.Close()
}
if _, ok := err.(transport.ZeroWindowSizeError); ok {
s.pendingSink(sink)
}
return false
}
// 创建sink需要的输出流
func (s *PublishSource) doAddSink(sink Sink) bool {
func (s *PublishSource) doAddSink(sink Sink, resume bool) bool {
// 暂时不考虑多路视频流意味着只能1路视频流和多路音频流同理originStreams和allStreams里面的Stream互斥. 同时多路音频流的Codec必须一致
audioCodecId, videoCodecId := sink.DesiredAudioCodecId(), sink.DesiredVideoCodecId()
audioTrack := s.originTracks.FindWithType(utils.AVMediaTypeAudio)
@@ -478,7 +443,7 @@ func (s *PublishSource) doAddSink(sink Sink) bool {
}
// 累加拉流计数
if s.recordSink != sink {
if !resume && s.recordSink != sink {
s.sinkCount++
log.Sugar.Infof("sink count: %d source: %s", s.sinkCount, s.ID)
}
@@ -495,7 +460,7 @@ func (s *PublishSource) doAddSink(sink Sink) bool {
// TCP拉流开启异步发包, 一旦出现网络不好的链路, 其余正常链路不受影响.
_, ok := sink.GetConn().(*transport.Conn)
if ok && sink.IsTCPStreaming() {
sink.EnableAsyncWriteMode(64)
sink.EnableAsyncWriteMode(24)
}
// 发送已有的缓存数据
@@ -503,10 +468,10 @@ func (s *PublishSource) doAddSink(sink Sink) bool {
data, timestamp, _ := transStream.ReadKeyFrameBuffer()
if len(data) > 0 {
if extraData, _, _ := transStream.ReadExtraData(timestamp); len(extraData) > 0 {
s.write(transStream, sink, 0, extraData, timestamp)
s.write(sink, 0, extraData, timestamp)
}
s.write(transStream, sink, 0, data, timestamp)
s.write(sink, 0, data, timestamp)
}
// 新建传输流,发送已经缓存的音视频帧
@@ -522,7 +487,7 @@ func (s *PublishSource) AddSink(sink Sink) {
if !s.completed {
AddSinkToWaitingQueue(sink.GetSourceID(), sink)
} else {
if !s.doAddSink(sink) {
if !s.doAddSink(sink, false) {
go sink.Close()
}
}
@@ -585,7 +550,6 @@ func (s *PublishSource) cleanupSinkStreaming(sink Sink) {
func (s *PublishSource) doRemoveSink(sink Sink) bool {
s.cleanupSinkStreaming(sink)
delete(s.sinks, sink.GetID())
delete(s.slowSinks, sink.GetID())
s.sinkCount--
log.Sugar.Infof("sink count: %d source: %s", s.sinkCount, s.ID)
@@ -655,6 +619,13 @@ func (s *PublishSource) DoClose() {
if len(data) > 0 {
s.DispatchBuffer(transStream, -1, data, ts, true)
}
// 如果是tcp传输流, 归还合并写缓冲区
if !transStream.IsTCPStreaming() || transStream.GetMWBuffer() == nil {
continue
} else if buffers := transStream.GetMWBuffer().Close(); buffers != nil {
AddMWBuffersToPending(s.ID, transStream.GetID(), buffers)
}
}
// 将所有sink添加到等待队列
@@ -768,7 +739,7 @@ func (s *PublishSource) writeHeader() {
}
for _, sink := range sinks {
if !s.doAddSink(sink) {
if !s.doAddSink(sink, false) {
go sink.Close()
}
}
@@ -880,11 +851,6 @@ func (s *PublishSource) OnPacket(packet *avformat.AVPacket) {
s.gopBuffer.AddPacket(packet)
}
// 遇到关键帧, 恢复推流
if utils.AVMediaTypeVideo == packet.MediaType && packet.Key && len(s.slowSinks) > 0 {
s.ResumeStreaming()
}
// track解析完毕后才能生成传输流
if s.completed {
s.CorrectTimestamp(packet)