From b58e9268195e8b6508b70fd3a9d185acb28d59f9 Mon Sep 17 00:00:00 2001 From: langhuihui <178529795@qq.com> Date: Mon, 26 Feb 2024 17:16:49 +0800 Subject: [PATCH] feat: remove sei track --- README.md | 2 - common/frame.go | 38 +-- common/index.go | 139 ---------- common/type.go | 90 +++++++ config/types.go | 1 - http.go | 31 --- io.go | 21 +- publisher-mp4.go | 10 +- publisher-rtpdump.go | 10 +- publisher-ts.go | 16 +- publisher.go | 28 +- stream.go | 100 +++----- track/aac.go | 4 +- track/audio.go | 19 +- track/av1.go | 8 +- track/base.go | 355 +++++--------------------- track/g711.go | 4 +- track/h264.go | 12 +- track/h265.go | 10 +- track/media.go | 321 +++++++++++++++++++++++ track/opus.go | 4 +- track/reader-data.go | 2 +- track/video.go | 51 ++-- {common => util}/dtsestimator.go | 6 +- {common => util}/dtsestimator_test.go | 2 +- {common => util}/ring-writer.go | 32 ++- util/safe_chan.go | 10 + 27 files changed, 656 insertions(+), 670 deletions(-) delete mode 100644 common/index.go create mode 100644 common/type.go create mode 100644 track/media.go rename {common => util}/dtsestimator.go (90%) rename {common => util}/dtsestimator_test.go (97%) rename {common => util}/ring-writer.go (68%) diff --git a/README.md b/README.md index d8ae6de..7efb1e2 100644 --- a/README.md +++ b/README.md @@ -38,7 +38,6 @@ - 获取所有向远端推流信息 `/api/list/push` 返回{RemoteURL:"",StreamPath:"",Type:"",StartTime:""} - 停止推流 `/api/stop/push?url=xxx` 停止向xxx推流 ,成功返回ok - 停止某个订阅者 `/api/stop/subscribe?streamPath=xxx&id=xxx` 停止xxx流的xxx订阅者 ,成功返回ok -- 插入SEI帧 `/api/insertsei?streamPath=xxx&type=5` 向xxx流内插入SEI帧 ,成功返回ok。type为SEI类型,可选,默认是5 # 引擎默认配置 ```yaml global: @@ -60,7 +59,6 @@ global: pubaudio: true # 是否发布音频流 pubvideo: true # 是否发布视频流 kickexist: false # 剔出已经存在的发布者,用于顶替原有发布者 - insertsei: false # 是否开启插入SEI信息功能 publishtimeout: 10s # 发布流默认过期时间,超过该时间发布者没有恢复流将被删除 delayclosetimeout: 0 # 自动关闭触发后延迟的时间(期间内如果有新的订阅则取消触发关闭),0为关闭该功能,保持连接。 waitclosetimeout: 0 # 发布者断开后等待时间,超过该时间发布者没有恢复流将被删除,0为关闭该功能,由订阅者决定是否删除 diff --git a/common/frame.go b/common/frame.go index cb1e49e..2febd7b 100644 --- a/common/frame.go +++ b/common/frame.go @@ -45,23 +45,6 @@ func (r *RTPFrame) Unmarshal(raw []byte) *RTPFrame { return r } -type IDataFrame[T any] interface { - Init() // 初始化 - Reset() // 重置数据,复用内存 - Ready() // 标记为可读取 - ReaderEnter() int32 // 读取者数量+1 - ReaderLeave() int32 // 读取者数量-1 - StartWrite() bool // 开始写入 - SetSequence(uint32) // 设置序号 - GetSequence() uint32 // 获取序号 - ReaderCount() int32 // 读取者数量 - Discard() int32 // 如果写入时还有读取者没有离开则废弃该帧,剥离RingBuffer,防止并发读写 - IsDiscarded() bool // 是否已废弃 - IsWriting() bool // 是否正在写入 - Wait() // 阻塞等待可读取 - Broadcast() // 广播可读取 -} - type DataFrame[T any] struct { DeltaTime uint32 // 相对上一帧时间戳,毫秒 WriteTime time.Time // 写入时间,可用于比较两个帧的先后 @@ -126,7 +109,7 @@ func (df *DataFrame[T]) Ready() { } func (df *DataFrame[T]) Init() { - df.L = EmptyLocker + df.L = util.EmptyLocker } func (df *DataFrame[T]) Reset() { @@ -181,6 +164,25 @@ func (av *AVFrame) Reset() { av.DataFrame.Reset() } +func (av *AVFrame) Assign(source *AVFrame) { + av.IFrame = source.IFrame + av.PTS = source.PTS + av.DTS = source.DTS + av.Timestamp = source.Timestamp + av.BytesIn = source.BytesIn + av.DeltaTime = source.DeltaTime + source.AUList.Range(func(au *util.BLL) bool { + var nau util.BLL + au.Range(func(b util.Buffer) bool { + nau.PushValue(b) + return true + }) + nau.ByteLength = au.ByteLength + av.AUList.PushValue(&nau) + return true + }) +} + type ParamaterSets [][]byte func (v ParamaterSets) GetAnnexB() (r net.Buffers) { diff --git a/common/index.go b/common/index.go deleted file mode 100644 index 2996a6c..0000000 --- a/common/index.go +++ /dev/null @@ -1,139 +0,0 @@ -package common - -import ( - "sync/atomic" - "time" - - "github.com/pion/rtp" - "go.uber.org/zap" - "m7s.live/engine/v4/log" - "m7s.live/engine/v4/util" -) - -type TimelineData[T any] struct { - Timestamp time.Time - Value T -} -type TrackState byte - -const ( - TrackStateOnline TrackState = iota // 上线 - TrackStateOffline // 下线 -) - -// Base 基础Track类 -type Base[T any, F IDataFrame[T]] struct { - RingWriter[T, F] - Name string - log.Zap `json:"-" yaml:"-"` - Stream IStream `json:"-" yaml:"-"` - Attached atomic.Bool `json:"-" yaml:"-"` - State TrackState - ts time.Time - bytes int - frames int - DropCount int `json:"-" yaml:"-"` //丢帧数 - BPS int - FPS int - Drops int // 丢帧率 - RawSize int // 裸数据长度 - RawPart []int // 裸数据片段用于UI上显示 -} - -func (bt *Base[T, F]) ComputeBPS(bytes int) { - bt.bytes += bytes - bt.frames++ - if elapse := time.Since(bt.ts).Seconds(); elapse > 1 { - bt.BPS = int(float64(bt.bytes) / elapse) - bt.FPS = int(float64(bt.frames) / elapse) - bt.Drops = int(float64(bt.DropCount) / elapse) - bt.bytes = 0 - bt.frames = 0 - bt.DropCount = 0 - bt.ts = time.Now() - } -} - -func (bt *Base[T, F]) GetName() string { - return bt.Name -} - -func (bt *Base[T, F]) GetBPS() int { - return bt.BPS -} - -func (bt *Base[T, F]) GetFPS() int { - return bt.FPS -} - -func (bt *Base[T, F]) GetDrops() int { - return bt.Drops -} - -// GetRBSize 获取缓冲区大小 -func (bt *Base[T, F]) GetRBSize() int { - return bt.RingWriter.Size -} - -func (bt *Base[T, F]) SnapForJson() { -} - -func (bt *Base[T, F]) SetStuff(stuff ...any) { - for _, s := range stuff { - switch v := s.(type) { - case IStream: - bt.Stream = v - bt.Zap = v.With(zap.String("track", bt.Name)) - case TrackState: - bt.State = v - case string: - bt.Name = v - } - } -} - -func (bt *Base[T, F]) Dispose() { - bt.Value.Broadcast() -} - -type Track interface { - GetReaderCount() int32 - GetName() string - GetBPS() int - GetFPS() int - GetDrops() int - LastWriteTime() time.Time - SnapForJson() - SetStuff(stuff ...any) - GetRBSize() int - Dispose() -} - -type AVTrack interface { - Track - PreFrame() *AVFrame - CurrentFrame() *AVFrame - Attach() - Detach() - WriteAVCC(ts uint32, frame *util.BLL) error //写入AVCC格式的数据 - WriteRTP(*util.ListItem[RTPFrame]) - WriteRTPPack(*rtp.Packet) - WriteSequenceHead(sh []byte) error - Flush() - SetSpeedLimit(time.Duration) - GetRTPFromPool() *util.ListItem[RTPFrame] - GetFromPool(util.IBytes) *util.ListItem[util.Buffer] -} -type VideoTrack interface { - AVTrack - WriteSliceBytes(slice []byte) - WriteNalu(uint32, uint32, []byte) - WriteAnnexB(uint32, uint32, []byte) - SetLostFlag() -} - -type AudioTrack interface { - AVTrack - WriteADTS(uint32, util.IBytes) - WriteRawBytes(uint32, util.IBytes) -} diff --git a/common/type.go b/common/type.go new file mode 100644 index 0000000..57f5cac --- /dev/null +++ b/common/type.go @@ -0,0 +1,90 @@ +package common + +import ( + "context" + "time" + + "github.com/pion/rtp" + "go.uber.org/zap/zapcore" + "m7s.live/engine/v4/codec" + "m7s.live/engine/v4/config" + "m7s.live/engine/v4/log" + "m7s.live/engine/v4/util" +) + +type TimelineData[T any] struct { + Timestamp time.Time + Value T +} +type TrackState byte + +const ( + TrackStateOnline TrackState = iota // 上线 + TrackStateOffline // 下线 +) + +type IIO interface { + IsClosed() bool + OnEvent(any) + Stop(reason ...zapcore.Field) + SetIO(any) + SetParentCtx(context.Context) + SetLogger(*log.Logger) + IsShutdown() bool + GetStream() IStream + log.Zap +} + +type IPuber interface { + IIO + GetAudioTrack() AudioTrack + GetVideoTrack() VideoTrack + GetConfig() *config.Publish + Publish(streamPath string, pub IPuber) error +} + +type Track interface { + GetPublisher() IPuber + GetReaderCount() int32 + GetName() string + GetBPS() int + GetFPS() int + GetDrops() int + LastWriteTime() time.Time + SnapForJson() + SetStuff(stuff ...any) + GetRBSize() int + Dispose() +} + +type AVTrack interface { + Track + PreFrame() *AVFrame + CurrentFrame() *AVFrame + Attach() + Detach() + WriteAVCC(ts uint32, frame *util.BLL) error //写入AVCC格式的数据 + WriteRTP(*util.ListItem[RTPFrame]) + WriteRTPPack(*rtp.Packet) + WriteSequenceHead(sh []byte) error + Flush() + SetSpeedLimit(time.Duration) + GetRTPFromPool() *util.ListItem[RTPFrame] + GetFromPool(util.IBytes) *util.ListItem[util.Buffer] +} +type VideoTrack interface { + AVTrack + GetCodec() codec.VideoCodecID + WriteSliceBytes(slice []byte) + WriteNalu(uint32, uint32, []byte) + WriteAnnexB(uint32, uint32, []byte) + SetLostFlag() +} + +type AudioTrack interface { + AVTrack + GetCodec() codec.AudioCodecID + WriteADTS(uint32, util.IBytes) + WriteRawBytes(uint32, util.IBytes) + Narrow() +} diff --git a/config/types.go b/config/types.go index 6c6b318..6b072d7 100755 --- a/config/types.go +++ b/config/types.go @@ -33,7 +33,6 @@ type PushConfig interface { type Publish struct { PubAudio bool `default:"true" desc:"是否发布音频"` PubVideo bool `default:"true" desc:"是否发布视频"` - InsertSEI bool `desc:"是否启用SEI插入"` // 是否启用SEI插入 KickExist bool `desc:"是否踢掉已经存在的发布者"` // 是否踢掉已经存在的发布者 PublishTimeout time.Duration `default:"10s" desc:"发布无数据超时"` // 发布无数据超时 WaitCloseTimeout time.Duration `desc:"延迟自动关闭(等待重连)"` // 延迟自动关闭(等待重连) diff --git a/http.go b/http.go index b5dfe63..c2066cf 100644 --- a/http.go +++ b/http.go @@ -2,7 +2,6 @@ package engine import ( "encoding/json" - "io" "net/http" "os" "strconv" @@ -347,33 +346,3 @@ func (conf *GlobalConfig) API_replay_mp4(w http.ResponseWriter, r *http.Request) go pub.ReadMP4Data(f) } } - -func (conf *GlobalConfig) API_insertSEI(w http.ResponseWriter, r *http.Request) { - q := r.URL.Query() - streamPath := q.Get("streamPath") - s := Streams.Get(streamPath) - if s == nil { - util.ReturnError(util.APIErrorNoStream, NO_SUCH_STREAM, w, r) - return - } - t := q.Get("type") - tb, err := strconv.ParseInt(t, 10, 8) - if err != nil { - if t == "" { - tb = 5 - } else { - util.ReturnError(util.APIErrorQueryParse, "type must a number", w, r) - return - } - } - sei, err := io.ReadAll(r.Body) - if err == nil { - if s.Tracks.AddSEI(byte(tb), sei) { - util.ReturnOK(w, r) - } else { - util.ReturnError(util.APIErrorNoSEI, "no sei track", w, r) - } - } else { - util.ReturnError(util.APIErrorNoBody, err.Error(), w, r) - } -} diff --git a/io.go b/io.go index 30423ac..79ab876 100644 --- a/io.go +++ b/io.go @@ -14,6 +14,7 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" + "m7s.live/engine/v4/common" "m7s.live/engine/v4/config" "m7s.live/engine/v4/log" "m7s.live/engine/v4/util" @@ -48,7 +49,11 @@ type IO struct { io.Writer `json:"-" yaml:"-"` io.Closer `json:"-" yaml:"-"` Args url.Values - Spesific IIO `json:"-" yaml:"-"` + Spesific common.IIO `json:"-" yaml:"-"` +} + +func (io *IO) GetStream() common.IStream { + return io.Stream } func (io *IO) IsClosed() bool { @@ -93,18 +98,6 @@ func (io *IO) IsShutdown() bool { return io.Stream.IsShutdown() } -type IIO interface { - receive(string, IIO) error - IsClosed() bool - OnEvent(any) - Stop(reason ...zapcore.Field) - SetIO(any) - SetParentCtx(context.Context) - SetLogger(*log.Logger) - IsShutdown() bool - log.Zap -} - func (i *IO) close(err StopError) bool { if i.IsClosed() { i.Warn("already closed", err...) @@ -158,7 +151,7 @@ func (io *IO) auth(key string, secret string, expire string) bool { } // receive 用于接收发布或者订阅 -func (io *IO) receive(streamPath string, specific IIO) error { +func (io *IO) receive(streamPath string, specific common.IIO) error { streamPath = strings.Trim(streamPath, "/") u, err := url.Parse(streamPath) if err != nil { diff --git a/publisher-mp4.go b/publisher-mp4.go index db3f6b9..f26dee6 100644 --- a/publisher-mp4.go +++ b/publisher-mp4.go @@ -32,15 +32,15 @@ func (p *MP4Publisher) ReadMP4Data(source io.ReadSeeker) error { p.Info("MP4 track", zap.Any("track", t)) switch t.Cid { case mp4.MP4_CODEC_H264: - p.VideoTrack = track.NewH264(p.Stream) + p.VideoTrack = track.NewH264(p) case mp4.MP4_CODEC_H265: - p.VideoTrack = track.NewH265(p.Stream) + p.VideoTrack = track.NewH265(p) case mp4.MP4_CODEC_AAC: - p.AudioTrack = track.NewAAC(p.Stream) + p.AudioTrack = track.NewAAC(p) case mp4.MP4_CODEC_G711A: - p.AudioTrack = track.NewG711(p.Stream, true) + p.AudioTrack = track.NewG711(p, true) case mp4.MP4_CODEC_G711U: - p.AudioTrack = track.NewG711(p.Stream, false) + p.AudioTrack = track.NewG711(p, false) } } for { diff --git a/publisher-rtpdump.go b/publisher-rtpdump.go index 8d15dd9..24ac8b2 100644 --- a/publisher-rtpdump.go +++ b/publisher-rtpdump.go @@ -36,9 +36,9 @@ func (t *RTPDumpPublisher) Feed(file *os.File) { if t.VideoTrack == nil { switch t.VCodec { case codec.CodecID_H264: - t.VideoTrack = track.NewH264(t.Publisher.Stream, t.VPayloadType) + t.VideoTrack = track.NewH264(t, t.VPayloadType) case codec.CodecID_H265: - t.VideoTrack = track.NewH265(t.Publisher.Stream, t.VPayloadType) + t.VideoTrack = track.NewH265(t, t.VPayloadType) } if t.VideoTrack != nil { t.VideoTrack.SetSpeedLimit(500 * time.Millisecond) @@ -47,7 +47,7 @@ func (t *RTPDumpPublisher) Feed(file *os.File) { if t.AudioTrack == nil { switch t.ACodec { case codec.CodecID_AAC: - at := track.NewAAC(t.Publisher.Stream, t.APayloadType) + at := track.NewAAC(t, t.APayloadType) t.AudioTrack = at var c mpeg4audio.Config c.ChannelCount = 2 @@ -55,9 +55,9 @@ func (t *RTPDumpPublisher) Feed(file *os.File) { asc, _ := c.Marshal() at.WriteSequenceHead(append([]byte{0xAF, 0x00}, asc...)) case codec.CodecID_PCMA: - t.AudioTrack = track.NewG711(t.Publisher.Stream, true, t.APayloadType) + t.AudioTrack = track.NewG711(t, true, t.APayloadType) case codec.CodecID_PCMU: - t.AudioTrack = track.NewG711(t.Publisher.Stream, false, t.APayloadType) + t.AudioTrack = track.NewG711(t, false, t.APayloadType) } if t.AudioTrack != nil { t.AudioTrack.SetSpeedLimit(500 * time.Millisecond) diff --git a/publisher-ts.go b/publisher-ts.go index bd1c53b..825b4ce 100644 --- a/publisher-ts.go +++ b/publisher-ts.go @@ -31,9 +31,9 @@ func (t *TSPublisher) OnEvent(event any) { switch v := event.(type) { case IPublisher: t.pool = make(util.BytesPool, 17) - if !t.Equal(v) { - t.AudioTrack = v.getAudioTrack() - t.VideoTrack = v.getVideoTrack() + if v.GetPublisher() != &t.Publisher { + t.AudioTrack = v.GetAudioTrack() + t.VideoTrack = v.GetVideoTrack() } case SEKick, SEclose: // close(t.PESChan) @@ -47,23 +47,23 @@ func (t *TSPublisher) OnPmtStream(s mpegts.MpegTsPmtStream) { switch s.StreamType { case mpegts.STREAM_TYPE_H264: if t.VideoTrack == nil { - t.VideoTrack = track.NewH264(t.Publisher.Stream, t.pool) + t.VideoTrack = track.NewH264(t, t.pool) } case mpegts.STREAM_TYPE_H265: if t.VideoTrack == nil { - t.VideoTrack = track.NewH265(t.Publisher.Stream, t.pool) + t.VideoTrack = track.NewH265(t, t.pool) } case mpegts.STREAM_TYPE_AAC: if t.AudioTrack == nil { - t.AudioTrack = track.NewAAC(t.Publisher.Stream, t.pool) + t.AudioTrack = track.NewAAC(t, t.pool) } case mpegts.STREAM_TYPE_G711A: if t.AudioTrack == nil { - t.AudioTrack = track.NewG711(t.Publisher.Stream, true, t.pool) + t.AudioTrack = track.NewG711(t, true, t.pool) } case mpegts.STREAM_TYPE_G711U: if t.AudioTrack == nil { - t.AudioTrack = track.NewG711(t.Publisher.Stream, false, t.pool) + t.AudioTrack = track.NewG711(t, false, t.pool) } default: t.Warn("unsupport stream type:", zap.Uint8("type", s.StreamType)) diff --git a/publisher.go b/publisher.go index a9f68f8..5700dd9 100644 --- a/publisher.go +++ b/publisher.go @@ -10,11 +10,8 @@ import ( ) type IPublisher interface { - IIO + common.IPuber GetPublisher() *Publisher - getAudioTrack() common.AudioTrack - getVideoTrack() common.VideoTrack - Publish(streamPath string, pub IPublisher) error } var _ IPublisher = (*Publisher)(nil) @@ -26,7 +23,7 @@ type Publisher struct { common.VideoTrack `json:"-" yaml:"-"` } -func (p *Publisher) Publish(streamPath string, pub IPublisher) error { +func (p *Publisher) Publish(streamPath string, pub common.IPuber) error { return p.receive(streamPath, pub) } @@ -39,16 +36,15 @@ func (p *Publisher) GetPublisher() *Publisher { // p.Stream.Receive(ACTION_PUBLISHCLOSE) // } -func (p *Publisher) getAudioTrack() common.AudioTrack { +func (p *Publisher) GetAudioTrack() common.AudioTrack { return p.AudioTrack } -func (p *Publisher) getVideoTrack() common.VideoTrack { +func (p *Publisher) GetVideoTrack() common.VideoTrack { return p.VideoTrack } -func (p *Publisher) Equal(p2 IPublisher) bool { - return p == p2.GetPublisher() +func (p *Publisher) GetConfig() *config.Publish { + return p.Config } - // func (p *Publisher) OnEvent(event any) { // p.IO.OnEvent(event) // switch event.(type) { @@ -69,10 +65,10 @@ func (p *Publisher) WriteAVCCVideo(ts uint32, frame *util.BLL, pool util.BytesPo fourCC := frame.GetUintN(1, 4) switch fourCC { case codec.FourCC_H265_32: - p.VideoTrack = track.NewH265(p.Stream, pool) + p.VideoTrack = track.NewH265(p, pool) p.VideoTrack.WriteAVCC(ts, frame) case codec.FourCC_AV1_32: - p.VideoTrack = track.NewAV1(p.Stream, pool) + p.VideoTrack = track.NewAV1(p, pool) p.VideoTrack.WriteAVCC(ts, frame) } } else { @@ -80,9 +76,9 @@ func (p *Publisher) WriteAVCCVideo(ts uint32, frame *util.BLL, pool util.BytesPo ts = 0 switch codecID := codec.VideoCodecID(b0 & 0x0F); codecID { case codec.CodecID_H264: - p.VideoTrack = track.NewH264(p.Stream, pool) + p.VideoTrack = track.NewH264(p, pool) case codec.CodecID_H265: - p.VideoTrack = track.NewH265(p.Stream, pool) + p.VideoTrack = track.NewH265(p, pool) default: p.Stream.Error("video codecID not support", zap.Uint8("codeId", uint8(codecID))) return @@ -108,7 +104,7 @@ func (p *Publisher) WriteAVCCAudio(ts uint32, frame *util.BLL, pool util.BytesPo if frame.GetByte(1) != 0 { return } - a := track.NewAAC(p.Stream, pool) + a := track.NewAAC(p, pool) p.AudioTrack = a a.AVCCHead = []byte{frame.GetByte(0), 1} a.WriteAVCC(0, frame) @@ -118,7 +114,7 @@ func (p *Publisher) WriteAVCCAudio(ts uint32, frame *util.BLL, pool util.BytesPo if codecID == codec.CodecID_PCMU { alaw = false } - a := track.NewG711(p.Stream, alaw, pool) + a := track.NewG711(p, alaw, pool) p.AudioTrack = a a.Audio.SampleRate = uint32(codec.SoundRate[(b0&0x0c)>>2]) if b0&0x02 == 0 { diff --git a/stream.go b/stream.go index 9b3ed58..8eb5500 100644 --- a/stream.go +++ b/stream.go @@ -14,7 +14,6 @@ import ( . "github.com/logrusorgru/aurora/v4" "go.uber.org/zap" "m7s.live/engine/v4/common" - . "m7s.live/engine/v4/common" "m7s.live/engine/v4/config" "m7s.live/engine/v4/log" "m7s.live/engine/v4/track" @@ -128,27 +127,23 @@ type Tracks struct { Data []common.Track MainVideo *track.Video MainAudio *track.Audio - SEI *track.Channel[[]byte] marshalLock sync.Mutex } -func (tracks *Tracks) Range(f func(name string, t Track)) { +func (tracks *Tracks) Range(f func(name string, t common.Track)) { tracks.Map.Range(func(k, v any) bool { - f(k.(string), v.(Track)) + f(k.(string), v.(common.Track)) return true }) } -func (tracks *Tracks) Add(name string, t Track) bool { +func (tracks *Tracks) Add(name string, t common.Track) bool { switch v := t.(type) { case *track.Video: if tracks.MainVideo == nil { tracks.MainVideo = v tracks.SetIDR(v) } - if tracks.SEI != nil { - v.SEIReader = tracks.SEI.CreateReader(100) - } case *track.Audio: if tracks.MainAudio == nil { tracks.MainAudio = v @@ -171,9 +166,9 @@ func (tracks *Tracks) Add(name string, t Track) bool { return !loaded } -func (tracks *Tracks) SetIDR(video Track) { +func (tracks *Tracks) SetIDR(video common.Track) { if video == tracks.MainVideo { - tracks.Range(func(_ string, t Track) { + tracks.Range(func(_ string, t common.Track) { if v, ok := t.(*track.Audio); ok { v.Narrow() } @@ -181,29 +176,11 @@ func (tracks *Tracks) SetIDR(video Track) { } } -func (tracks *Tracks) AddSEI(t byte, data []byte) bool { - if tracks.SEI != nil { - l := len(data) - var buffer util.Buffer - buffer.WriteByte(t) - for l >= 255 { - buffer.WriteByte(255) - l -= 255 - } - buffer.WriteByte(byte(l)) - buffer.Write(data) - buffer.WriteByte(0x80) - tracks.SEI.Write(buffer) - return true - } - return false -} - func (tracks *Tracks) MarshalJSON() ([]byte, error) { - var trackList []Track + var trackList []common.Track tracks.marshalLock.Lock() defer tracks.marshalLock.Unlock() - tracks.Range(func(_ string, t Track) { + tracks.Range(func(_ string, t common.Track) { t.SnapForJson() trackList = append(trackList, t) }) @@ -222,6 +199,7 @@ type Stream struct { StreamTimeoutConfig Path string Publisher IPublisher + publisher *Publisher State StreamState SEHistory []StateEvent // 事件历史 Subscribers Subscribers // 订阅者 @@ -245,7 +223,11 @@ func (s *Stream) GetType() string { if s.Publisher == nil { return "" } - return s.Publisher.GetPublisher().Type + return s.publisher.Type +} + +func (s *Stream) GetPath() string { + return s.Path } func (s *Stream) GetStartTime() time.Time { @@ -257,15 +239,15 @@ func (s *Stream) GetPublisherConfig() *config.Publish { s.Error("GetPublisherConfig: Publisher is nil") return nil } - return s.Publisher.GetPublisher().Config + return s.Publisher.GetConfig() } // Summary 返回流的简要信息 func (s *Stream) Summary() (r StreamSummay) { - if s.Publisher != nil { - r.Type = s.Publisher.GetPublisher().Type + if s.publisher != nil { + r.Type = s.publisher.Type } - s.Tracks.Range(func(name string, t Track) { + s.Tracks.Range(func(name string, t common.Track) { r.BPS += t.GetBPS() r.Tracks = append(r.Tracks, name) }) @@ -279,7 +261,7 @@ func (s *Stream) Summary() (r StreamSummay) { func (s *Stream) SSRC() uint32 { return uint32(uintptr(unsafe.Pointer(s))) } -func (s *Stream) SetIDR(video Track) { +func (s *Stream) SetIDR(video common.Track) { s.Tracks.SetIDR(video) } func findOrCreateStream(streamPath string, waitTimeout time.Duration) (s *Stream, created bool) { @@ -330,9 +312,9 @@ func (r *Stream) action(action StreamAction) (ok bool) { stateEvent = SEwaitPublish{event, r.Publisher} waitTime := time.Duration(0) if r.Publisher != nil { - waitTime = r.Publisher.GetPublisher().Config.WaitCloseTimeout - r.Tracks.Range(func(name string, t Track) { - t.SetStuff(TrackStateOffline) + waitTime = r.Publisher.GetConfig().WaitCloseTimeout + r.Tracks.Range(func(name string, t common.Track) { + t.SetStuff(common.TrackStateOffline) }) } r.Subscribers.OnPublisherLost(event) @@ -373,8 +355,10 @@ func (r *Stream) action(action StreamAction) (ok bool) { r.timeout.Stop() stateEvent = SEclose{event} r.Subscribers.Broadcast(stateEvent) - r.Tracks.Range(func(_ string, t Track) { - t.Dispose() + r.Tracks.Range(func(_ string, t common.Track) { + if t.GetPublisher().GetStream() == r { + t.Dispose() + } }) r.Subscribers.Dispose() r.actionChan.Close() @@ -486,7 +470,7 @@ func (s *Stream) run() { if s.IsPause { timeout = s.PauseTimeout } - s.Tracks.Range(func(name string, t Track) { + s.Tracks.Range(func(name string, t common.Track) { trackCount++ switch t.(type) { case *track.Video, *track.Audio: @@ -504,7 +488,7 @@ func (s *Stream) run() { s.action(ACTION_CLOSE) continue } else if s.Publisher != nil && s.Publisher.IsClosed() { - s.Warn("publish is closed", zap.Error(context.Cause(s.Publisher.GetPublisher())), zap.String("ptr", fmt.Sprintf("%p", s.Publisher.GetPublisher().Context))) + s.Warn("publish is closed", zap.Error(context.Cause(s.publisher)), zap.String("ptr", fmt.Sprintf("%p", s.publisher.Context))) lost = true if len(s.Tracks.Audio)+len(s.Tracks.Video) == 0 { s.action(ACTION_CLOSE) @@ -546,19 +530,17 @@ func (s *Stream) run() { break } puber := v.Value.GetPublisher() - var oldPuber *Publisher - if s.Publisher != nil { - oldPuber = s.Publisher.GetPublisher() - } + oldPuber := s.publisher + s.publisher = puber conf := puber.Config republish := s.Publisher == v.Value // 重复发布 if republish { s.Info("republish") - s.Tracks.Range(func(name string, t Track) { - t.SetStuff(TrackStateOffline) + s.Tracks.Range(func(name string, t common.Track) { + t.SetStuff(common.TrackStateOffline) }) } - needKick := !republish && s.Publisher != nil && conf.KickExist // 需要踢掉老的发布者 + needKick := !republish && oldPuber != nil && conf.KickExist // 需要踢掉老的发布者 if needKick { s.Warn("kick", zap.String("old type", oldPuber.Type)) s.Publisher.OnEvent(SEKick{CreateEvent[struct{}](util.Null)}) @@ -574,12 +556,6 @@ func (s *Stream) run() { puber.AudioTrack = oldPuber.AudioTrack puber.VideoTrack = oldPuber.VideoTrack } - if conf.InsertSEI { - if s.Tracks.SEI == nil { - s.Tracks.SEI = &track.Channel[[]byte]{} - s.Info("sei track added") - } - } v.Resolve() } else { s.Warn("duplicate publish") @@ -618,8 +594,8 @@ func (s *Stream) run() { } if s.Publisher != nil { s.Publisher.OnEvent(v) // 通知Publisher有新的订阅者加入,在回调中可以去获取订阅者数量 - pubConfig := s.Publisher.GetPublisher().Config - s.Tracks.Range(func(name string, t Track) { + pubConfig := s.Publisher.GetConfig() + s.Tracks.Range(func(name string, t common.Track) { waits.Accept(t) }) if !pubConfig.PubAudio { @@ -656,7 +632,7 @@ func (s *Stream) run() { s.Subscribers.Broadcast(t) t.(common.Track).Dispose() } - case *util.Promise[Track]: + case *util.Promise[common.Track]: timeOutInfo = zap.String("action", "Track") if s.IsClosed() { v.Reject(ErrStreamIsClosed) @@ -706,7 +682,7 @@ func (s *Stream) run() { } } -func (s *Stream) AddTrack(t Track) (promise *util.Promise[Track]) { +func (s *Stream) AddTrack(t common.Track) (promise *util.Promise[common.Track]) { promise = util.NewPromise(t) if !s.Receive(promise) { promise.Reject(ErrStreamIsClosed) @@ -714,7 +690,7 @@ func (s *Stream) AddTrack(t Track) (promise *util.Promise[Track]) { return } -func (s *Stream) RemoveTrack(t Track) { +func (s *Stream) RemoveTrack(t common.Track) { s.Receive(TrackRemoved{t}) } @@ -727,7 +703,7 @@ func (s *Stream) Resume() { } type TrackRemoved struct { - Track + common.Track } type SubPulse struct { diff --git a/track/aac.go b/track/aac.go index 0cd085e..3d4b061 100644 --- a/track/aac.go +++ b/track/aac.go @@ -14,7 +14,7 @@ import ( var _ SpesificTrack = (*AAC)(nil) -func NewAAC(stream IStream, stuff ...any) (aac *AAC) { +func NewAAC(puber IPuber, stuff ...any) (aac *AAC) { aac = &AAC{ Mode: 2, } @@ -24,7 +24,7 @@ func NewAAC(stream IStream, stuff ...any) (aac *AAC) { aac.CodecID = codec.CodecID_AAC aac.Channels = 2 aac.SampleSize = 16 - aac.SetStuff("aac", byte(97), aac, stuff, stream) + aac.SetStuff("aac", byte(97), aac, stuff, puber) if aac.BytesPool == nil { aac.BytesPool = make(util.BytesPool, 17) } diff --git a/track/audio.go b/track/audio.go index e4e4fa2..c7dbd8e 100644 --- a/track/audio.go +++ b/track/audio.go @@ -20,20 +20,15 @@ type Audio struct { } func (a *Audio) Attach() { - if a.Attached.CompareAndSwap(false, true) { - if err := a.Stream.AddTrack(a).Await(); err != nil { - a.Error("attach audio track failed", zap.Error(err)) - a.Attached.Store(false) - } else { - a.Info("audio track attached", zap.Uint32("sample rate", a.SampleRate)) - } + if err := a.Publisher.GetStream().AddTrack(a).Await(); err != nil { + a.Error("attach audio track failed", zap.Error(err)) + } else { + a.Info("audio track attached", zap.Uint32("sample rate", a.SampleRate)) } } func (a *Audio) Detach() { - if a.Attached.CompareAndSwap(true, false) { - a.Stream.RemoveTrack(a) - } + a.Publisher.GetStream().RemoveTrack(a) } func (a *Audio) GetName() string { @@ -43,6 +38,10 @@ func (a *Audio) GetName() string { return a.Name } +func (a *Audio) GetCodec() codec.AudioCodecID { + return a.CodecID +} + func (av *Audio) WriteADTS(pts uint32, adts util.IBytes) { } diff --git a/track/av1.go b/track/av1.go index 01df2e2..1be19c4 100644 --- a/track/av1.go +++ b/track/av1.go @@ -23,15 +23,15 @@ type AV1 struct { refFrameType map[byte]byte } -func NewAV1(stream IStream, stuff ...any) (vt *AV1) { +func NewAV1(puber IPuber, stuff ...any) (vt *AV1) { vt = &AV1{} vt.Video.CodecID = codec.CodecID_AV1 - vt.SetStuff("av1", byte(96), uint32(90000), vt, stuff, stream) + vt.SetStuff("av1", byte(96), uint32(90000), vt, stuff, puber) if vt.BytesPool == nil { vt.BytesPool = make(util.BytesPool, 17) } vt.nalulenSize = 0 - vt.dtsEst = NewDTSEstimator() + vt.dtsEst = util.NewDTSEstimator() vt.decoder.Init() vt.encoder.Init() vt.encoder.PayloadType = vt.PayloadType @@ -53,7 +53,7 @@ func (vt *AV1) WriteRTPFrame(rtpItem *util.ListItem[RTPFrame]) { err := recover() if err != nil { vt.Error("WriteRTPFrame panic", zap.Any("err", err)) - vt.Stream.Close() + vt.Publisher.Stop(zap.Any("err", err)) } }() if vt.lastSeq != vt.lastSeq2+1 && vt.lastSeq2 != 0 { diff --git a/track/base.go b/track/base.go index 38d03b3..31e5271 100644 --- a/track/base.go +++ b/track/base.go @@ -2,318 +2,87 @@ package track import ( "time" - "unsafe" - "github.com/pion/rtp" "go.uber.org/zap" - . "m7s.live/engine/v4/common" - "m7s.live/engine/v4/config" + "m7s.live/engine/v4/common" "m7s.live/engine/v4/log" "m7s.live/engine/v4/util" ) -var deltaDTSRange time.Duration = 90 * 10000 // 超过 10 秒 - -type 流速控制 struct { - 起始时间戳 time.Duration - 起始dts time.Duration - 等待上限 time.Duration - 起始时间 time.Time +// Base 基础Track类 +type Base[T any, F util.IDataFrame[T]] struct { + util.RingWriter[T, F] + Name string + log.Zap `json:"-" yaml:"-"` + Publisher common.IPuber `json:"-" yaml:"-"` //所属发布者 + State common.TrackState + ts time.Time + bytes int + frames int + DropCount int `json:"-" yaml:"-"` //丢帧数 + BPS int + FPS int + Drops int // 丢帧率 + RawSize int // 裸数据长度 + RawPart []int // 裸数据片段用于UI上显示 } -func (p *流速控制) 重置(绝对时间戳 time.Duration, dts time.Duration) { - p.起始时间 = time.Now() - p.起始时间戳 = 绝对时间戳 - p.起始dts = dts - // println("重置", p.起始时间.Format("2006-01-02 15:04:05"), p.起始时间戳) -} -func (p *流速控制) 根据起始DTS计算绝对时间戳(dts time.Duration) time.Duration { - if dts < p.起始dts { - dts += (1 << 32) - } - return ((dts-p.起始dts)*time.Millisecond + p.起始时间戳*90) / 90 -} - -func (p *流速控制) 控制流速(绝对时间戳 time.Duration, dts time.Duration) (等待了 time.Duration) { - 数据时间差, 实际时间差 := 绝对时间戳-p.起始时间戳, time.Since(p.起始时间) - // println("数据时间差", 数据时间差, "实际时间差", 实际时间差, "绝对时间戳", 绝对时间戳, "起始时间戳", p.起始时间戳, "起始时间", p.起始时间.Format("2006-01-02 15:04:05")) - // if 实际时间差 > 数据时间差 { - // p.重置(绝对时间戳) - // return - // } - // 如果收到的帧的时间戳超过实际消耗的时间100ms就休息一下,100ms作为一个弹性区间防止频繁调用sleep - if 过快 := (数据时间差 - 实际时间差); 过快 > 100*time.Millisecond { - // fmt.Println("过快毫秒", 过快.Milliseconds()) - // println("过快毫秒", p.name, 过快.Milliseconds()) - if 过快 > p.等待上限 { - 等待了 = p.等待上限 - } else { - 等待了 = 过快 - } - time.Sleep(等待了) - } else if 过快 < -100*time.Millisecond { - // fmt.Println("过慢毫秒", 过快.Milliseconds()) - // p.重置(绝对时间戳, dts) - // println("过慢毫秒", p.name, 过快.Milliseconds()) - } - return -} - -type SpesificTrack interface { - CompleteRTP(*AVFrame) - CompleteAVCC(*AVFrame) - WriteSliceBytes([]byte) - WriteRTPFrame(*util.ListItem[RTPFrame]) - generateTimestamp(uint32) - WriteSequenceHead([]byte) error - writeAVCCFrame(uint32, *util.BLLReader, *util.BLL) error - GetNALU_SEI() *util.ListItem[util.Buffer] - Flush() -} - -type IDRingList struct { - IDRList util.List[*util.Ring[*AVFrame]] - IDRing *util.Ring[*AVFrame] - HistoryRing *util.Ring[*AVFrame] -} - -func (p *IDRingList) AddIDR(IDRing *util.Ring[*AVFrame]) { - p.IDRList.PushValue(IDRing) - p.IDRing = IDRing -} - -func (p *IDRingList) ShiftIDR() { - p.IDRList.Shift() - p.HistoryRing = p.IDRList.Next.Value -} - -// Media 基础媒体Track类 -type Media struct { - Base[any, *AVFrame] - PayloadType byte - IDRingList `json:"-" yaml:"-"` //最近的关键帧位置,首屏渲染 - SSRC uint32 - SampleRate uint32 - BytesPool util.BytesPool `json:"-" yaml:"-"` - RtpPool util.Pool[RTPFrame] `json:"-" yaml:"-"` - SequenceHead []byte `json:"-" yaml:"-"` //H264(SPS、PPS) H265(VPS、SPS、PPS) AAC(config) - SequenceHeadSeq int - RTPDemuxer - SpesificTrack `json:"-" yaml:"-"` - deltaTs time.Duration //用于接续发布后时间戳连续 - - 流速控制 -} - -func (av *Media) GetFromPool(b util.IBytes) (item *util.ListItem[util.Buffer]) { - if b.Reuse() { - item = av.BytesPool.Get(b.Len()) - copy(item.Value, b.Bytes()) - } else { - return av.BytesPool.GetShell(b.Bytes()) - } - return -} - -func (av *Media) GetRTPFromPool() (result *util.ListItem[RTPFrame]) { - result = av.RtpPool.Get() - if result.Value.Packet == nil { - result.Value.Packet = &rtp.Packet{} - result.Value.PayloadType = av.PayloadType - result.Value.SSRC = av.SSRC - result.Value.Version = 2 - result.Value.Raw = make([]byte, 1460) - } - result.Value.Raw = result.Value.Raw[:1460] - result.Value.Payload = result.Value.Raw[:0] - return -} - -// 为json序列化而计算的数据 -func (av *Media) SnapForJson() { - v := av.LastValue - if av.RawPart != nil { - av.RawPart = av.RawPart[:0] - } else { - av.RawPart = make([]int, 0, 10) - } - if av.RawSize = v.AUList.ByteLength; av.RawSize > 0 { - r := v.AUList.NewReader() - for b, err := r.ReadByte(); err == nil && len(av.RawPart) < 10; b, err = r.ReadByte() { - av.RawPart = append(av.RawPart, int(b)) - } +func (bt *Base[T, F]) ComputeBPS(bytes int) { + bt.bytes += bytes + bt.frames++ + if elapse := time.Since(bt.ts).Seconds(); elapse > 1 { + bt.BPS = int(float64(bt.bytes) / elapse) + bt.FPS = int(float64(bt.frames) / elapse) + bt.Drops = int(float64(bt.DropCount) / elapse) + bt.bytes = 0 + bt.frames = 0 + bt.DropCount = 0 + bt.ts = time.Now() } } -func (av *Media) SetSpeedLimit(value time.Duration) { - av.等待上限 = value +func (bt *Base[T, F]) GetName() string { + return bt.Name } -func (av *Media) SetStuff(stuff ...any) { - // 代表发布者已经离线,该Track成为遗留Track,等待下一任发布者接续发布 +func (bt *Base[T, F]) GetBPS() int { + return bt.BPS +} + +func (bt *Base[T, F]) GetFPS() int { + return bt.FPS +} + +func (bt *Base[T, F]) GetDrops() int { + return bt.Drops +} + +// GetRBSize 获取缓冲区大小 +func (bt *Base[T, F]) GetRBSize() int { + return bt.RingWriter.Size +} + +func (bt *Base[T, F]) SnapForJson() { +} + +func (bt *Base[T, F]) SetStuff(stuff ...any) { for _, s := range stuff { switch v := s.(type) { - case IStream: - pubConf := v.GetPublisherConfig() - av.Base.SetStuff(v) - av.Init(256, NewAVFrame) - av.SSRC = uint32(uintptr(unsafe.Pointer(av))) - av.等待上限 = pubConf.SpeedLimit - case uint32: - av.SampleRate = v - case byte: - av.PayloadType = v - case util.BytesPool: - av.BytesPool = v - case SpesificTrack: - av.SpesificTrack = v - case []any: - av.SetStuff(v...) - default: - av.Base.SetStuff(v) + case common.IPuber: + bt.Publisher = v + bt.Zap = v.With(zap.String("track", bt.Name)) + case common.TrackState: + bt.State = v + case string: + bt.Name = v } } } -func (av *Media) LastWriteTime() time.Time { - return av.LastValue.WriteTime +func (bt *Base[T, F]) GetPublisher() common.IPuber { + return bt.Publisher } -func (av *Media) CurrentFrame() *AVFrame { - return av.Value -} -func (av *Media) PreFrame() *AVFrame { - return av.LastValue -} - -func (av *Media) generateTimestamp(ts uint32) { - av.Value.PTS = time.Duration(ts) - av.Value.DTS = time.Duration(ts) -} - -func (av *Media) WriteSequenceHead(sh []byte) { - av.SequenceHead = sh - av.SequenceHeadSeq++ -} -func (av *Media) AppendAuBytes(b ...[]byte) { - var au util.BLL - for _, bb := range b { - au.Push(av.BytesPool.GetShell(bb)) - } - av.Value.AUList.PushValue(&au) -} - -func (av *Media) narrow(gop int) { - if l := av.Size - gop; l > 12 { - if log.Trace { - av.Trace("resize", zap.Int("before", av.Size), zap.Int("after", av.Size-5)) - } - //缩小缓冲环节省内存 - av.Reduce(5) - } -} - -func (av *Media) AddIDR() { - if av.Stream.GetPublisherConfig().BufferTime > 0 { - av.IDRingList.AddIDR(av.Ring) - if av.HistoryRing == nil { - av.HistoryRing = av.IDRing - } - } else { - av.IDRing = av.Ring - } -} - -func (av *Media) Flush() { - curValue, preValue, nextValue := av.Value, av.LastValue, av.Next() - useDts := curValue.Timestamp == 0 - originDTS := curValue.DTS - if av.State == TrackStateOffline { - av.State = TrackStateOnline - if useDts { - av.deltaTs = curValue.DTS - preValue.DTS - } else { - av.deltaTs = curValue.Timestamp - preValue.Timestamp - } - curValue.DTS = preValue.DTS + 900 - curValue.PTS = preValue.PTS + 900 - curValue.Timestamp = preValue.Timestamp + time.Millisecond - av.Info("track back online", zap.Duration("delta", av.deltaTs)) - } else if av.deltaTs != 0 { - if useDts { - curValue.DTS -= av.deltaTs - curValue.PTS -= av.deltaTs - } else { - rtpts := av.deltaTs * 90 / time.Millisecond - curValue.DTS -= rtpts - curValue.PTS -= rtpts - curValue.Timestamp -= av.deltaTs - } - } - if av.起始时间.IsZero() { - curValue.DeltaTime = 0 - if useDts { - curValue.Timestamp = time.Since(av.Stream.GetStartTime()) - } - av.重置(curValue.Timestamp, curValue.DTS) - } else { - if useDts { - deltaDts := curValue.DTS - preValue.DTS - if deltaDts > deltaDTSRange || deltaDts < -deltaDTSRange { - // 时间戳跳变,等同于离线重连 - av.deltaTs = originDTS - preValue.DTS - curValue.DTS = preValue.DTS + 900 - curValue.PTS = preValue.PTS + 900 - av.Warn("track dts reset", zap.Int64("delta1", int64(deltaDts)), zap.Int64("delta2", int64(av.deltaTs))) - } - curValue.Timestamp = av.根据起始DTS计算绝对时间戳(curValue.DTS) - } - - curValue.DeltaTime = uint32(deltaTS(curValue.Timestamp, preValue.Timestamp) / time.Millisecond) - } - if log.Trace { - av.Trace("write", zap.Uint32("seq", curValue.Sequence), zap.Int64("dts0", int64(preValue.DTS)), zap.Int64("dts1", int64(originDTS)), zap.Uint64("dts2", uint64(curValue.DTS)), zap.Uint32("delta", curValue.DeltaTime), zap.Duration("timestamp", curValue.Timestamp), zap.Int("au", curValue.AUList.Length), zap.Int("rtp", curValue.RTP.Length), zap.Int("avcc", curValue.AVCC.ByteLength), zap.Int("raw", curValue.AUList.ByteLength), zap.Int("bps", av.BPS)) - } - bufferTime := av.Stream.GetPublisherConfig().BufferTime - if bufferTime > 0 && av.IDRingList.IDRList.Length > 1 && deltaTS(curValue.Timestamp, av.IDRingList.IDRList.Next.Next.Value.Value.Timestamp) > bufferTime { - av.ShiftIDR() - av.narrow(int(curValue.Sequence - av.HistoryRing.Value.Sequence)) - } - // 下一帧为订阅起始帧,即将覆盖,需要扩环 - if nextValue == av.IDRing || nextValue == av.HistoryRing { - // if av.AVRing.Size < 512 { - if log.Trace { - av.Stream.Trace("resize", zap.Int("before", av.Size), zap.Int("after", av.Size+5), zap.String("name", av.Name)) - } - av.Glow(5) - // } else { - // av.Stream.Error("sub ring overflow", zap.Int("size", av.AVRing.Size), zap.String("name", av.Name)) - // } - } - - if curValue.AUList.Length > 0 { - // 补完RTP - if config.Global.EnableRTP && curValue.RTP.Length == 0 { - av.CompleteRTP(curValue) - } - // 补完AVCC - if config.Global.EnableAVCC && curValue.AVCC.ByteLength == 0 { - av.CompleteAVCC(curValue) - } - } - av.ComputeBPS(curValue.BytesIn) - av.Step() - if av.等待上限 > 0 { - 等待了 := av.控制流速(curValue.Timestamp, curValue.DTS) - if log.Trace && 等待了 > 0 { - av.Trace("speed control", zap.Duration("sleep", 等待了)) - } - } -} - -func deltaTS(curTs time.Duration, preTs time.Duration) time.Duration { - if curTs < preTs { - return curTs + (1<<32)*time.Millisecond - preTs - } - return curTs - preTs +func (bt *Base[T, F]) Dispose() { + bt.Value.Broadcast() } diff --git a/track/g711.go b/track/g711.go index 5ae5bc6..6e43caa 100644 --- a/track/g711.go +++ b/track/g711.go @@ -11,7 +11,7 @@ import ( var _ SpesificTrack = (*G711)(nil) -func NewG711(stream IStream, alaw bool, stuff ...any) (g711 *G711) { +func NewG711(puber IPuber, alaw bool, stuff ...any) (g711 *G711) { g711 = &G711{} if alaw { g711.Name = "pcma" @@ -28,7 +28,7 @@ func NewG711(stream IStream, alaw bool, stuff ...any) (g711 *G711) { g711.SampleSize = 8 g711.Channels = 1 g711.AVCCHead = []byte{(byte(g711.CodecID) << 4) | (1 << 1)} - g711.SetStuff(uint32(8000), g711, stuff, stream) + g711.SetStuff(uint32(8000), g711, stuff, puber) if g711.BytesPool == nil { g711.BytesPool = make(util.BytesPool, 17) } diff --git a/track/h264.go b/track/h264.go index cbc86e1..b8b09a8 100644 --- a/track/h264.go +++ b/track/h264.go @@ -18,16 +18,16 @@ type H264 struct { buf util.Buffer // rtp 包临时缓存,对于不规范的 rtp 包(sps 放到了 fua 中导致)需要缓存 } -func NewH264(stream IStream, stuff ...any) (vt *H264) { +func NewH264(puber IPuber, stuff ...any) (vt *H264) { vt = &H264{} vt.Video.CodecID = codec.CodecID_H264 - vt.SetStuff("h264", byte(96), uint32(90000), vt, stuff, stream) + vt.SetStuff("h264", byte(96), uint32(90000), vt, stuff, puber) if vt.BytesPool == nil { vt.BytesPool = make(util.BytesPool, 17) } vt.ParamaterSets = make(ParamaterSets, 2) vt.nalulenSize = 4 - vt.dtsEst = NewDTSEstimator() + vt.dtsEst = util.NewDTSEstimator() return } @@ -104,8 +104,8 @@ func (vt *H264) WriteSequenceHead(head []byte) (err error) { vt.ParamaterSets[1] = vt.PPS vt.Video.WriteSequenceHead(head) } else { - vt.Stream.Error("H264 ParseSpsPps Error") - vt.Stream.Close() + vt.Error("H264 ParseSpsPps Error") + vt.Publisher.Stop(zap.Error(err)) } return } @@ -115,7 +115,7 @@ func (vt *H264) WriteRTPFrame(rtpItem *util.ListItem[RTPFrame]) { err := recover() if err != nil { vt.Error("WriteRTPFrame panic", zap.Any("err", err)) - vt.Stream.Close() + vt.Publisher.Stop(zap.Any("err", err)) } }() if vt.lastSeq != vt.lastSeq2+1 && vt.lastSeq2 != 0 { diff --git a/track/h265.go b/track/h265.go index 0a5729c..5d1b814 100644 --- a/track/h265.go +++ b/track/h265.go @@ -15,16 +15,16 @@ type H265 struct { VPS []byte `json:"-" yaml:"-"` } -func NewH265(stream IStream, stuff ...any) (vt *H265) { +func NewH265(puber IPuber, stuff ...any) (vt *H265) { vt = &H265{} vt.Video.CodecID = codec.CodecID_H265 - vt.SetStuff("h265", byte(96), uint32(90000), vt, stuff, stream) + vt.SetStuff("h265", byte(96), uint32(90000), vt, stuff, puber) if vt.BytesPool == nil { vt.BytesPool = make(util.BytesPool, 17) } vt.ParamaterSets = make(ParamaterSets, 3) vt.nalulenSize = 4 - vt.dtsEst = NewDTSEstimator() + vt.dtsEst = util.NewDTSEstimator() return } @@ -92,7 +92,7 @@ func (vt *H265) WriteSequenceHead(head []byte) (err error) { vt.Video.WriteSequenceHead(head) } else { vt.Error("H265 ParseVpsSpsPps Error") - vt.Stream.Close() + vt.Publisher.Stop(zap.Error(err)) } return } @@ -102,7 +102,7 @@ func (vt *H265) WriteRTPFrame(rtpItem *util.ListItem[RTPFrame]) { err := recover() if err != nil { vt.Error("WriteRTPFrame panic", zap.Any("err", err)) - vt.Stream.Close() + vt.Publisher.Stop(zap.Any("err", err)) } }() frame := &rtpItem.Value diff --git a/track/media.go b/track/media.go new file mode 100644 index 0000000..0ac532c --- /dev/null +++ b/track/media.go @@ -0,0 +1,321 @@ +package track + +import ( + "time" + "unsafe" + + "github.com/pion/rtp" + "go.uber.org/zap" + . "m7s.live/engine/v4/common" + "m7s.live/engine/v4/config" + "m7s.live/engine/v4/log" + "m7s.live/engine/v4/util" +) + +var deltaDTSRange time.Duration = 90 * 10000 // 超过 10 秒 + +type 流速控制 struct { + 起始时间戳 time.Duration + 起始dts time.Duration + 等待上限 time.Duration + 起始时间 time.Time +} + +func (p *流速控制) 重置(绝对时间戳 time.Duration, dts time.Duration) { + p.起始时间 = time.Now() + p.起始时间戳 = 绝对时间戳 + p.起始dts = dts + // println("重置", p.起始时间.Format("2006-01-02 15:04:05"), p.起始时间戳) +} +func (p *流速控制) 根据起始DTS计算绝对时间戳(dts time.Duration) time.Duration { + if dts < p.起始dts { + dts += (1 << 32) + } + return ((dts-p.起始dts)*time.Millisecond + p.起始时间戳*90) / 90 +} + +func (p *流速控制) 控制流速(绝对时间戳 time.Duration, dts time.Duration) (等待了 time.Duration) { + 数据时间差, 实际时间差 := 绝对时间戳-p.起始时间戳, time.Since(p.起始时间) + // println("数据时间差", 数据时间差, "实际时间差", 实际时间差, "绝对时间戳", 绝对时间戳, "起始时间戳", p.起始时间戳, "起始时间", p.起始时间.Format("2006-01-02 15:04:05")) + // if 实际时间差 > 数据时间差 { + // p.重置(绝对时间戳) + // return + // } + // 如果收到的帧的时间戳超过实际消耗的时间100ms就休息一下,100ms作为一个弹性区间防止频繁调用sleep + if 过快 := (数据时间差 - 实际时间差); 过快 > 100*time.Millisecond { + // fmt.Println("过快毫秒", 过快.Milliseconds()) + // println("过快毫秒", p.name, 过快.Milliseconds()) + if 过快 > p.等待上限 { + 等待了 = p.等待上限 + } else { + 等待了 = 过快 + } + time.Sleep(等待了) + } else if 过快 < -100*time.Millisecond { + // fmt.Println("过慢毫秒", 过快.Milliseconds()) + // p.重置(绝对时间戳, dts) + // println("过慢毫秒", p.name, 过快.Milliseconds()) + } + return +} + +type SpesificTrack interface { + CompleteRTP(*AVFrame) + CompleteAVCC(*AVFrame) + WriteSliceBytes([]byte) + WriteRTPFrame(*util.ListItem[RTPFrame]) + generateTimestamp(uint32) + WriteSequenceHead([]byte) error + writeAVCCFrame(uint32, *util.BLLReader, *util.BLL) error + GetNALU_SEI() *util.ListItem[util.Buffer] + Flush() +} + +type IDRingList struct { + IDRList util.List[*util.Ring[*AVFrame]] + IDRing *util.Ring[*AVFrame] + HistoryRing *util.Ring[*AVFrame] +} + +func (p *IDRingList) AddIDR(IDRing *util.Ring[*AVFrame]) { + p.IDRList.PushValue(IDRing) + p.IDRing = IDRing +} + +func (p *IDRingList) ShiftIDR() { + p.IDRList.Shift() + p.HistoryRing = p.IDRList.Next.Value +} + +// Media 基础媒体Track类 +type Media struct { + Base[any, *AVFrame] + BufferTime time.Duration //发布者配置中的缓冲时间(时光回溯) + PayloadType byte + IDRingList `json:"-" yaml:"-"` //最近的关键帧位置,首屏渲染 + SSRC uint32 + SampleRate uint32 + BytesPool util.BytesPool `json:"-" yaml:"-"` + RtpPool util.Pool[RTPFrame] `json:"-" yaml:"-"` + SequenceHead []byte `json:"-" yaml:"-"` //H264(SPS、PPS) H265(VPS、SPS、PPS) AAC(config) + SequenceHeadSeq int + RTPDemuxer + SpesificTrack `json:"-" yaml:"-"` + deltaTs time.Duration //用于接续发布后时间戳连续 + + 流速控制 +} + +func (av *Media) GetFromPool(b util.IBytes) (item *util.ListItem[util.Buffer]) { + if b.Reuse() { + item = av.BytesPool.Get(b.Len()) + copy(item.Value, b.Bytes()) + } else { + return av.BytesPool.GetShell(b.Bytes()) + } + return +} + +func (av *Media) GetRTPFromPool() (result *util.ListItem[RTPFrame]) { + result = av.RtpPool.Get() + if result.Value.Packet == nil { + result.Value.Packet = &rtp.Packet{} + result.Value.PayloadType = av.PayloadType + result.Value.SSRC = av.SSRC + result.Value.Version = 2 + result.Value.Raw = make([]byte, 1460) + } + result.Value.Raw = result.Value.Raw[:1460] + result.Value.Payload = result.Value.Raw[:0] + return +} + +// 为json序列化而计算的数据 +func (av *Media) SnapForJson() { + v := av.LastValue + if av.RawPart != nil { + av.RawPart = av.RawPart[:0] + } else { + av.RawPart = make([]int, 0, 10) + } + if av.RawSize = v.AUList.ByteLength; av.RawSize > 0 { + r := v.AUList.NewReader() + for b, err := r.ReadByte(); err == nil && len(av.RawPart) < 10; b, err = r.ReadByte() { + av.RawPart = append(av.RawPart, int(b)) + } + } +} + +func (av *Media) SetSpeedLimit(value time.Duration) { + av.等待上限 = value +} + +func (av *Media) SetStuff(stuff ...any) { + // 代表发布者已经离线,该Track成为遗留Track,等待下一任发布者接续发布 + for _, s := range stuff { + switch v := s.(type) { + case IPuber: + pubConf := v.GetConfig() + av.BufferTime = pubConf.BufferTime + av.Base.SetStuff(v) + av.Init(256, NewAVFrame) + av.SSRC = uint32(uintptr(unsafe.Pointer(av))) + av.等待上限 = pubConf.SpeedLimit + case uint32: + av.SampleRate = v + case byte: + av.PayloadType = v + case util.BytesPool: + av.BytesPool = v + case SpesificTrack: + av.SpesificTrack = v + case []any: + av.SetStuff(v...) + default: + av.Base.SetStuff(v) + } + } +} + +func (av *Media) LastWriteTime() time.Time { + return av.LastValue.WriteTime +} + +func (av *Media) CurrentFrame() *AVFrame { + return av.Value +} +func (av *Media) PreFrame() *AVFrame { + return av.LastValue +} + +func (av *Media) generateTimestamp(ts uint32) { + av.Value.PTS = time.Duration(ts) + av.Value.DTS = time.Duration(ts) +} + +func (av *Media) WriteSequenceHead(sh []byte) { + av.SequenceHead = sh + av.SequenceHeadSeq++ +} +func (av *Media) AppendAuBytes(b ...[]byte) { + var au util.BLL + for _, bb := range b { + au.Push(av.BytesPool.GetShell(bb)) + } + av.Value.AUList.PushValue(&au) +} + +func (av *Media) narrow(gop int) { + if l := av.Size - gop; l > 12 { + if log.Trace { + av.Trace("resize", zap.Int("before", av.Size), zap.Int("after", av.Size-5)) + } + //缩小缓冲环节省内存 + av.Reduce(5) + } +} + +func (av *Media) AddIDR() { + if av.BufferTime > 0 { + av.IDRingList.AddIDR(av.Ring) + if av.HistoryRing == nil { + av.HistoryRing = av.IDRing + } + } else { + av.IDRing = av.Ring + } +} + +func (av *Media) Flush() { + curValue, preValue, nextValue := av.Value, av.LastValue, av.Next() + useDts := curValue.Timestamp == 0 + originDTS := curValue.DTS + if av.State == TrackStateOffline { + av.State = TrackStateOnline + if useDts { + av.deltaTs = curValue.DTS - preValue.DTS + } else { + av.deltaTs = curValue.Timestamp - preValue.Timestamp + } + curValue.DTS = preValue.DTS + 900 + curValue.PTS = preValue.PTS + 900 + curValue.Timestamp = preValue.Timestamp + time.Millisecond + av.Info("track back online", zap.Duration("delta", av.deltaTs)) + } else if av.deltaTs != 0 { + if useDts { + curValue.DTS -= av.deltaTs + curValue.PTS -= av.deltaTs + } else { + rtpts := av.deltaTs * 90 / time.Millisecond + curValue.DTS -= rtpts + curValue.PTS -= rtpts + curValue.Timestamp -= av.deltaTs + } + } + if av.起始时间.IsZero() { + curValue.DeltaTime = 0 + if useDts { + curValue.Timestamp = time.Since(av.Publisher.GetStream().GetStartTime()) + } + av.重置(curValue.Timestamp, curValue.DTS) + } else { + if useDts { + deltaDts := curValue.DTS - preValue.DTS + if deltaDts > deltaDTSRange || deltaDts < -deltaDTSRange { + // 时间戳跳变,等同于离线重连 + av.deltaTs = originDTS - preValue.DTS + curValue.DTS = preValue.DTS + 900 + curValue.PTS = preValue.PTS + 900 + av.Warn("track dts reset", zap.Int64("delta1", int64(deltaDts)), zap.Int64("delta2", int64(av.deltaTs))) + } + curValue.Timestamp = av.根据起始DTS计算绝对时间戳(curValue.DTS) + } + + curValue.DeltaTime = uint32(deltaTS(curValue.Timestamp, preValue.Timestamp) / time.Millisecond) + } + if log.Trace { + av.Trace("write", zap.Uint32("seq", curValue.Sequence), zap.Int64("dts0", int64(preValue.DTS)), zap.Int64("dts1", int64(originDTS)), zap.Uint64("dts2", uint64(curValue.DTS)), zap.Uint32("delta", curValue.DeltaTime), zap.Duration("timestamp", curValue.Timestamp), zap.Int("au", curValue.AUList.Length), zap.Int("rtp", curValue.RTP.Length), zap.Int("avcc", curValue.AVCC.ByteLength), zap.Int("raw", curValue.AUList.ByteLength), zap.Int("bps", av.BPS)) + } + bufferTime := av.BufferTime + if bufferTime > 0 && av.IDRingList.IDRList.Length > 1 && deltaTS(curValue.Timestamp, av.IDRingList.IDRList.Next.Next.Value.Value.Timestamp) > bufferTime { + av.ShiftIDR() + av.narrow(int(curValue.Sequence - av.HistoryRing.Value.Sequence)) + } + // 下一帧为订阅起始帧,即将覆盖,需要扩环 + if nextValue == av.IDRing || nextValue == av.HistoryRing { + // if av.AVRing.Size < 512 { + if log.Trace { + av.Trace("resize", zap.Int("before", av.Size), zap.Int("after", av.Size+5), zap.String("name", av.Name)) + } + av.Glow(5) + // } else { + // av.Stream.Error("sub ring overflow", zap.Int("size", av.AVRing.Size), zap.String("name", av.Name)) + // } + } + + if curValue.AUList.Length > 0 { + // 补完RTP + if config.Global.EnableRTP && curValue.RTP.Length == 0 { + av.CompleteRTP(curValue) + } + // 补完AVCC + if config.Global.EnableAVCC && curValue.AVCC.ByteLength == 0 { + av.CompleteAVCC(curValue) + } + } + av.ComputeBPS(curValue.BytesIn) + av.Step() + if av.等待上限 > 0 { + 等待了 := av.控制流速(curValue.Timestamp, curValue.DTS) + if log.Trace && 等待了 > 0 { + av.Trace("speed control", zap.Duration("sleep", 等待了)) + } + } +} + +func deltaTS(curTs time.Duration, preTs time.Duration) time.Duration { + if curTs < preTs { + return curTs + (1<<32)*time.Millisecond - preTs + } + return curTs - preTs +} diff --git a/track/opus.go b/track/opus.go index ddb4dfd..1d5abfd 100644 --- a/track/opus.go +++ b/track/opus.go @@ -9,13 +9,13 @@ import ( var _ SpesificTrack = (*Opus)(nil) -func NewOpus(stream IStream, stuff ...any) (opus *Opus) { +func NewOpus(puber IPuber, stuff ...any) (opus *Opus) { opus = &Opus{} opus.CodecID = codec.CodecID_OPUS opus.SampleSize = 16 opus.Channels = 2 opus.AVCCHead = []byte{(byte(opus.CodecID) << 4) | (1 << 1)} - opus.SetStuff("opus", uint32(48000), byte(111), opus, stuff, stream) + opus.SetStuff("opus", uint32(48000), byte(111), opus, stuff, puber) if opus.BytesPool == nil { opus.BytesPool = make(util.BytesPool, 17) } diff --git a/track/reader-data.go b/track/reader-data.go index 2ce2ed2..a24471d 100644 --- a/track/reader-data.go +++ b/track/reader-data.go @@ -5,7 +5,7 @@ import ( "m7s.live/engine/v4/util" ) -type RingReader[T any, F common.IDataFrame[T]] struct { +type RingReader[T any, F util.IDataFrame[T]] struct { *util.Ring[F] Count int // 读取的帧数 } diff --git a/track/video.go b/track/video.go index 0486ae9..e55031f 100644 --- a/track/video.go +++ b/track/video.go @@ -18,31 +18,26 @@ type Video struct { GOP int //关键帧间隔 nalulenSize int //avcc格式中表示nalu长度的字节数,通常为4 dcChanged bool //解码器配置是否改变了,一般由于变码率导致 - dtsEst *DTSEstimator + dtsEst *util.DTSEstimator lostFlag bool // 是否丢帧 codec.SPSInfo - ParamaterSets `json:"-" yaml:"-"` - SPS []byte `json:"-" yaml:"-"` - PPS []byte `json:"-" yaml:"-"` - SEIReader chan []byte `json:"-" yaml:"-"` + ParamaterSets `json:"-" yaml:"-"` + SPS []byte `json:"-" yaml:"-"` + PPS []byte `json:"-" yaml:"-"` + iframeReceived bool } func (v *Video) Attach() { - if v.Attached.CompareAndSwap(false, true) { - v.Info("attach video track", zap.Uint("width", v.Width), zap.Uint("height", v.Height)) - if err := v.Stream.AddTrack(v).Await(); err != nil { - v.Error("attach video track failed", zap.Error(err)) - v.Attached.Store(false) - } else { - v.Info("video track attached", zap.Uint("width", v.Width), zap.Uint("height", v.Height)) - } + v.Info("attach video track", zap.Uint("width", v.Width), zap.Uint("height", v.Height)) + if err := v.Publisher.GetStream().AddTrack(v).Await(); err != nil { + v.Error("attach video track failed", zap.Error(err)) + } else { + v.Info("video track attached", zap.Uint("width", v.Width), zap.Uint("height", v.Height)) } } func (v *Video) Detach() { - if v.Attached.CompareAndSwap(true, false) { - v.Stream.RemoveTrack(v) - } + v.Publisher.GetStream().RemoveTrack(v) } func (vt *Video) GetName() string { @@ -52,6 +47,9 @@ func (vt *Video) GetName() string { return vt.Name } +func (vt *Video) GetCodec() codec.VideoCodecID { + return vt.CodecID +} // PlayFullAnnexB 订阅annex-b格式的流数据,每一个I帧增加sps、pps头 // func (vt *Video) PlayFullAnnexB(ctx context.Context, onMedia func(net.Buffers) error) error { // for vr := vt.ReadRing(); ctx.Err() == nil; vr.MoveNext() { @@ -202,7 +200,7 @@ func (vt *Video) insertDCRtp() { func (vt *Video) generateTimestamp(ts uint32) { if vt.State == TrackStateOffline { - vt.dtsEst = NewDTSEstimator() + vt.dtsEst = util.NewDTSEstimator() } vt.Value.PTS = time.Duration(ts) vt.Value.DTS = time.Duration(vt.dtsEst.Feed(ts)) @@ -249,26 +247,17 @@ func (vt *Video) CompleteAVCC(rv *AVFrame) { func (vt *Video) Flush() { rv := vt.Value - if vt.SEIReader != nil && len(vt.SEIReader) > 0 { - for seiFrame := range vt.SEIReader { - var au util.BLL - au.Push(vt.SpesificTrack.GetNALU_SEI()) - au.Push(vt.BytesPool.GetShell(seiFrame)) - vt.Info("sei", zap.Int("len", len(seiFrame))) - vt.Value.AUList.UnshiftValue(&au) - if len(vt.SEIReader) == 0 { - break - } - } - } if rv.IFrame { vt.computeGOP() - vt.Stream.SetIDR(vt) + if audioTrack := vt.Publisher.GetAudioTrack(); audioTrack != nil { + audioTrack.Narrow() + } } - if !vt.Attached.Load() { + if !vt.iframeReceived { if vt.IDRing != nil && vt.SequenceHeadSeq > 0 { defer vt.Attach() + vt.iframeReceived = true } else { rv.Reset() return diff --git a/common/dtsestimator.go b/util/dtsestimator.go similarity index 90% rename from common/dtsestimator.go rename to util/dtsestimator.go index c15b7fd..bc02230 100644 --- a/common/dtsestimator.go +++ b/util/dtsestimator.go @@ -1,6 +1,4 @@ -package common - -import "m7s.live/engine/v4/util" +package util // DTSEstimator is a DTS estimator. type DTSEstimator struct { @@ -48,7 +46,7 @@ func (d *DTSEstimator) add(pts uint32) { // Feed provides PTS to the estimator, and returns the estimated DTS. func (d *DTSEstimator) Feed(pts uint32) uint32 { - interval := util.Conditoinal(pts > d.prevPTS, pts-d.prevPTS, d.prevPTS-pts) + interval := Conditoinal(pts > d.prevPTS, pts-d.prevPTS, d.prevPTS-pts) if interval > 10*d.interval { *d = *NewDTSEstimator() } diff --git a/common/dtsestimator_test.go b/util/dtsestimator_test.go similarity index 97% rename from common/dtsestimator_test.go rename to util/dtsestimator_test.go index e236a92..46fc4ff 100644 --- a/common/dtsestimator_test.go +++ b/util/dtsestimator_test.go @@ -1,4 +1,4 @@ -package common +package util import ( "testing" diff --git a/common/ring-writer.go b/util/ring-writer.go similarity index 68% rename from common/ring-writer.go rename to util/ring-writer.go index 1da52f6..b126b46 100644 --- a/common/ring-writer.go +++ b/util/ring-writer.go @@ -1,9 +1,8 @@ -package common +package util import ( "sync/atomic" - "m7s.live/engine/v4/util" ) type emptyLocker struct{} @@ -13,18 +12,35 @@ func (emptyLocker) Unlock() {} var EmptyLocker emptyLocker +type IDataFrame[T any] interface { + Init() // 初始化 + Reset() // 重置数据,复用内存 + Ready() // 标记为可读取 + ReaderEnter() int32 // 读取者数量+1 + ReaderLeave() int32 // 读取者数量-1 + StartWrite() bool // 开始写入 + SetSequence(uint32) // 设置序号 + GetSequence() uint32 // 获取序号 + ReaderCount() int32 // 读取者数量 + Discard() int32 // 如果写入时还有读取者没有离开则废弃该帧,剥离RingBuffer,防止并发读写 + IsDiscarded() bool // 是否已废弃 + IsWriting() bool // 是否正在写入 + Wait() // 阻塞等待可读取 + Broadcast() // 广播可读取 +} + type RingWriter[T any, F IDataFrame[T]] struct { - *util.Ring[F] `json:"-" yaml:"-"` + *Ring[F] `json:"-" yaml:"-"` ReaderCount atomic.Int32 `json:"-" yaml:"-"` - pool *util.Ring[F] + pool *Ring[F] poolSize int Size int LastValue F constructor func() F } -func (rb *RingWriter[T, F]) create(n int) (ring *util.Ring[F]) { - ring = util.NewRing[F](n) +func (rb *RingWriter[T, F]) create(n int) (ring *Ring[F]) { + ring = NewRing[F](n) for p, i := ring, n; i > 0; p, i = p.Next(), i-1 { p.Value = rb.constructor() p.Value.Init() @@ -46,7 +62,7 @@ func (rb *RingWriter[T, F]) Init(n int, constructor func() F) *RingWriter[T, F] // return rb.Value // } -func (rb *RingWriter[T, F]) Glow(size int) (newItem *util.Ring[F]) { +func (rb *RingWriter[T, F]) Glow(size int) (newItem *Ring[F]) { if size < rb.poolSize { newItem = rb.pool.Unlink(size) rb.poolSize -= size @@ -64,7 +80,7 @@ func (rb *RingWriter[T, F]) Glow(size int) (newItem *util.Ring[F]) { return } -func (rb *RingWriter[T, F]) Recycle(r *util.Ring[F]) { +func (rb *RingWriter[T, F]) Recycle(r *Ring[F]) { rb.poolSize++ r.Value.Init() r.Value.Reset() diff --git a/util/safe_chan.go b/util/safe_chan.go index a36ff4b..f635614 100644 --- a/util/safe_chan.go +++ b/util/safe_chan.go @@ -76,6 +76,16 @@ func (p *Promise[S]) Await() (err error) { return } +func (p *Promise[S]) Then(resolved func(S), rejected func(error)) { + go func() { + if err := p.Await(); err == nil { + resolved(p.Value) + } else { + rejected(err) + } + }() +} + func NewPromise[S any](value S) *Promise[S] { ctx0, cancel0 := context.WithTimeout(context.Background(), time.Second*10) ctx, cancel := context.WithCancelCause(ctx0)