From de6cca48ca22670301b76d8d160bc256eec2bb53 Mon Sep 17 00:00:00 2001 From: yangjiechina <1534796060@qq.com> Date: Wed, 6 Mar 2024 20:42:17 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=81=E8=A3=85hls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hls/hls_stream.go | 191 ++++++++++++++++++++++++++++++++++++++--- hls/m3u8.go | 3 +- main.go | 18 ++-- rtmp/rtmp_session.go | 2 +- rtmp/rtmp_stream.go | 23 +++-- stream/hook.go | 2 +- stream/source.go | 26 +++++- stream/trans_stream.go | 28 ++++-- 8 files changed, 254 insertions(+), 39 deletions(-) diff --git a/hls/hls_stream.go b/hls/hls_stream.go index bd6dd80..4cf1e44 100644 --- a/hls/hls_stream.go +++ b/hls/hls_stream.go @@ -1,37 +1,204 @@ package hls import ( + "fmt" "github.com/yangjiechina/avformat/libmpeg" "github.com/yangjiechina/avformat/utils" "github.com/yangjiechina/live-server/stream" + "os" ) +type tsContext struct { + segmentSeq int + writeBuffer []byte + writeBufferSize int + + duration int + playlistLength int + url string + path string + + file *os.File +} + type Stream struct { stream.TransStreamImpl - muxer libmpeg.TSMuxer + muxer libmpeg.TSMuxer + context *tsContext + + m3u8 M3U8Writer + url string + m3u8Name string + tsFormat string + dir string + m3u8File *os.File } -func NewTransStream(segmentDuration, playlistLength int) stream.ITransStream { - return &Stream{muxer: libmpeg.NewTSMuxer()} +// NewTransStream 创建HLS传输流 +// @url url前缀 +// @m3u8Name m3u8文件名 +// @tsFormat ts文件格式, 例如: test_%d.ts +// @parentDir 保存切片的绝对路径. mu38和ts切片放在同一目录下, 目录地址使用parentDir+urlPrefix +// @segmentDuration 单个切片时长 +// @playlistLength 缓存多少个切片 +func NewTransStream(url, m3u8Name, tsFormat, dir string, segmentDuration, playlistLength int) (stream.ITransStream, error) { + //创建文件夹 + if err := os.MkdirAll(dir, 0666); err != nil { + return nil, err + } + + m3u8Path := fmt.Sprintf("%s/%s", dir, m3u8Name) + file, err := os.OpenFile(m3u8Path, os.O_CREATE|os.O_RDWR, 0666) + if err != nil { + return nil, err + } + + stream_ := &Stream{ + url: url, + m3u8Name: m3u8Name, + tsFormat: tsFormat, + dir: dir, + } + + muxer := libmpeg.NewTSMuxer() + muxer.SetWriteHandler(stream_.onTSWrite) + muxer.SetAllocHandler(stream_.onTSAlloc) + + stream_.context = &tsContext{ + segmentSeq: 0, + writeBuffer: make([]byte, 1024*1024), + writeBufferSize: 0, + duration: segmentDuration, + playlistLength: playlistLength, + } + + stream_.muxer = muxer + stream_.m3u8 = NewM3U8Writer(playlistLength) + stream_.m3u8File = file + return stream_, nil } -func (t *Stream) Input(packet utils.AVPacket) { - if utils.AVMediaTypeVideo == packet.MediaType() { - if packet.KeyFrame() { - t.Tracks[packet.Index()].AnnexBExtraData() - t.muxer.Input() +func (t *Stream) Input(packet utils.AVPacket) error { + if packet.Index() >= t.muxer.TrackCount() { + return fmt.Errorf("track not available") + } + + if utils.AVMediaTypeVideo == packet.MediaType() && packet.KeyFrame() || t.context.file == nil { + if err := t.createSegment(); err != nil { + return err } } + if utils.AVMediaTypeVideo == packet.MediaType() { + return t.muxer.Input(packet.Index(), packet.AnnexBPacketData(), packet.Pts()*90, packet.Dts()*90, packet.KeyFrame()) + } else { + return t.muxer.Input(packet.Index(), packet.Data(), packet.Pts()*90, packet.Dts()*90, packet.KeyFrame()) + } } -func (t *Stream) AddTrack(stream utils.AVStream) { - t.TransStreamImpl.AddTrack(stream) +func (t *Stream) AddTrack(stream utils.AVStream) error { + err := t.TransStreamImpl.AddTrack(stream) + if err != nil { + return err + } - t.muxer.AddTrack(stream.Type(), stream.CodecId()) + if stream.CodecId() == utils.AVCodecIdH264 { + data, err := stream.AnnexBExtraData() + if err != nil { + return err + } + + _, err = t.muxer.AddTrack(stream.Type(), stream.CodecId(), data) + } else { + _, err = t.muxer.AddTrack(stream.Type(), stream.CodecId(), stream.Extra()) + } + return err } func (t *Stream) WriteHeader() error { - t.muxer.WriteHeader() return nil } + +func (t *Stream) onTSWrite(data []byte) { + t.context.writeBufferSize += len(data) +} + +func (t *Stream) onTSAlloc(size int) []byte { + n := len(t.context.writeBuffer) - t.context.writeBufferSize + if n < size { + _, _ = t.context.file.Write(t.context.writeBuffer[:t.context.writeBufferSize]) + t.context.writeBufferSize = 0 + } + + return t.context.writeBuffer[t.context.writeBufferSize : t.context.writeBufferSize+size] +} + +func (t *Stream) flushSegment() error { + //将剩余数据写入缓冲区 + if t.context.writeBufferSize > 0 { + _, _ = t.context.file.Write(t.context.writeBuffer[:t.context.writeBufferSize]) + t.context.writeBufferSize = 0 + } + + if err := t.context.file.Close(); err != nil { + return err + } + + duration := float32(t.muxer.Duration()) / 90000 + t.m3u8.AddSegment(duration, t.context.url, t.context.segmentSeq) + + //更新m3u8 + if _, err := t.m3u8File.Seek(0, 0); err != nil { + return err + } + if err := t.m3u8File.Truncate(0); err != nil { + return err + } + + m3u8Txt := t.m3u8.ToString() + if _, err := t.m3u8File.Write([]byte(m3u8Txt)); err != nil { + return err + } + + return nil +} + +func (t *Stream) createSegment() error { + if t.context.file != nil { + err := t.flushSegment() + t.context.segmentSeq++ + if err != nil { + return err + } + } + + tsName := fmt.Sprintf(t.tsFormat, t.context.segmentSeq) + t.context.path = fmt.Sprintf("%s%s", t.dir, tsName) + t.context.url = fmt.Sprintf("%s%s", t.url, tsName) + file, err := os.OpenFile(t.context.path, os.O_WRONLY|os.O_CREATE, 0666) + if err != nil { + return err + } + t.context.file = file + + t.muxer.Reset() + err = t.muxer.WriteHeader() + return err +} + +func (t *Stream) Close() error { + var err error + + if t.context.file != nil { + err = t.flushSegment() + err = t.context.file.Close() + t.context.file = nil + } + + if t.m3u8File != nil { + err = t.m3u8File.Close() + t.m3u8File = nil + } + + return err +} diff --git a/hls/m3u8.go b/hls/m3u8.go index e4777d0..b70708c 100644 --- a/hls/m3u8.go +++ b/hls/m3u8.go @@ -47,7 +47,7 @@ type M3U8Writer interface { func NewM3U8Writer(len int) M3U8Writer { return &m3u8Writer{ - stringBuffer: bytes.NewBuffer(make([]byte, 1024*10)), + stringBuffer: bytes.NewBuffer(make([]byte, 0, 1024*10)), playlist: stream.NewQueue(len), } } @@ -85,6 +85,7 @@ func (m *m3u8Writer) ToString() string { return "" } + m.stringBuffer.Reset() m.stringBuffer.WriteString("#EXTM3U\r\n") //暂时只实现第三个版本 m.stringBuffer.WriteString("#EXT-X-VERSION:3\r\n") diff --git a/main.go b/main.go index 1da1010..30f95ec 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "github.com/yangjiechina/live-server/hls" "net" "net/http" @@ -12,18 +13,25 @@ import ( "github.com/yangjiechina/live-server/stream" ) -func CreateTransStream(protocol stream.Protocol, streams []utils.AVStream) stream.ITransStream { +func CreateTransStream(source stream.ISource, protocol stream.Protocol, streams []utils.AVStream) stream.ITransStream { if stream.ProtocolRtmp == protocol { return rtmp.NewTransStream(librtmp.ChunkSize) + } else if stream.ProtocolHls == protocol { + id := source.Id() + m3u8Name := id + ".m3u8" + tsFormat := id + "_%d.ts" + + transStream, err := hls.NewTransStream("/live/hls/", m3u8Name, tsFormat, "../tmp/", 2, 10) + if err != nil { + panic(err) + } + + return transStream } return nil } -func requestStream(sourceId string) { - -} - func init() { stream.TransStreamFactory = CreateTransStream } diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index b7e4ca1..3af487d 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -79,7 +79,7 @@ func (s *sessionImpl) Close() { return } - _, ok := s.handle.(*Publisher) + _, ok := s.handle.(*publisher) if ok { if s.isPublisher { s.handle.(*publisher).AddEvent(stream.SourceEventClose, nil) diff --git a/rtmp/rtmp_stream.go b/rtmp/rtmp_stream.go index 6edb1d1..6222071 100644 --- a/rtmp/rtmp_stream.go +++ b/rtmp/rtmp_stream.go @@ -1,6 +1,7 @@ package rtmp import ( + "fmt" "github.com/yangjiechina/avformat/libflv" "github.com/yangjiechina/avformat/librtmp" "github.com/yangjiechina/avformat/utils" @@ -46,7 +47,7 @@ func NewTransStream(chunkSize int) stream.ITransStream { return transStream } -func (t *TransStream) Input(packet utils.AVPacket) { +func (t *TransStream) Input(packet utils.AVPacket) error { utils.Assert(t.TransStreamImpl.Completed) var data []byte @@ -66,7 +67,7 @@ func (t *TransStream) Input(packet utils.AVPacket) { //首帧必须要视频关键帧 if !t.firstVideoPacket { if !packet.KeyFrame() { - return + return fmt.Errorf("the first video frame must be a keyframe") } t.firstVideoPacket = true @@ -83,7 +84,7 @@ func (t *TransStream) Input(packet utils.AVPacket) { //即不开启GOP缓存又不合并发送. 直接使用AVPacket的预留头封装发送 if !stream.AppConfig.GOPCache || t.onlyAudio { //首帧视频帧必须要关键帧 - return + return nil } if videoKey { @@ -168,11 +169,11 @@ func (t *TransStream) Input(packet utils.AVPacket) { t.incompleteSinks = nil } - return + return nil } if t.segmentDuration < stream.AppConfig.MergeWriteLatency { - return + return nil } head, tail := t.memoryPool[0].Data() @@ -188,9 +189,11 @@ func (t *TransStream) Input(packet utils.AVPacket) { if t.segmentOffset > len(head) && t.memoryPool[1] != nil && !t.memoryPool[1].Empty() { t.memoryPool[1].Clear() } + + return nil } -func (t *TransStream) AddSink(sink stream.ISink) { +func (t *TransStream) AddSink(sink stream.ISink) error { utils.Assert(t.headerSize > 0) t.TransStreamImpl.AddSink(sink) @@ -198,7 +201,7 @@ func (t *TransStream) AddSink(sink stream.ISink) { sink.Input(t.header[:t.headerSize]) if !stream.AppConfig.GOPCache || t.onlyAudio { - return + return nil } //发送当前内存池已有的合并写切片 @@ -207,7 +210,7 @@ func (t *TransStream) AddSink(sink stream.ISink) { utils.Assert(len(data) > 0) utils.Assert(len(tail) == 0) sink.Input(data[:t.segmentOffset]) - return + return nil } //发送上一组GOP @@ -216,7 +219,7 @@ func (t *TransStream) AddSink(sink stream.ISink) { utils.Assert(len(data) > 0) utils.Assert(len(tail) == 0) sink.Input(data) - return + return nil } //不足一个合并写切片, 有多少发多少 @@ -226,6 +229,8 @@ func (t *TransStream) AddSink(sink stream.ISink) { sink.Input(data) t.incompleteSinks = append(t.incompleteSinks, sink) } + + return nil } func (t *TransStream) WriteHeader() error { diff --git a/stream/hook.go b/stream/hook.go index 140cc46..79ddbf9 100644 --- a/stream/hook.go +++ b/stream/hook.go @@ -34,7 +34,7 @@ const ( HookEventPlayDone = HookEvent(0x4) HookEventRecord = HookEvent(0x5) HookEventIdleTimeout = HookEvent(0x6) - HookEventRecvTimeout = HookEvent(0x6) + HookEventRecvTimeout = HookEvent(0x7) ) // 每个通知的时间都需要携带的字段 diff --git a/stream/source.go b/stream/source.go index 67b7d34..27db6ed 100644 --- a/stream/source.go +++ b/stream/source.go @@ -154,6 +154,8 @@ type SourceImpl struct { closeEvent chan byte playingEventQueue chan ISink playingDoneEventQueue chan ISink + + testTransStream ITransStream } func (s *SourceImpl) Id() string { @@ -169,6 +171,13 @@ func (s *SourceImpl) Init() { s.closeEvent = make(chan byte) s.playingEventQueue = make(chan ISink, 128) s.playingDoneEventQueue = make(chan ISink, 128) + + if s.transStreams == nil { + s.transStreams = make(map[TransStreamId]ITransStream, 10) + } + //测试传输流 + s.testTransStream = TransStreamFactory(s, ProtocolHls, nil) + s.transStreams[0x100] = s.testTransStream } func (s *SourceImpl) LoopEvent() { @@ -191,6 +200,10 @@ func (s *SourceImpl) LoopEvent() { } } +func (s *SourceImpl) Input(data []byte) { + +} + func (s *SourceImpl) OriginStreams() []utils.AVStream { return s.originStreams.All() } @@ -304,11 +317,11 @@ func (s *SourceImpl) AddSink(sink ISink) bool { transStreamId := GenerateTransStreamId(sink.Protocol(), streams[:size]...) transStream, ok := s.transStreams[transStreamId] if !ok { - //创建一个新的传输流 - transStream = TransStreamFactory(sink.Protocol(), streams[:size]) if s.transStreams == nil { s.transStreams = make(map[TransStreamId]ITransStream, 10) } + //创建一个新的传输流 + transStream = TransStreamFactory(s, sink.Protocol(), streams[:size]) s.transStreams[transStreamId] = transStream for i := 0; i < size; i++ { @@ -433,6 +446,14 @@ func (s *SourceImpl) writeHeader() { for _, sink := range sinks { s.AddSink(sink) } + + if s.testTransStream != nil { + for _, stream_ := range s.originStreams.All() { + s.testTransStream.AddTrack(stream_) + } + + s.testTransStream.WriteHeader() + } } func (s *SourceImpl) OnDeMuxStreamDone() { @@ -445,6 +466,7 @@ func (s *SourceImpl) OnDeMuxPacket(packet utils.AVPacket) { buffer.AddPacket(packet, packet.KeyFrame(), packet.Dts()) } + //分发给各个传输流 for _, stream := range s.transStreams { stream.Input(packet) } diff --git a/stream/trans_stream.go b/stream/trans_stream.go index 1c28e90..930c52d 100644 --- a/stream/trans_stream.go +++ b/stream/trans_stream.go @@ -62,23 +62,25 @@ func GenerateTransStreamId(protocol Protocol, ids ...utils.AVStream) TransStream return TransStreamId(streamId) } -var TransStreamFactory func(protocol Protocol, streams []utils.AVStream) ITransStream +var TransStreamFactory func(source ISource, protocol Protocol, streams []utils.AVStream) ITransStream // ITransStream 讲AVPacket封装成传输流,转发给各个Sink type ITransStream interface { - Input(packet utils.AVPacket) + Input(packet utils.AVPacket) error - AddTrack(stream utils.AVStream) + AddTrack(stream utils.AVStream) error WriteHeader() error - AddSink(sink ISink) + AddSink(sink ISink) error RemoveSink(id SinkId) (ISink, bool) PopAllSink(handler func(sink ISink)) AllSink() []ISink + + Close() error } type TransStreamImpl struct { @@ -87,18 +89,24 @@ type TransStreamImpl struct { Tracks []utils.AVStream transBuffer MemoryPool //每个TransStream也缓存封装后的流 Completed bool + existVideo bool } -func (t *TransStreamImpl) Input(packet utils.AVPacket) { - +func (t *TransStreamImpl) Input(packet utils.AVPacket) error { + return nil } -func (t *TransStreamImpl) AddTrack(stream utils.AVStream) { +func (t *TransStreamImpl) AddTrack(stream utils.AVStream) error { t.Tracks = append(t.Tracks, stream) + if utils.AVMediaTypeVideo == stream.Type() { + t.existVideo = true + } + return nil } -func (t *TransStreamImpl) AddSink(sink ISink) { +func (t *TransStreamImpl) AddSink(sink ISink) error { t.Sinks[sink.Id()] = sink + return nil } func (t *TransStreamImpl) RemoveSink(id SinkId) (ISink, bool) { @@ -122,3 +130,7 @@ func (t *TransStreamImpl) AllSink() []ISink { //TODO implement me panic("implement me") } + +func (t *TransStreamImpl) Close() error { + return nil +}