diff --git a/rtmp/rtmp_publisher.go b/rtmp/rtmp_publisher.go index a0aa0c4..81621d3 100644 --- a/rtmp/rtmp_publisher.go +++ b/rtmp/rtmp_publisher.go @@ -2,27 +2,114 @@ package rtmp import ( "github.com/yangjiechina/avformat/libflv" + "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/live-server/stream" ) type Publisher struct { stream.SourceImpl - deMuxer libflv.DeMuxer + deMuxer libflv.DeMuxer + audioMemoryPool stream.MemoryPool + videoMemoryPool stream.MemoryPool + audioPacket []byte + videoPacket []byte + + audioUnmark bool + videoUnmark bool } func NewPublisher(sourceId string) *Publisher { publisher := &Publisher{SourceImpl: stream.SourceImpl{Id_: sourceId}} - muxer := &libflv.DeMuxer{} + publisher.deMuxer = libflv.DeMuxer{} //设置回调,从flv解析出来的Stream和AVPacket都将统一回调到stream.SourceImpl - muxer.SetHandler(publisher) + publisher.deMuxer.SetHandler(publisher) + + //创建内存池 + publisher.audioMemoryPool = stream.NewMemoryPool(48000 * (stream.AppConfig.GOPCache + 1)) + if stream.AppConfig.GOPCache > 0 { + //以每秒钟4M码率大小创建内存池 + publisher.videoMemoryPool = stream.NewMemoryPool(4096 * 1000 / 8 * stream.AppConfig.GOPCache) + } else { + publisher.videoMemoryPool = stream.NewMemoryPool(4096 * 1000 / 8) + } return publisher } -// OnVideo 从rtmpchunk解析过来的视频包 +func (p *Publisher) OnDeMuxStream(stream_ utils.AVStream) { + tmp := stream_.Extra() + bytes := make([]byte, len(tmp)) + copy(bytes, tmp) + stream_.SetExtraData(bytes) + + if utils.AVMediaTypeAudio == stream_.Type() { + p.audioMemoryPool.FreeTail(len(p.audioPacket)) + } else if utils.AVMediaTypeVideo == stream_.Type() { + p.videoMemoryPool.FreeTail(len(p.videoPacket)) + } + + p.SourceImpl.OnDeMuxStream(stream_) +} + +func (p *Publisher) OnDeMuxStreamDone() { + +} + +func (p *Publisher) OnDeMuxPacket(index int, packet utils.AVPacket) { + p.SourceImpl.OnDeMuxPacket(index, packet) + + if stream.AppConfig.GOPCache > 0 { + return + } + + if utils.AVMediaTypeAudio == packet.MediaType() { + p.audioMemoryPool.FreeHead(len(packet.Data())) + } else if utils.AVMediaTypeVideo == packet.MediaType() { + p.videoMemoryPool.FreeHead(len(packet.Data())) + } +} + +func (p *Publisher) OnDeMuxDone() { + +} + +// OnVideo 从rtm chunk解析过来的视频包 func (p *Publisher) OnVideo(data []byte, ts uint32) { + if data == nil { + data = p.videoMemoryPool.Fetch() + p.videoUnmark = false + } + + p.videoPacket = data _ = p.deMuxer.InputVideo(data, ts) } func (p *Publisher) OnAudio(data []byte, ts uint32) { + if data == nil { + data = p.audioMemoryPool.Fetch() + p.audioUnmark = false + } + + p.audioPacket = data _ = p.deMuxer.InputAudio(data, ts) } + +// OnPartPacket 从rtmp解析过来的部分音视频包 +func (p *Publisher) OnPartPacket(index int, data []byte, first bool) { + //audio + if index == 0 { + if p.audioUnmark { + p.audioMemoryPool.Mark() + p.audioUnmark = true + } + + p.audioMemoryPool.Write(data) + //video + } else if index == 1 { + if p.videoUnmark { + p.videoMemoryPool.Mark() + p.videoUnmark = true + } + + p.videoMemoryPool.Write(data) + } +} diff --git a/rtmp/rtmp_server.go b/rtmp/rtmp_server.go index fd30636..cdb8457 100644 --- a/rtmp/rtmp_server.go +++ b/rtmp/rtmp_server.go @@ -1,8 +1,8 @@ package rtmp import ( - "github.com/yangjiechina/avformat" "github.com/yangjiechina/avformat/transport" + "github.com/yangjiechina/avformat/utils" "net" ) @@ -17,7 +17,7 @@ type serverImpl struct { } func (s *serverImpl) Start(addr net.Addr) error { - avformat.Assert(s.tcp == nil) + utils.Assert(s.tcp == nil) server := &transport.TCPServer{} server.SetHandler(s) @@ -37,7 +37,7 @@ func (s *serverImpl) Close() { func (s *serverImpl) OnConnected(conn net.Conn) { t := conn.(*transport.Conn) - t.Data = NewSession() + t.Data = NewSession(conn) } func (s *serverImpl) OnPacket(conn net.Conn, data []byte) { diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index 32d9d62..9b269bc 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -1,11 +1,10 @@ package rtmp import ( - "github.com/yangjiechina/avformat" "github.com/yangjiechina/avformat/librtmp" + "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/live-server/stream" "net" - "net/http" ) // Session 负责除RTMP连接和断开以外的所有生命周期处理 @@ -15,43 +14,49 @@ type Session interface { Close() } -func NewSession() *sessionImpl { +func NewSession(conn net.Conn) Session { impl := &sessionImpl{} stack := librtmp.NewStack(impl) impl.stack = stack + impl.conn = conn return impl } type sessionImpl struct { stream.SessionImpl + //解析rtmp协议栈 stack *librtmp.Stack //publisher/sink handle interface{} + conn net.Conn streamId string } -func (s *sessionImpl) OnPublish(app, stream_ string, response chan avformat.HookState) { +func (s *sessionImpl) OnPublish(app, stream_ string, response chan utils.HookState) { s.streamId = app + "/" + stream_ publisher := NewPublisher(s.streamId) s.stack.SetOnPublishHandler(publisher) + s.stack.SetOnTransDeMuxerHandler(publisher) + //stream.SessionImpl统一处理, Source是否已经存在, Hook回调.... s.SessionImpl.OnPublish(publisher, nil, func() { s.handle = publisher - response <- http.StatusOK - }, func(state avformat.HookState) { + response <- utils.HookStateOK + }, func(state utils.HookState) { response <- state }) } -func (s *sessionImpl) OnPlay(app, stream string, response chan avformat.HookState) { - s.streamId = app + "/" + stream - //sink := &Sink{} - //s.SessionImpl.OnPlay(sink, nil, func() { - // s.handle = sink - // response <- http.StatusOK - //}, func(state avformat.HookState) { - // response <- state - //}) +func (s *sessionImpl) OnPlay(app, stream_ string, response chan utils.HookState) { + s.streamId = app + "/" + stream_ + + sink := NewSink(stream.GenerateSinkId(s.conn), s.conn) + s.SessionImpl.OnPlay(sink, nil, func() { + s.handle = sink + response <- utils.HookStateOK + }, func(state utils.HookState) { + response <- state + }) } func (s *sessionImpl) Input(conn net.Conn, data []byte) error { diff --git a/rtmp/rtmp_sink.go b/rtmp/rtmp_sink.go index 1ed3094..46f318e 100644 --- a/rtmp/rtmp_sink.go +++ b/rtmp/rtmp_sink.go @@ -1,7 +1,11 @@ package rtmp -import "github.com/yangjiechina/live-server/stream" +import ( + "github.com/yangjiechina/avformat/utils" + "github.com/yangjiechina/live-server/stream" + "net" +) -type Sink struct { - stream.SinkImpl +func NewSink(id stream.SinkId, conn net.Conn) stream.ISink { + return &stream.SinkImpl{Id_: id, Protocol_: stream.ProtocolRtmp, Conn: conn, DesiredAudioCodecId_: utils.AVCodecIdNONE, DesiredVideoCodecId_: utils.AVCodecIdNONE} } diff --git a/rtmp/rtmp_transtream.go b/rtmp/rtmp_transtream.go index 009a7ca..3ced4f6 100644 --- a/rtmp/rtmp_transtream.go +++ b/rtmp/rtmp_transtream.go @@ -1,4 +1,158 @@ package rtmp +import ( + "github.com/yangjiechina/avformat/libflv" + "github.com/yangjiechina/avformat/librtmp" + "github.com/yangjiechina/avformat/utils" + "github.com/yangjiechina/live-server/stream" +) + type TransStream struct { + stream.TransStreamImpl + chunkSize int + header []byte //音视频头chunk + headerSize int + muxer *libflv.Muxer + + audioChunk librtmp.Chunk + videoChunk librtmp.Chunk + + memoryPool stream.MemoryPool + transBuffer stream.StreamBuffer +} + +func (t *TransStream) Input(packet utils.AVPacket) { + utils.Assert(t.TransStreamImpl.Completed) + var data []byte + var chunk *librtmp.Chunk + var videoPkt bool + + if utils.AVMediaTypeAudio == packet.MediaType() { + data = packet.Data() + chunk = &t.audioChunk + } else if utils.AVMediaTypeVideo == packet.MediaType() { + videoPkt = true + data = packet.AVCCPacketData() + chunk = &t.videoChunk + } + + length := len(data) + //rtmp chunk消息体的数据大小 + payloadSize := 5 + length + payloadSize += payloadSize / t.chunkSize + + //分配内存 + t.memoryPool.Mark() + allocate := t.memoryPool.Allocate(12 + payloadSize) + + //写chunk头 + chunk.Length = payloadSize + chunk.Timestamp = uint32(packet.Dts()) + n := chunk.ToBytes(allocate) + utils.Assert(n == 12) + + //写flv + ct := packet.Pts() - packet.Dts() + if videoPkt { + n += t.muxer.WriteVideoData(allocate, uint32(ct), packet.KeyFrame(), false) + } else { + n += t.muxer.WriteAudioData(allocate, false) + } + + for length > 0 { + min := utils.MinInt(length, t.chunkSize) + copy(allocate[n:], data[:min]) + n += min + + length -= min + data = data[min:] + + //写一个ChunkType3用作分割 + if length > 0 { + if videoPkt { + allocate[n] = (0x3 << 6) | byte(librtmp.ChunkStreamIdVideo) + } else { + allocate[n] = (0x3 << 6) | byte(librtmp.ChunkStreamIdAudio) + } + n++ + } + } + + rtmpData := t.memoryPool.Fetch() + ret := t.transBuffer.AddPacket(rtmpData, packet.KeyFrame() && utils.AVMediaTypeVideo == packet.MediaType(), packet.Dts()) + if ret { + //发送给sink + + for _, sink := range t.Sinks { + sink.Input(rtmpData) + } + } + +} + +func (t *TransStream) AddSink(sink stream.ISink) { + t.TransStreamImpl.AddSink(sink) + + t.transBuffer.Peek(func(packet interface{}) { + sink.Input(packet.([]byte)) + }) +} + +func (t *TransStream) onDiscardPacket(pkt interface{}) { + bytes := pkt.([]byte) + t.memoryPool.FreeHead(len(bytes)) +} + +func (t *TransStream) WriteHeader() error { + utils.Assert(t.Tracks != nil) + utils.Assert(!t.TransStreamImpl.Completed) + + var audioStream utils.AVStream + var videoStream utils.AVStream + var audioCodecId utils.AVCodecID + var videoCodecId utils.AVCodecID + + for _, track := range t.Tracks { + if utils.AVMediaTypeAudio == track.Type() { + audioStream = track + audioCodecId = audioStream.CodecId() + t.audioChunk = librtmp.NewAudioChunk() + } else if utils.AVMediaTypeAudio == track.Type() { + videoStream = track + videoCodecId = videoStream.CodecId() + t.videoChunk = librtmp.NewVideoChunk() + } + } + + utils.Assert(audioStream != nil || videoStream != nil) + + //初始化 + t.header = make([]byte, 1024) + t.muxer = libflv.NewMuxer(audioCodecId, videoCodecId, 0, 0, 0) + t.memoryPool = stream.NewMemoryPool(1024 * 1024 * 2) + t.transBuffer = stream.NewStreamBuffer(2000) + t.transBuffer.SetDiscardHandler(t.onDiscardPacket) + + var n int + if audioStream != nil { + n += t.muxer.WriteAudioData(t.header, true) + extra := audioStream.Extra() + copy(t.header[n:], extra) + n += len(extra) + } + + if videoStream != nil { + n += t.muxer.WriteAudioData(t.header[n:], true) + extra := videoStream.Extra() + copy(t.header[n:], extra) + n += len(extra) + } + + t.headerSize = n + return nil +} + +func NewTransStream(chunkSize int) stream.ITransStream { + transStream := &TransStream{chunkSize: chunkSize} + return transStream } diff --git a/stream/memory_pool.go b/stream/memory_pool.go index cf9ab5d..41b3f3a 100644 --- a/stream/memory_pool.go +++ b/stream/memory_pool.go @@ -1,29 +1,116 @@ package stream -// MemoryPool -// 从解复用阶段,拼凑成完整的AVPacket开始(写),到GOP缓存结束(释放),整个过程都使用池中内存 +import ( + "github.com/yangjiechina/avformat/utils" +) + +// MemoryPool 从解复用阶段,拼凑成完整的AVPacket开始(写),到GOP缓存结束(释放),整个过程都使用池中内存 +// 类似环形缓冲区, 区别在于,写入的内存块是连续的、整块内存. type MemoryPool interface { + // Mark 标记一块写的内存地址 + //使用流程 Mark->Write/Allocate....->Fetch/Reset + Mark() + + Write(data []byte) + Allocate(size int) []byte - Free(size int) + Fetch() []byte + + // Reset 清空此次Write的标记,本次缓存的数据无效 + Reset() + + // FreeHead 从头部释放指定大小内存 + FreeHead(size int) + + // FreeTail 从尾部释放指定大小内存 + FreeTail(size int) } func NewMemoryPool(capacity int) MemoryPool { pool := &memoryPool{ - data: make([]byte, capacity), + data: make([]byte, capacity), + capacity: capacity, } return pool } type memoryPool struct { - data []byte - size int + data []byte + ptrStart uintptr + ptrEnd uintptr + //剩余的可用内存空间不足以为此次write + capacity int + head int + tail int + + //保存开始索引 + mark int +} + +// 根据head和tail计算出可用的内存地址 +func (m *memoryPool) allocate(size int) []byte { + if m.capacity-m.tail < size { + //使用从头释放的内存 + if m.tail-m.mark+size <= m.head { + copy(m.data, m.data[m.mark:m.tail]) + m.capacity = m.mark + m.tail = m.tail - m.mark + m.mark = 0 + } else { + + //扩容 + capacity := (cap(m.data) + m.tail - m.mark + size) * 3 / 2 + bytes := make([]byte, capacity) + //不对之前的内存进行复制, 已经被AVPacket引用, 自行GC + copy(bytes, m.data[m.mark:m.tail]) + m.data = bytes + m.capacity = capacity + m.tail = m.tail - m.mark + m.mark = 0 + m.head = 0 + + } + } + + bytes := m.data[m.tail:] + m.tail += size + return bytes +} + +func (m *memoryPool) Mark() { + m.mark = m.tail +} + +func (m *memoryPool) Write(data []byte) { + allocate := m.allocate(len(data)) + copy(allocate, data) } func (m *memoryPool) Allocate(size int) []byte { - return nil + return m.allocate(size) } -func (m *memoryPool) Free(size int) { +func (m *memoryPool) Fetch() []byte { + return m.data[m.mark:m.tail] +} + +func (m *memoryPool) Reset() { + m.tail = m.mark +} + +func (m *memoryPool) FreeHead(size int) { + m.head += size + if m.head == m.tail { + m.head = 0 + m.tail = 0 + } else if m.head >= m.capacity { + m.head = 0 + } +} + +func (m *memoryPool) FreeTail(size int) { + m.tail -= size + utils.Assert(m.tail >= 0) } diff --git a/stream/memory_pool_test.go b/stream/memory_pool_test.go new file mode 100644 index 0000000..c3025c0 --- /dev/null +++ b/stream/memory_pool_test.go @@ -0,0 +1,25 @@ +package stream + +import ( + "encoding/hex" + "testing" +) + +func TestMemoryPool(t *testing.T) { + bytes := make([]byte, 10) + for i := 0; i < 10; i++ { + bytes[i] = byte(i) + } + + pool := NewMemoryPool(5) + for i := 0; i < 10; i++ { + pool.Mark() + pool.Write(bytes) + fetch := pool.Fetch() + println(hex.Dump(fetch)) + + if i%2 == 0 { + pool.FreeHead(len(fetch)) + } + } +} diff --git a/stream/ring_buffer.go b/stream/ring_buffer.go index 8aa9bed..5c0fe15 100644 --- a/stream/ring_buffer.go +++ b/stream/ring_buffer.go @@ -14,6 +14,8 @@ type RingBuffer interface { Tail() interface{} Size() int + + All() ([]interface{}, []interface{}) } func NewRingBuffer(capacity int) RingBuffer { @@ -75,3 +77,11 @@ func (r *ringBuffer) Tail() interface{} { func (r *ringBuffer) Size() int { return r.size } + +func (r *ringBuffer) All() ([]interface{}, []interface{}) { + if r.head < r.tail { + return r.data[r.head:r.tail], nil + } else { + return r.data[r.head:], r.data[:r.tail] + } +} diff --git a/stream/session.go b/stream/session.go index 936a245..22900f0 100644 --- a/stream/session.go +++ b/stream/session.go @@ -1,19 +1,19 @@ package stream import ( - "github.com/yangjiechina/avformat" + "github.com/yangjiechina/avformat/utils" "net/http" ) // Session 封装推拉流Session 统一管理,统一 hook回调 type Session interface { - OnPublish(source ISource, pra map[string]interface{}, success func(), failure func(state avformat.HookState)) + OnPublish(source ISource, pra map[string]interface{}, success func(), failure func(state utils.HookState)) OnPublishDone() - OnPlay(sink ISink, pra map[string]interface{}, success func(), failure func(state avformat.HookState)) + OnPlay(sink ISink, pra map[string]interface{}, success func(), failure func(state utils.HookState)) - OnPlayDone(pra map[string]interface{}, success func(), failure func(state avformat.HookState)) + OnPlayDone(pra map[string]interface{}, success func(), failure func(state utils.HookState)) } type SessionImpl struct { @@ -30,19 +30,19 @@ func (s *SessionImpl) AddInfoParams(data map[string]interface{}) { data["remoteAddr"] = s.remoteAddr } -func (s *SessionImpl) OnPublish(source_ ISource, pra map[string]interface{}, success func(), failure func(state avformat.HookState)) { +func (s *SessionImpl) OnPublish(source_ ISource, pra map[string]interface{}, success func(), failure func(state utils.HookState)) { //streamId 已经被占用 source := SourceManager.Find(s.stream) if source != nil { - failure(avformat.HookStateOccupy) + failure(utils.HookStateOccupy) return } if !AppConfig.Hook.EnableOnPublish() { - if err := SourceManager.Add(source_); err != nil { + if err := SourceManager.Add(source_); err == nil { success() } else { - failure(avformat.HookStateOccupy) + failure(utils.HookStateOccupy) } return @@ -54,18 +54,18 @@ func (s *SessionImpl) OnPublish(source_ ISource, pra map[string]interface{}, suc s.AddInfoParams(pra) err := s.DoPublish(pra, func(response *http.Response) { - if err := SourceManager.Add(source_); err != nil { + if err := SourceManager.Add(source_); err == nil { success() } else { - failure(avformat.HookStateOccupy) + failure(utils.HookStateOccupy) } }, func(response *http.Response, err error) { - failure(avformat.HookStateFailure) + failure(utils.HookStateFailure) }) //hook地址连接失败 if err != nil { - failure(avformat.HookStateFailure) + failure(utils.HookStateFailure) return } } @@ -74,7 +74,7 @@ func (s *SessionImpl) OnPublishDone() { } -func (s *SessionImpl) OnPlay(sink ISink, pra map[string]interface{}, success func(), failure func(state avformat.HookState)) { +func (s *SessionImpl) OnPlay(sink ISink, pra map[string]interface{}, success func(), failure func(state utils.HookState)) { f := func() { source := SourceManager.Find(s.stream) if source == nil { @@ -99,15 +99,15 @@ func (s *SessionImpl) OnPlay(sink ISink, pra map[string]interface{}, success fun f() success() }, func(response *http.Response, err error) { - failure(avformat.HookStateFailure) + failure(utils.HookStateFailure) }) if err != nil { - failure(avformat.HookStateFailure) + failure(utils.HookStateFailure) return } } -func (s *SessionImpl) OnPlayDone(pra map[string]interface{}, success func(), failure func(state avformat.HookState)) { +func (s *SessionImpl) OnPlayDone(pra map[string]interface{}, success func(), failure func(state utils.HookState)) { } diff --git a/stream/sink.go b/stream/sink.go index cc6bcd6..0af9e54 100644 --- a/stream/sink.go +++ b/stream/sink.go @@ -1,15 +1,16 @@ package stream -import "github.com/yangjiechina/avformat/utils" +import ( + "github.com/yangjiechina/avformat/utils" + "net" +) -type SinkId string +type SinkId interface{} type ISink interface { Id() SinkId - Input(data []byte) - - Send(buffer utils.ByteBuffer) + Input(data []byte) error SourceId() string @@ -33,6 +34,26 @@ type ISink interface { Close() } +// GenerateSinkId 根据Conn生成SinkId IPV4使用一个uint64, IPV6使用String +func GenerateSinkId(conn net.Conn) SinkId { + network := conn.RemoteAddr().Network() + if "tcp" == network { + id := uint64(utils.BytesToInt(conn.RemoteAddr().(*net.TCPAddr).IP.To4())) + id <<= 32 + id |= uint64(conn.RemoteAddr().(*net.TCPAddr).Port << 16) + + return id + } else if "udp" == network { + id := uint64(utils.BytesToInt(conn.RemoteAddr().(*net.UDPAddr).IP.To4())) + id <<= 32 + id |= uint64(conn.RemoteAddr().(*net.UDPAddr).Port << 16) + + return id + } + + return conn.RemoteAddr().String() +} + func AddSinkToWaitingQueue(streamId string, sink ISink) { } @@ -46,35 +67,37 @@ func PopWaitingSinks(streamId string) []ISink { } type SinkImpl struct { - id string - protocol Protocol + Id_ SinkId + sourceId string + Protocol_ Protocol enableVideo bool - desiredAudioCodecId utils.AVCodecID - desiredVideoCodecId utils.AVCodecID + DesiredAudioCodecId_ utils.AVCodecID + DesiredVideoCodecId_ utils.AVCodecID + + Conn net.Conn } -func (s *SinkImpl) Id() string { - return s.id +func (s *SinkImpl) Id() SinkId { + return s.Id_ } -func (s *SinkImpl) Input(data []byte) { - //TODO implement me - panic("implement me") -} +func (s *SinkImpl) Input(data []byte) error { + if s.Conn != nil { + _, err := s.Conn.Write(data) -func (s *SinkImpl) Send(buffer utils.ByteBuffer) { - //TODO implement me - panic("implement me") + return err + } + + return nil } func (s *SinkImpl) SourceId() string { - //TODO implement me - panic("implement me") + return s.sourceId } func (s *SinkImpl) Protocol() Protocol { - return s.protocol + return s.Protocol_ } func (s *SinkImpl) State() int { @@ -95,7 +118,14 @@ func (s *SinkImpl) SetEnableVideo(enable bool) { s.enableVideo = enable } -func (s *SinkImpl) Close() { - //TODO implement me - panic("implement me") +func (s *SinkImpl) DesiredAudioCodecId() utils.AVCodecID { + return s.DesiredAudioCodecId_ +} + +func (s *SinkImpl) DesiredVideoCodecId() utils.AVCodecID { + return s.DesiredVideoCodecId_ +} + +func (s *SinkImpl) Close() { + } diff --git a/stream/source.go b/stream/source.go index 0b72dfa..278de4f 100644 --- a/stream/source.go +++ b/stream/source.go @@ -1,7 +1,7 @@ package stream import ( - "github.com/yangjiechina/avformat" + "fmt" "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/live-server/transcode" "time" @@ -84,6 +84,7 @@ type SourceImpl struct { videoTranscoders []transcode.ITranscoder //视频解码器 originStreams StreamManager //推流的音视频Streams allStreams StreamManager //推流Streams+转码器获得的Streams + buffers []StreamBuffer completed bool probeTimer *time.Timer @@ -153,12 +154,12 @@ func (s *SourceImpl) AddSink(sink ISink) bool { //创建音频转码器 if !disableAudio && audioCodecId != audioStream.CodecId() { - avformat.Assert(false) + utils.Assert(false) } //创建视频转码器 if !disableVideo && videoCodecId != videoStream.CodecId() { - avformat.Assert(false) + utils.Assert(false) } var streams [5]utils.AVStream @@ -172,32 +173,53 @@ func (s *SourceImpl) AddSink(sink ISink) bool { index++ } - //transStreamId := GenerateTransStreamId(sink.Protocol(), streams[:]...) - TransStreamFactory(sink.Protocol(), streams[:]) + transStreamId := GenerateTransStreamId(sink.Protocol(), streams[:]...) + transStream, ok := s.transStreams[transStreamId] + if ok { + transStream = TransStreamFactory(sink.Protocol(), streams[:]) + s.transStreams[transStreamId] = transStream + + for i := 0; i < index; i++ { + transStream.AddTrack(streams[i]) + } + + _ = transStream.WriteHeader() + } + + transStream.AddSink(sink) return false } func (s *SourceImpl) RemoveSink(tid TransStreamId, sinkId string) bool { - //TODO implement me - panic("implement me") + return true } func (s *SourceImpl) Close() { - //TODO implement me - panic("implement me") + } func (s *SourceImpl) OnDeMuxStream(stream utils.AVStream) { + if s.completed { + fmt.Printf("添加Stream失败 Source: %s已经WriteHeader", s.Id_) + return + } + s.originStreams.Add(stream) s.allStreams.Add(stream) if len(s.originStreams.All()) == 1 { s.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, s.writeHeader) } + + //为每个Stream创建对于的Buffer + if AppConfig.GOPCache > 0 { + buffer := NewStreamBuffer(int64(AppConfig.GOPCache)) + s.buffers = append(s.buffers, buffer) + } } // 从DeMuxer解析完Stream后, 处理等待Sinks func (s *SourceImpl) writeHeader() { - avformat.Assert(!s.completed) + utils.Assert(!s.completed) s.probeTimer.Stop() s.completed = true @@ -212,7 +234,14 @@ func (s *SourceImpl) OnDeMuxStreamDone() { } func (s *SourceImpl) OnDeMuxPacket(index int, packet utils.AVPacket) { + if AppConfig.GOPCache > 0 { + buffer := s.buffers[packet.Index()] + buffer.AddPacket(packet, packet.KeyFrame(), packet.Dts()) + } + for _, stream := range s.transStreams { + stream.Input(packet) + } } func (s *SourceImpl) OnDeMuxDone() { diff --git a/stream/stream_buffer.go b/stream/stream_buffer.go new file mode 100644 index 0000000..c5c1157 --- /dev/null +++ b/stream/stream_buffer.go @@ -0,0 +1,108 @@ +package stream + +// StreamBuffer GOP缓存 +type StreamBuffer interface { + + // AddPacket Return bool 缓存帧是否成功, 如果首帧非关键帧, 缓存失败 + AddPacket(packet interface{}, key bool, ts int64) bool + + // SetDiscardHandler 设置丢弃帧时的回调 + SetDiscardHandler(handler func(packet interface{})) + + Peek(handler func(packet interface{})) + + Duration() int64 +} + +type streamBuffer struct { + buffer RingBuffer + duration int64 + + keyFrameDts int64 //最近一个关键帧的Dts + FarthestKeyFrameDts int64 //最远一个关键帧的Dts + + discardHandler func(packet interface{}) +} + +type element struct { + ts int64 + key bool + pkt interface{} +} + +func NewStreamBuffer(duration int64) StreamBuffer { + return &streamBuffer{duration: duration, buffer: NewRingBuffer(1000)} +} + +func (s *streamBuffer) AddPacket(packet interface{}, key bool, ts int64) bool { + if s.buffer.IsEmpty() { + if !key { + return false + } + + s.FarthestKeyFrameDts = ts + } + + s.buffer.Push(element{ts, key, packet}) + if key { + s.keyFrameDts = ts + } + + //丢弃处理 + //以最近的关键帧时间戳开始,丢弃缓存超过duration长度的帧 + //至少需要保障当前GOP完整 + //暂时不考虑以下情况: + // 1. 音频收流正常,视频长时间没收流,待视频恢复后。 会造成在此期间,多余的音频帧被丢弃,播放时有画面,没声音. + // 2. 视频反之亦然 + if !key { + return true + } + + for farthest := s.keyFrameDts - s.duration; s.buffer.Size() > 1 && s.buffer.Head().(element).ts < farthest; { + ele := s.buffer.Pop().(element) + + //重新设置最早的关键帧时间戳 + if ele.key && ele.ts != s.FarthestKeyFrameDts { + s.FarthestKeyFrameDts = ele.ts + } + + if s.discardHandler != nil { + s.discardHandler(ele.pkt) + } + } + + return true +} + +func (s *streamBuffer) SetDiscardHandler(handler func(packet interface{})) { + s.discardHandler = handler +} + +func (s *streamBuffer) Peek(handler func(packet interface{})) { + head, tail := s.buffer.All() + + if head == nil { + return + } + for _, value := range head { + handler(value.(element).pkt) + } + + if tail == nil { + return + } + for _, value := range tail { + handler(value.(element).pkt) + } +} + +func (s *streamBuffer) Duration() int64 { + head := s.buffer.Head() + tail := s.buffer.Tail() + + if head == nil || tail == nil { + return 0 + } + + return tail.(element).ts - head.(element).ts +} diff --git a/stream/stream_manager.go b/stream/stream_manager.go index 46ae684..673a90b 100644 --- a/stream/stream_manager.go +++ b/stream/stream_manager.go @@ -1,66 +1,9 @@ package stream import ( - "github.com/yangjiechina/avformat" "github.com/yangjiechina/avformat/utils" ) -type Track interface { - Stream() utils.AVStream - - Cache() RingBuffer - - AddPacket(packet utils.AVPacket) -} - -// 封装stream 增加GOP管理 -type track struct { - stream utils.AVStream - cache RingBuffer - duration int - keyFrameDts int64 //最近一个关键帧的Dts -} - -func (t *track) Stream() utils.AVStream { - return t.stream -} - -func (t *track) Cache() RingBuffer { - return t.cache -} - -func (t *track) AddPacket(packet utils.AVPacket) { - if t.cache.IsEmpty() && !packet.KeyFrame() { - return - } - - t.cache.Push(packet) - if packet.KeyFrame() { - t.keyFrameDts = packet.Dts() - } - - //以最近的关键帧时间戳开始,丢弃缓存超过duration长度的帧 - //至少需要保障当前GOP完整 - //head := t.cache.Head().(utils.AVPacket) - //for farthest := t.keyFrameDts - int64(t.duration); t.cache.Size() > 1 && t.cache.Head().(utils.AVPacket).Dts() < farthest; { - // t.cache.Pop() - //} -} - -func NewTrack(stream utils.AVStream, cacheSeconds int) Track { - t := &track{stream: stream, duration: cacheSeconds * 1000} - - if cacheSeconds > 0 { - if utils.AVMediaTypeVideo == stream.Type() { - t.cache = NewRingBuffer(cacheSeconds * 30 * 2) - } else if utils.AVMediaTypeAudio == stream.Type() { - t.cache = NewRingBuffer(cacheSeconds * 50 * 2) - } - } - - return t -} - type StreamManager struct { streams []utils.AVStream completed bool @@ -68,8 +11,8 @@ type StreamManager struct { func (s *StreamManager) Add(stream utils.AVStream) { for _, stream_ := range s.streams { - avformat.Assert(stream_.Type() != stream.Type()) - avformat.Assert(stream_.CodecId() != stream.CodecId()) + utils.Assert(stream_.Type() != stream.Type()) + utils.Assert(stream_.CodecId() != stream.CodecId()) } s.streams = append(s.streams, stream) diff --git a/stream/trans_stream.go b/stream/trans_stream.go index e779d9d..8f86e8b 100644 --- a/stream/trans_stream.go +++ b/stream/trans_stream.go @@ -30,14 +30,14 @@ func init() { // 请确保ids根据值升序排序传参 /*func GenerateTransStreamId(protocol Protocol, ids ...utils.AVCodecID) TransStreamId { len_ := len(ids) - avformat.Assert(len_ > 0 && len_ < 8) + utils.Assert(len_ > 0 && len_ < 8) var streamId uint64 streamId = uint64(protocol) << 56 for i, id := range ids { bId, ok := narrowCodecIds[int(id)] - avformat.Assert(ok) + utils.Assert(ok) streamId |= uint64(bId) << (48 - i*8) } @@ -47,14 +47,14 @@ func init() { func GenerateTransStreamId(protocol Protocol, ids ...utils.AVStream) TransStreamId { len_ := len(ids) - avformat.Assert(len_ > 0 && len_ < 8) + utils.Assert(len_ > 0 && len_ < 8) var streamId uint64 streamId = uint64(protocol) << 56 for i, id := range ids { bId, ok := narrowCodecIds[int(id.CodecId())] - avformat.Assert(ok) + utils.Assert(ok) streamId |= uint64(bId) << (48 - i*8) } @@ -65,9 +65,11 @@ func GenerateTransStreamId(protocol Protocol, ids ...utils.AVStream) TransStream var TransStreamFactory func(protocol Protocol, streams []utils.AVStream) ITransStream type ITransStream interface { + Input(packet utils.AVPacket) + AddTrack(stream utils.AVStream) - WriteHeader() + WriteHeader() error AddSink(sink ISink) @@ -77,21 +79,27 @@ type ITransStream interface { } type TransStreamImpl struct { - sinks map[SinkId]ISink - muxer avformat.Muxer - tracks []utils.AVStream + Sinks map[SinkId]ISink + muxer avformat.Muxer + Tracks []utils.AVStream + transBuffer MemoryPool //每个TransStream也缓存封装后的流 + Completed bool +} + +func (t *TransStreamImpl) Input(packet utils.AVPacket) { + } func (t *TransStreamImpl) AddTrack(stream utils.AVStream) { - t.tracks = append(t.tracks, stream) + t.Tracks = append(t.Tracks, stream) } func (t *TransStreamImpl) AddSink(sink ISink) { - t.sinks[sink.Id()] = sink + t.Sinks[sink.Id()] = sink } func (t *TransStreamImpl) RemoveSink(id SinkId) { - delete(t.sinks, id) + delete(t.Sinks, id) } func (t *TransStreamImpl) AllSink() []ISink {