diff --git a/stream.go b/stream.go index 12514df..3d32851 100644 --- a/stream.go +++ b/stream.go @@ -512,11 +512,6 @@ func (s *Stream) run() { case action, ok := <-s.actionChan.C: if !ok { return - } else if s.State == STATE_CLOSED { - if s.actionChan.Close() { //再次尝试关闭 - return - } - continue } timeStart = time.Now() switch v := action.(type) { @@ -527,6 +522,7 @@ func (s *Stream) run() { timeOutInfo = zap.String("action", "Publish") if s.IsClosed() { v.Reject(ErrStreamIsClosed) + break } puber := v.Value.GetPublisher() conf := puber.Config @@ -571,6 +567,7 @@ func (s *Stream) run() { timeOutInfo = zap.String("action", "Subscribe") if s.IsClosed() { v.Reject(ErrStreamIsClosed) + break } suber := v.Value io := suber.GetSubscriber() @@ -628,6 +625,9 @@ func (s *Stream) run() { s.onSuberClose(v) case TrackRemoved: timeOutInfo = zap.String("action", "TrackRemoved") + if s.IsClosed() { + break + } name := v.GetName() if t, ok := s.Tracks.LoadAndDelete(name); ok { s.Info("track -1", zap.String("name", name)) @@ -636,6 +636,10 @@ func (s *Stream) run() { } case *util.Promise[Track]: timeOutInfo = zap.String("action", "Track") + if s.IsClosed() { + v.Reject(ErrStreamIsClosed) + break + } if s.State == STATE_WAITPUBLISH { s.action(ACTION_PUBLISH) } @@ -673,6 +677,9 @@ func (s *Stream) run() { timeOutInfo = zap.String("action", "unknown") s.Error("unknown action", timeOutInfo) } + if s.IsClosed() && s.actionChan.Close() { //再次尝试关闭 + return + } } } } diff --git a/track/av1.go b/track/av1.go index 6ebaea1..75a3b38 100644 --- a/track/av1.go +++ b/track/av1.go @@ -15,6 +15,7 @@ var _ SpesificTrack = (*AV1)(nil) type AV1 struct { Video decoder rtpav1.Decoder + encoder rtpav1.Encoder } func NewAV1(stream IStream, stuff ...any) (vt *AV1) { @@ -24,8 +25,11 @@ func NewAV1(stream IStream, stuff ...any) (vt *AV1) { if vt.BytesPool == nil { vt.BytesPool = make(util.BytesPool, 17) } - vt.nalulenSize = 4 + vt.nalulenSize = 0 vt.dtsEst = NewDTSEstimator() + vt.decoder.Init() + vt.encoder.Init() + vt.encoder.PayloadType = vt.PayloadType return } @@ -33,7 +37,7 @@ func (vt *AV1) writeSequenceHead(head []byte) (err error) { vt.WriteSequenceHead(head) var info codec.AV1CodecConfigurationRecord info.Unmarshal(head[5:]) - vt.ParamaterSets[0] = info.ConfigOBUs + vt.ParamaterSets = [][]byte{info.ConfigOBUs, {info.SeqLevelIdx0, info.SeqProfile, info.SeqTier0}} return } @@ -44,32 +48,32 @@ func (vt *AV1) WriteAVCC(ts uint32, frame *util.BLL) (err error) { } b0 := frame.GetByte(0) if isExtHeader := (b0 >> 4) & 0b1000; isExtHeader != 0 { - firstBuffer := frame.Next.Value + // firstBuffer := frame.Next.Value packetType := b0 & 0b1111 switch packetType { case codec.PacketTypeSequenceStart: header := frame.ToBytes() - header[0] = 0x1d - header[1] = 0x00 - header[2] = 0x00 - header[3] = 0x00 - header[4] = 0x00 + // header[0] = 0x1d + // header[1] = 0x00 + // header[2] = 0x00 + // header[3] = 0x00 + // header[4] = 0x00 err = vt.writeSequenceHead(header) frame.Recycle() return case codec.PacketTypeCodedFrames: - firstBuffer[0] = b0 & 0b0111_1111 & 0xFD - firstBuffer[1] = 0x01 - copy(firstBuffer[2:], firstBuffer[5:]) - frame.Next.Value = firstBuffer[:firstBuffer.Len()-3] - frame.ByteLength -= 3 + // firstBuffer[0] = b0 & 0b0111_1111 & 0xFD + // firstBuffer[1] = 0x01 + // copy(firstBuffer[2:], firstBuffer[5:]) + // frame.Next.Value = firstBuffer[:firstBuffer.Len()-3] + // frame.ByteLength -= 3 return vt.Video.WriteAVCC(ts, frame) case codec.PacketTypeCodedFramesX: - firstBuffer[0] = b0 & 0b0111_1111 & 0xFD - firstBuffer[1] = 0x01 - firstBuffer[2] = 0 - firstBuffer[3] = 0 - firstBuffer[4] = 0 + // firstBuffer[0] = b0 & 0b0111_1111 & 0xFD + // firstBuffer[1] = 0x01 + // firstBuffer[2] = 0 + // firstBuffer[3] = 0 + // firstBuffer[4] = 0 return vt.Video.WriteAVCC(ts, frame) } } else { @@ -108,3 +112,26 @@ func (vt *AV1) WriteRTPFrame(rtpItem *util.ListItem[RTPFrame]) { vt.Flush() } } + +// RTP格式补完 +func (vt *AV1) CompleteRTP(value *AVFrame) { + rtps, err := vt.encoder.Encode(vt.Value.AUList.ToBuffers()) + if err != nil { + vt.Error("AV1 encoder encode error", zap.Error(err)) + return + } + if vt.Value.IFrame { + rtpItem := vt.GetRTPFromPool() + packet := &rtpItem.Value + br := util.LimitBuffer{Buffer: packet.Payload} + packet.Timestamp = uint32(vt.Value.PTS) + packet.Marker = false + br.Write(vt.ParamaterSets[0]) + packet.Payload = br.Bytes() + vt.Value.RTP.Push(rtpItem) + } + + for _, rtp := range rtps { + vt.Value.RTP.PushValue(RTPFrame{Packet: rtp}) + } +} diff --git a/track/video.go b/track/video.go index d335299..9fbb21d 100644 --- a/track/video.go +++ b/track/video.go @@ -1,7 +1,6 @@ package track import ( - "time" "github.com/pion/rtp" @@ -127,41 +126,64 @@ func (vt *Video) WriteAVCC(ts uint32, frame *util.BLL) (e error) { if err != nil { return err } - b = (b >> 4) & 0b0111 - vt.Value.IFrame = b == 1 || b == 4 - r.ReadByte() //sequence frame flag - cts, err := r.ReadBE(3) + isExtHeader := (b >> 4) & 0b1000 + frameType := (b >> 4) & 0b0111 + vt.Value.IFrame = frameType == 1 || frameType == 4 + packetType := b & 0b1111 + var cts uint32 + if isExtHeader != 0 { + fourcCC, _ := r.ReadBE(4) //sequence frame flag + switch packetType { + case codec.PacketTypeSequenceStart: + case codec.PacketTypeCodedFrames: + if fourcCC == codec.FourCC_H265_32 { + cts, err = r.ReadBE(3) + } + case codec.PacketTypeCodedFramesX: + } + } else { + r.ReadByte() //sequence frame flag + cts, err = r.ReadBE(3) + if err != nil { + return err + } + } if err != nil { return err } vt.Value.PTS = time.Duration(ts+cts) * 90 vt.Value.DTS = time.Duration(ts) * 90 // println(":", vt.Value.Sequence) - var nalulen uint32 - for nalulen, e = r.ReadBE(vt.nalulenSize); e == nil; nalulen, e = r.ReadBE(vt.nalulenSize) { - if remain := frame.ByteLength - r.GetOffset(); remain < int(nalulen) { - vt.Error("read nalu length error", zap.Int("nalulen", int(nalulen)), zap.Int("remain", remain)) - frame.Recycle() - vt.Value.Reset() - return - // for bbb.CanRead() { - // nalulen = bbb.ReadUint32() - // if bbb.CanReadN(int(nalulen)) { - // bbb.ReadN(int(nalulen)) - // } else { - // panic("read nalu error1") - // } + if isExtHeader == 0 { + var nalulen uint32 + for nalulen, e = r.ReadBE(vt.nalulenSize); e == nil; nalulen, e = r.ReadBE(vt.nalulenSize) { + if remain := frame.ByteLength - r.GetOffset(); remain < int(nalulen) { + vt.Error("read nalu length error", zap.Int("nalulen", int(nalulen)), zap.Int("remain", remain)) + frame.Recycle() + vt.Value.Reset() + return + // for bbb.CanRead() { + // nalulen = bbb.ReadUint32() + // if bbb.CanReadN(int(nalulen)) { + // bbb.ReadN(int(nalulen)) + // } else { + // panic("read nalu error1") + // } + // } + // panic("read nalu error2") + } + // var au util.BLL + // for _, bb := range r.ReadN(int(nalulen)) { + // au.Push(vt.BytesPool.GetShell(bb)) // } - // panic("read nalu error2") + // println(":", nalulen, au.ByteLength) + // vt.Value.AUList.PushValue(&au) + vt.AppendAuBytes(r.ReadN(int(nalulen))...) } - // var au util.BLL - // for _, bb := range r.ReadN(int(nalulen)) { - // au.Push(vt.BytesPool.GetShell(bb)) - // } - // println(":", nalulen, au.ByteLength) - // vt.Value.AUList.PushValue(&au) - vt.AppendAuBytes(r.ReadN(int(nalulen))...) + } else { + vt.AppendAuBytes(r.ReadN(frame.ByteLength - 5)...) } + vt.Value.WriteAVCC(ts, frame) // { // b := util.Buffer(vt.Value.AVCC.ToBytes()[5:]) diff --git a/util/safe_chan.go b/util/safe_chan.go index e1ff89c..a36ff4b 100644 --- a/util/safe_chan.go +++ b/util/safe_chan.go @@ -1,8 +1,11 @@ package util import ( + "context" + "errors" "math" "sync/atomic" + "time" ) // SafeChan安全的channel,可以防止close后被写入的问题 @@ -46,33 +49,40 @@ func (sc *SafeChan[T]) IsFull() bool { return atomic.LoadInt32(&sc.senders) > 0 } +var errResolved = errors.New("resolved") + type Promise[S any] struct { + context.Context + context.CancelCauseFunc + context.CancelFunc Value S - c chan error - state int32 // 0 pendding 1 fullfilled -1 rejected } func (r *Promise[S]) Resolve() { - if atomic.CompareAndSwapInt32(&r.state, 0, 1) { - r.c <- nil - close(r.c) - } + r.CancelCauseFunc(errResolved) } func (r *Promise[S]) Reject(err error) { - if atomic.CompareAndSwapInt32(&r.state, 0, -1) { - r.c <- err - close(r.c) - } + r.CancelCauseFunc(err) } -func (p *Promise[S]) Await() error { - return <-p.c +func (p *Promise[S]) Await() (err error) { + <-p.Done() + err = context.Cause(p.Context) + if err == errResolved { + err = nil + } + p.CancelFunc() + return } func NewPromise[S any](value S) *Promise[S] { + ctx0, cancel0 := context.WithTimeout(context.Background(), time.Second*10) + ctx, cancel := context.WithCancelCause(ctx0) return &Promise[S]{ - Value: value, - c: make(chan error, 1), + Value: value, + Context: ctx, + CancelCauseFunc: cancel, + CancelFunc: cancel0, } }