diff --git a/plugin/hls/download.go b/plugin/hls/download.go new file mode 100644 index 0000000..3adfc3d --- /dev/null +++ b/plugin/hls/download.go @@ -0,0 +1,682 @@ +package plugin_hls + +import ( + "bufio" + "fmt" + "io" + "net/http" + "os" + "strconv" + "strings" + "time" + + m7s "m7s.live/v5" + "m7s.live/v5/pkg" + "m7s.live/v5/pkg/util" + hls "m7s.live/v5/plugin/hls/pkg" + mpegts "m7s.live/v5/plugin/hls/pkg/ts" + mp4 "m7s.live/v5/plugin/mp4/pkg" + "m7s.live/v5/plugin/mp4/pkg/box" +) + +// requestParams 包含请求解析后的参数 +type requestParams struct { + streamPath string + startTime time.Time + endTime time.Time + timeRange time.Duration +} + +// fileInfo 包含文件信息 +type fileInfo struct { + filePath string + startTime time.Time + endTime time.Time + startOffsetTime time.Duration + recordType string // "ts", "mp4", "fmp4" +} + +// parseRequestParams 解析请求参数 +func (plugin *HLSPlugin) parseRequestParams(r *http.Request) (*requestParams, error) { + // 从URL路径中提取流路径,去除前缀 "/download/" 和后缀 ".ts" + streamPath := strings.TrimSuffix(strings.TrimPrefix(r.URL.Path, "/download/"), ".ts") + + // 解析URL查询参数中的时间范围(start和end参数) + startTime, endTime, err := util.TimeRangeQueryParse(r.URL.Query()) + if err != nil { + return nil, err + } + + return &requestParams{ + streamPath: streamPath, + startTime: startTime, + endTime: endTime, + timeRange: endTime.Sub(startTime), + }, nil +} + +// queryRecordStreams 从数据库查询录像记录 +func (plugin *HLSPlugin) queryRecordStreams(params *requestParams) ([]m7s.RecordStream, error) { + // 检查数据库是否可用 + if plugin.DB == nil { + return nil, fmt.Errorf("database not available") + } + + var recordStreams []m7s.RecordStream + + // 首先查询HLS记录 (ts) + query := plugin.DB.Model(&m7s.RecordStream{}).Where("stream_path = ? AND type = ?", params.streamPath, "hls") + + // 添加时间范围查询条件 + if !params.startTime.IsZero() && !params.endTime.IsZero() { + query = query.Where("(start_time <= ? AND end_time >= ?) OR (start_time >= ? AND start_time <= ?)", + params.endTime, params.startTime, params.startTime, params.endTime) + } + + err := query.Order("start_time ASC").Find(&recordStreams).Error + if err != nil { + return nil, err + } + + // 如果没有找到HLS记录,尝试查询MP4记录 + if len(recordStreams) == 0 { + query = plugin.DB.Model(&m7s.RecordStream{}).Where("stream_path = ? AND type IN (?)", params.streamPath, []string{"mp4", "fmp4"}) + + if !params.startTime.IsZero() && !params.endTime.IsZero() { + query = query.Where("(start_time <= ? AND end_time >= ?) OR (start_time >= ? AND start_time <= ?)", + params.endTime, params.startTime, params.startTime, params.endTime) + } + + err = query.Order("start_time ASC").Find(&recordStreams).Error + if err != nil { + return nil, err + } + } + + return recordStreams, nil +} + +// buildFileInfoList 构建文件信息列表 +func (plugin *HLSPlugin) buildFileInfoList(recordStreams []m7s.RecordStream, startTime, endTime time.Time) ([]*fileInfo, bool) { + var fileInfoList []*fileInfo + var found bool + + for _, record := range recordStreams { + // 检查文件是否存在 + if !util.Exist(record.FilePath) { + plugin.Warn("Record file not found", "filePath", record.FilePath) + continue + } + + var startOffsetTime time.Duration + recordStartTime := record.StartTime + recordEndTime := record.EndTime + + // 计算文件内的偏移时间 + if startTime.After(recordStartTime) { + startOffsetTime = startTime.Sub(recordStartTime) + } + + // 检查是否在时间范围内 + if recordEndTime.Before(startTime) || recordStartTime.After(endTime) { + continue + } + + fileInfoList = append(fileInfoList, &fileInfo{ + filePath: record.FilePath, + startTime: recordStartTime, + endTime: recordEndTime, + startOffsetTime: startOffsetTime, + recordType: record.Type, + }) + + found = true + } + + return fileInfoList, found +} + +// hasOnlyMp4Records 检查是否只有MP4记录 +func (plugin *HLSPlugin) hasOnlyMp4Records(fileInfoList []*fileInfo) bool { + if len(fileInfoList) == 0 { + return false + } + + for _, info := range fileInfoList { + if info.recordType == "hls" { + return false + } + } + return true +} + +// filterTsFiles 过滤HLS TS文件 +func (plugin *HLSPlugin) filterTsFiles(fileInfoList []*fileInfo) []*fileInfo { + var filteredList []*fileInfo + + for _, info := range fileInfoList { + if info.recordType == "hls" { + filteredList = append(filteredList, info) + } + } + + plugin.Debug("TS files filtered", "original", len(fileInfoList), "filtered", len(filteredList)) + return filteredList +} + +// filterMp4Files 过滤MP4文件 +func (plugin *HLSPlugin) filterMp4Files(fileInfoList []*fileInfo) []*fileInfo { + var filteredList []*fileInfo + + for _, info := range fileInfoList { + if info.recordType == "mp4" || info.recordType == "fmp4" { + filteredList = append(filteredList, info) + } + } + + plugin.Debug("MP4 files filtered", "original", len(fileInfoList), "filtered", len(filteredList)) + return filteredList +} + +// processMp4ToTs 将MP4记录转换为TS输出 +func (plugin *HLSPlugin) processMp4ToTs(w http.ResponseWriter, r *http.Request, fileInfoList []*fileInfo, params *requestParams) { + plugin.Info("Converting MP4 records to TS", "count", len(fileInfoList)) + + // 设置HTTP响应头 + w.Header().Set("Content-Type", "video/mp2t") + w.Header().Set("Content-Disposition", "attachment") + + // 创建一个TS写入器,在循环外面,所有MP4文件共享同一个TsInMemory + tsWriter := &simpleTsWriter{ + TsInMemory: &hls.TsInMemory{}, + plugin: plugin, + } + + // 对于MP4到TS的转换,我们采用简化的方法 + // 直接将每个MP4文件转换输出 + for _, info := range fileInfoList { + if r.Context().Err() != nil { + return + } + + plugin.Debug("Converting MP4 file to TS", "path", info.filePath) + + // 创建MP4解复用器 + demuxer := &mp4.DemuxerRange{ + StartTime: params.startTime, + EndTime: params.endTime, + Streams: []m7s.RecordStream{{ + FilePath: info.filePath, + StartTime: info.startTime, + EndTime: info.endTime, + Type: info.recordType, + }}, + } + + // 设置回调函数 + demuxer.OnVideoExtraData = tsWriter.onVideoExtraData + demuxer.OnAudioExtraData = tsWriter.onAudioExtraData + demuxer.OnVideoSample = tsWriter.onVideoSample + demuxer.OnAudioSample = tsWriter.onAudioSample + + // 执行解复用和转换 + err := demuxer.Demux(r.Context()) + if err != nil { + plugin.Error("MP4 to TS conversion failed", "err", err, "file", info.filePath) + if !tsWriter.hasWritten { + http.Error(w, "Conversion failed", http.StatusInternalServerError) + } + return + } + } + + // 将所有累积的 TsInMemory 内容写入到响应 + _, err := tsWriter.WriteTo(w) + if err != nil { + plugin.Error("Failed to write TS data to response", "error", err) + return + } + + plugin.Info("MP4 to TS conversion completed") +} + +// simpleTsWriter 简化的TS写入器 +type simpleTsWriter struct { + *hls.TsInMemory + plugin *HLSPlugin + hasWritten bool + spsData []byte + ppsData []byte + videoCodec box.MP4_CODEC_TYPE + audioCodec box.MP4_CODEC_TYPE +} + +func (w *simpleTsWriter) WritePMT() { + // 初始化 TsInMemory 的 PMT + var videoCodec, audioCodec [4]byte + switch w.videoCodec { + case box.MP4_CODEC_H264: + copy(videoCodec[:], []byte("H264")) + case box.MP4_CODEC_H265: + copy(videoCodec[:], []byte("H265")) + } + switch w.audioCodec { + case box.MP4_CODEC_AAC: + copy(audioCodec[:], []byte("MP4A")) + + } + w.WritePMTPacket(audioCodec, videoCodec) + w.hasWritten = true +} + +// onVideoExtraData 处理视频序列头 +func (w *simpleTsWriter) onVideoExtraData(codecType box.MP4_CODEC_TYPE, data []byte) error { + w.videoCodec = codecType + // 解析并存储SPS/PPS数据 + if codecType == box.MP4_CODEC_H264 && len(data) > 0 { + if w.plugin != nil { + w.plugin.Debug("Processing H264 extra data", "size", len(data)) + } + + // 解析AVCC格式的extra data + if len(data) >= 8 { + // AVCC格式: configurationVersion(1) + AVCProfileIndication(1) + profile_compatibility(1) + AVCLevelIndication(1) + + // lengthSizeMinusOne(1) + numOfSequenceParameterSets(1) + ... + + offset := 5 // 跳过前5个字节 + if offset < len(data) { + // 读取SPS数量 + numSPS := data[offset] & 0x1f + offset++ + + // 解析SPS + for i := 0; i < int(numSPS) && offset < len(data)-1; i++ { + if offset+1 >= len(data) { + break + } + spsLength := int(data[offset])<<8 | int(data[offset+1]) + offset += 2 + + if offset+spsLength <= len(data) { + // 添加起始码并存储SPS + w.spsData = make([]byte, 4+spsLength) + copy(w.spsData[0:4], []byte{0x00, 0x00, 0x00, 0x01}) + copy(w.spsData[4:], data[offset:offset+spsLength]) + offset += spsLength + + if w.plugin != nil { + w.plugin.Debug("Extracted SPS", "length", spsLength) + } + break // 只取第一个SPS + } + } + + // 读取PPS数量 + if offset < len(data) { + numPPS := data[offset] + offset++ + + // 解析PPS + for i := 0; i < int(numPPS) && offset < len(data)-1; i++ { + if offset+1 >= len(data) { + break + } + ppsLength := int(data[offset])<<8 | int(data[offset+1]) + offset += 2 + + if offset+ppsLength <= len(data) { + // 添加起始码并存储PPS + w.ppsData = make([]byte, 4+ppsLength) + copy(w.ppsData[0:4], []byte{0x00, 0x00, 0x00, 0x01}) + copy(w.ppsData[4:], data[offset:offset+ppsLength]) + + if w.plugin != nil { + w.plugin.Debug("Extracted PPS", "length", ppsLength) + } + break // 只取第一个PPS + } + } + } + } + } + } + + return nil +} + +// onAudioExtraData 处理音频序列头 +func (w *simpleTsWriter) onAudioExtraData(codecType box.MP4_CODEC_TYPE, data []byte) error { + w.audioCodec = codecType + w.plugin.Debug("Processing audio extra data", "codec", codecType, "size", len(data)) + return nil +} + +// onVideoSample 处理视频样本 +func (w *simpleTsWriter) onVideoSample(codecType box.MP4_CODEC_TYPE, sample box.Sample) error { + if !w.hasWritten { + w.WritePMT() + } + + w.plugin.Debug("Processing video sample", "size", len(sample.Data), "keyFrame", sample.KeyFrame, "timestamp", sample.Timestamp) + + // 转换AVCC格式到Annex-B格式 + annexBData, err := w.convertAVCCToAnnexB(sample.Data, sample.KeyFrame) + if err != nil { + w.plugin.Error("Failed to convert AVCC to Annex-B", "error", err) + return err + } + + if len(annexBData) == 0 { + w.plugin.Warn("Empty Annex-B data after conversion") + return nil + } + + // 创建视频帧结构 + videoFrame := mpegts.MpegtsPESFrame{ + Pid: mpegts.PID_VIDEO, + IsKeyFrame: sample.KeyFrame, + } + + // 创建 AnnexB 帧 + annexBFrame := &pkg.AnnexB{ + PTS: (time.Duration(sample.Timestamp) + time.Duration(sample.CTS)) * 90, + DTS: time.Duration(sample.Timestamp) * 90, // 对于MP4转换,假设PTS=DTS + } + + // 根据编解码器类型设置 Hevc 标志 + if codecType == box.MP4_CODEC_H265 { + annexBFrame.Hevc = true + } + + annexBFrame.AppendOne(annexBData) + + // 使用 WriteVideoFrame 写入TS包 + err = w.WriteVideoFrame(annexBFrame, &videoFrame) + if err != nil { + w.plugin.Error("Failed to write video frame", "error", err) + return err + } + + return nil +} + +// convertAVCCToAnnexB 将AVCC格式转换为Annex-B格式 +func (w *simpleTsWriter) convertAVCCToAnnexB(avccData []byte, isKeyFrame bool) ([]byte, error) { + if len(avccData) == 0 { + return nil, fmt.Errorf("empty AVCC data") + } + + var annexBBuffer []byte + + // 如果是关键帧,先添加SPS和PPS + if isKeyFrame { + if len(w.spsData) > 0 { + annexBBuffer = append(annexBBuffer, w.spsData...) + w.plugin.Debug("Added SPS to key frame", "spsSize", len(w.spsData)) + } + if len(w.ppsData) > 0 { + annexBBuffer = append(annexBBuffer, w.ppsData...) + w.plugin.Debug("Added PPS to key frame", "ppsSize", len(w.ppsData)) + } + } + + // 解析AVCC格式的NAL单元 + offset := 0 + nalCount := 0 + + for offset < len(avccData) { + // AVCC格式:4字节长度 + NAL数据 + if offset+4 > len(avccData) { + break + } + + // 读取NAL单元长度(大端序) + nalLength := int(avccData[offset])<<24 | + int(avccData[offset+1])<<16 | + int(avccData[offset+2])<<8 | + int(avccData[offset+3]) + offset += 4 + + if nalLength <= 0 || offset+nalLength > len(avccData) { + w.plugin.Warn("Invalid NAL length", "length", nalLength, "remaining", len(avccData)-offset) + break + } + + nalData := avccData[offset : offset+nalLength] + offset += nalLength + nalCount++ + + if len(nalData) > 0 { + nalType := nalData[0] & 0x1f + w.plugin.Debug("Converting NAL unit", "type", nalType, "length", nalLength) + + // 添加起始码前缀 + annexBBuffer = append(annexBBuffer, []byte{0x00, 0x00, 0x00, 0x01}...) + annexBBuffer = append(annexBBuffer, nalData...) + } + } + + if nalCount == 0 { + return nil, fmt.Errorf("no NAL units found in AVCC data") + } + + w.plugin.Debug("AVCC to Annex-B conversion completed", + "inputSize", len(avccData), + "outputSize", len(annexBBuffer), + "nalUnits", nalCount) + + return annexBBuffer, nil +} + +// onAudioSample 处理音频样本 +func (w *simpleTsWriter) onAudioSample(codecType box.MP4_CODEC_TYPE, sample box.Sample) error { + if !w.hasWritten { + w.WritePMT() + } + + w.plugin.Debug("Processing audio sample", "codec", codecType, "size", len(sample.Data), "timestamp", sample.Timestamp) + + // 创建音频帧结构 + audioFrame := mpegts.MpegtsPESFrame{ + Pid: mpegts.PID_AUDIO, + } + + // 根据编解码器类型处理音频数据 + switch codecType { + case box.MP4_CODEC_AAC: // AAC + // 创建 ADTS 帧 + adtsFrame := &pkg.ADTS{ + DTS: time.Duration(sample.Timestamp) * 90, + } + + // 将音频数据添加到帧中 + copy(adtsFrame.NextN(len(sample.Data)), sample.Data) + + // 使用 WriteAudioFrame 写入TS包 + err := w.WriteAudioFrame(adtsFrame, &audioFrame) + if err != nil { + w.plugin.Error("Failed to write audio frame", "error", err) + return err + } + default: + // 对于非AAC音频,暂时使用原来的PES包方式 + pesPacket := mpegts.MpegTsPESPacket{ + Header: mpegts.MpegTsPESHeader{ + PacketStartCodePrefix: 0x000001, + StreamID: mpegts.STREAM_ID_AUDIO, + }, + } + // 设置可选字段 + pesPacket.Header.ConstTen = 0x80 + pesPacket.Header.PtsDtsFlags = 0x80 // 只有PTS + pesPacket.Header.PesHeaderDataLength = 5 + pesPacket.Header.Pts = uint64(sample.Timestamp) + + pesPacket.Buffers = append(pesPacket.Buffers, sample.Data) + + // 写入TS包 + err := w.WritePESPacket(&audioFrame, pesPacket) + if err != nil { + w.plugin.Error("Failed to write audio PES packet", "error", err) + return err + } + } + + return nil +} + +// processTsFiles 处理原生TS文件拼接 +func (plugin *HLSPlugin) processTsFiles(w http.ResponseWriter, r *http.Request, fileInfoList []*fileInfo, params *requestParams) { + plugin.Info("Processing TS files", "count", len(fileInfoList)) + + // 设置HTTP响应头 + w.Header().Set("Content-Type", "video/mp2t") + w.Header().Set("Content-Disposition", "attachment") + + var writer io.Writer = w + var totalSize uint64 + + // 第一次遍历:计算总大小 + for _, info := range fileInfoList { + if r.Context().Err() != nil { + return + } + + fileInfo, err := os.Stat(info.filePath) + if err != nil { + plugin.Error("Failed to stat file", "path", info.filePath, "err", err) + continue + } + totalSize += uint64(fileInfo.Size()) + } + + // 设置内容长度 + w.Header().Set("Content-Length", strconv.FormatUint(totalSize, 10)) + w.WriteHeader(http.StatusOK) + + // 第二次遍历:写入数据 + for i, info := range fileInfoList { + if r.Context().Err() != nil { + return + } + + plugin.Debug("Processing TS file", "path", info.filePath) + file, err := os.Open(info.filePath) + if err != nil { + plugin.Error("Failed to open file", "path", info.filePath, "err", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + reader := bufio.NewReader(file) + + if i == 0 { + // 第一个文件,直接拷贝 + _, err = io.Copy(writer, reader) + } else { + // 后续文件,跳过PAT/PMT包,只拷贝媒体数据 + err = plugin.copyTsFileSkipHeaders(writer, reader) + } + + file.Close() + + if err != nil { + plugin.Error("Failed to copy file", "path", info.filePath, "err", err) + return + } + } + + plugin.Info("TS download completed") +} + +// copyTsFileSkipHeaders 拷贝TS文件,跳过PAT/PMT包 +func (plugin *HLSPlugin) copyTsFileSkipHeaders(writer io.Writer, reader *bufio.Reader) error { + buffer := make([]byte, mpegts.TS_PACKET_SIZE) + + for { + n, err := io.ReadFull(reader, buffer) + if err != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + break + } + return err + } + + if n != mpegts.TS_PACKET_SIZE { + continue + } + + // 检查同步字节 + if buffer[0] != 0x47 { + continue + } + + // 提取PID + pid := uint16(buffer[1]&0x1f)<<8 | uint16(buffer[2]) + + // 跳过PAT(PID=0)和PMT(PID=256)包 + if pid == mpegts.PID_PAT || pid == mpegts.PID_PMT { + continue + } + + // 写入媒体数据包 + _, err = writer.Write(buffer) + if err != nil { + return err + } + } + + return nil +} + +// download 下载处理函数 +func (plugin *HLSPlugin) download(w http.ResponseWriter, r *http.Request) { + // 解析请求参数 + params, err := plugin.parseRequestParams(r) + if err != nil { + plugin.Error("Failed to parse request params", "err", err) + http.Error(w, "Invalid parameters", http.StatusBadRequest) + return + } + + plugin.Info("TS download request", "streamPath", params.streamPath, "timeRange", params.timeRange) + + // 查询录像记录 + recordStreams, err := plugin.queryRecordStreams(params) + if err != nil { + plugin.Error("Failed to query record streams", "err", err) + http.Error(w, "Database error", http.StatusInternalServerError) + return + } + + if len(recordStreams) == 0 { + plugin.Warn("No records found", "streamPath", params.streamPath) + http.Error(w, "No records found", http.StatusNotFound) + return + } + + // 构建文件信息列表 + fileInfoList, found := plugin.buildFileInfoList(recordStreams, params.startTime, params.endTime) + if !found { + plugin.Warn("No valid files found", "streamPath", params.streamPath) + http.Error(w, "No valid files found", http.StatusNotFound) + return + } + + // 检查文件类型并处理 + if plugin.hasOnlyMp4Records(fileInfoList) { + // 只有MP4记录,转换为TS + mp4Files := plugin.filterMp4Files(fileInfoList) + plugin.processMp4ToTs(w, r, mp4Files, params) + } else { + // 有TS记录,优先使用TS文件 + tsFiles := plugin.filterTsFiles(fileInfoList) + if len(tsFiles) > 0 { + plugin.processTsFiles(w, r, tsFiles, params) + } else { + // 没有TS文件,使用MP4转换 + mp4Files := plugin.filterMp4Files(fileInfoList) + plugin.processMp4ToTs(w, r, mp4Files, params) + } + } +} diff --git a/plugin/hls/index.go b/plugin/hls/index.go index 6e75bca..92e8ac6 100644 --- a/plugin/hls/index.go +++ b/plugin/hls/index.go @@ -59,6 +59,7 @@ func (p *HLSPlugin) OnInit() (err error) { func (p *HLSPlugin) RegisterHandler() map[string]http.HandlerFunc { return map[string]http.HandlerFunc{ "/vod/{streamPath...}": p.vod, + "/download/{streamPath...}": p.download, "/api/record/start/{streamPath...}": p.API_record_start, "/api/record/stop/{id}": p.API_record_stop, } diff --git a/plugin/mp4/pkg/box/trak.go b/plugin/mp4/pkg/box/trak.go index 7bb416b..b451caf 100644 --- a/plugin/mp4/pkg/box/trak.go +++ b/plugin/mp4/pkg/box/trak.go @@ -54,8 +54,16 @@ func (t *TrakBox) Unmarshal(buf []byte) (b IBox, err error) { return t, err } +// SampleCallback 定义样本处理回调函数类型 +type SampleCallback func(sample *Sample, sampleIndex int) error + // ParseSamples parses the sample table and builds the sample list func (t *TrakBox) ParseSamples() (samplelist []Sample) { + return t.ParseSamplesWithCallback(nil) +} + +// ParseSamplesWithCallback parses the sample table and builds the sample list with optional callback +func (t *TrakBox) ParseSamplesWithCallback(callback SampleCallback) (samplelist []Sample) { stbl := t.MDIA.MINF.STBL var chunkOffsets []uint64 if stbl.STCO != nil { @@ -150,6 +158,17 @@ func (t *TrakBox) ParseSamples() (samplelist []Sample) { } } + // 调用回调函数处理每个样本 + if callback != nil { + for i := range samplelist { + if err := callback(&samplelist[i], i); err != nil { + // 如果回调返回错误,可以选择记录或处理,但不中断解析 + // 这里为了保持向后兼容性,我们继续处理 + continue + } + } + } + return samplelist } diff --git a/plugin/mp4/pkg/demuxer.go b/plugin/mp4/pkg/demuxer.go index 442acc4..65de8d9 100644 --- a/plugin/mp4/pkg/demuxer.go +++ b/plugin/mp4/pkg/demuxer.go @@ -6,8 +6,10 @@ import ( "slices" "m7s.live/v5/pkg" + "m7s.live/v5/pkg/codec" + "m7s.live/v5/pkg/util" "m7s.live/v5/plugin/mp4/pkg/box" - . "m7s.live/v5/plugin/mp4/pkg/box" + rtmp "m7s.live/v5/plugin/rtmp/pkg" ) type ( @@ -30,7 +32,7 @@ type ( Number uint32 CryptByteBlock uint8 SkipByteBlock uint8 - PsshBoxes []*PsshBox + PsshBoxes []*box.PsshBox } SubSamplePattern struct { BytesClear uint16 @@ -43,16 +45,28 @@ type ( chunkoffset uint64 } + RTMPFrame struct { + Frame any // 可以是 *rtmp.RTMPVideo 或 *rtmp.RTMPAudio + } + Demuxer struct { reader io.ReadSeeker Tracks []*Track ReadSampleIdx []uint32 IsFragment bool - // pssh []*PsshBox - moov *MoovBox - mdat *MediaDataBox + // pssh []*box.PsshBox + moov *box.MoovBox + mdat *box.MediaDataBox mdatOffset uint64 QuicTime bool + + // 预生成的 RTMP 帧 + RTMPVideoSequence *rtmp.RTMPVideo + RTMPAudioSequence *rtmp.RTMPAudio + RTMPFrames []RTMPFrame + + // RTMP 帧生成配置 + RTMPAllocator *util.ScalableMemoryAllocator } ) @@ -63,6 +77,10 @@ func NewDemuxer(r io.ReadSeeker) *Demuxer { } func (d *Demuxer) Demux() (err error) { + return d.DemuxWithAllocator(nil) +} + +func (d *Demuxer) DemuxWithAllocator(allocator *util.ScalableMemoryAllocator) (err error) { // decodeVisualSampleEntry := func() (offset int, err error) { // var encv VisualSampleEntry @@ -96,7 +114,7 @@ func (d *Demuxer) Demux() (err error) { // } // return // } - var b IBox + var b box.IBox var offset uint64 for { b, err = box.ReadFrom(d.reader) @@ -107,53 +125,59 @@ func (d *Demuxer) Demux() (err error) { return err } offset += b.Size() - switch box := b.(type) { - case *FileTypeBox: - if slices.Contains(box.CompatibleBrands, [4]byte{'q', 't', ' ', ' '}) { + switch boxData := b.(type) { + case *box.FileTypeBox: + if slices.Contains(boxData.CompatibleBrands, [4]byte{'q', 't', ' ', ' '}) { d.QuicTime = true } - case *FreeBox: - case *MediaDataBox: - d.mdat = box - d.mdatOffset = offset - b.Size() + uint64(box.HeaderSize()) - case *MoovBox: - if box.MVEX != nil { + case *box.FreeBox: + case *box.MediaDataBox: + d.mdat = boxData + d.mdatOffset = offset - b.Size() + uint64(boxData.HeaderSize()) + case *box.MoovBox: + if boxData.MVEX != nil { d.IsFragment = true } - for _, trak := range box.Tracks { + for _, trak := range boxData.Tracks { track := &Track{} track.TrackId = trak.TKHD.TrackID track.Duration = uint32(trak.TKHD.Duration) track.Timescale = trak.MDIA.MDHD.Timescale - track.Samplelist = trak.ParseSamples() + // 创建RTMP样本处理回调 + var sampleCallback box.SampleCallback + if d.RTMPAllocator != nil { + sampleCallback = d.createRTMPSampleCallback(track, trak) + } + + track.Samplelist = trak.ParseSamplesWithCallback(sampleCallback) if len(trak.MDIA.MINF.STBL.STSD.Entries) > 0 { entryBox := trak.MDIA.MINF.STBL.STSD.Entries[0] switch entry := entryBox.(type) { - case *AudioSampleEntry: + case *box.AudioSampleEntry: switch entry.Type() { - case TypeMP4A: - track.Cid = MP4_CODEC_AAC - case TypeALAW: - track.Cid = MP4_CODEC_G711A - case TypeULAW: - track.Cid = MP4_CODEC_G711U - case TypeOPUS: - track.Cid = MP4_CODEC_OPUS + case box.TypeMP4A: + track.Cid = box.MP4_CODEC_AAC + case box.TypeALAW: + track.Cid = box.MP4_CODEC_G711A + case box.TypeULAW: + track.Cid = box.MP4_CODEC_G711U + case box.TypeOPUS: + track.Cid = box.MP4_CODEC_OPUS } track.SampleRate = entry.Samplerate track.ChannelCount = uint8(entry.ChannelCount) track.SampleSize = entry.SampleSize switch extra := entry.ExtraData.(type) { - case *ESDSBox: - track.Cid, track.ExtraData = DecodeESDescriptor(extra.Data) + case *box.ESDSBox: + track.Cid, track.ExtraData = box.DecodeESDescriptor(extra.Data) } - case *VisualSampleEntry: - track.ExtraData = entry.ExtraData.(*DataBox).Data + case *box.VisualSampleEntry: + track.ExtraData = entry.ExtraData.(*box.DataBox).Data switch entry.Type() { - case TypeAVC1: - track.Cid = MP4_CODEC_H264 - case TypeHVC1, TypeHEV1: - track.Cid = MP4_CODEC_H265 + case box.TypeAVC1: + track.Cid = box.MP4_CODEC_H264 + case box.TypeHVC1, box.TypeHEV1: + track.Cid = box.MP4_CODEC_H265 } track.Width = uint32(entry.Width) track.Height = uint32(entry.Height) @@ -161,9 +185,9 @@ func (d *Demuxer) Demux() (err error) { } d.Tracks = append(d.Tracks, track) } - d.moov = box - case *MovieFragmentBox: - for _, traf := range box.TRAFs { + d.moov = boxData + case *box.MovieFragmentBox: + for _, traf := range boxData.TRAFs { track := d.Tracks[traf.TFHD.TrackID-1] track.defaultSize = traf.TFHD.DefaultSampleSize track.defaultDuration = traf.TFHD.DefaultSampleDuration @@ -171,6 +195,7 @@ func (d *Demuxer) Demux() (err error) { } } d.ReadSampleIdx = make([]uint32, len(d.Tracks)) + // for _, track := range d.Tracks { // if len(track.Samplelist) > 0 { // track.StartDts = uint64(track.Samplelist[0].DTS) * 1000 / uint64(track.Timescale) @@ -180,7 +205,7 @@ func (d *Demuxer) Demux() (err error) { return nil } -func (d *Demuxer) SeekTime(dts uint64) (sample *Sample, err error) { +func (d *Demuxer) SeekTime(dts uint64) (sample *box.Sample, err error) { var audioTrack, videoTrack *Track for _, track := range d.Tracks { if track.Cid.IsAudio() { @@ -425,10 +450,10 @@ func (d *Demuxer) SeekTimePreIDR(dts uint64) (sample *Sample, err error) { // return nil // } -func (d *Demuxer) ReadSample(yield func(*Track, Sample) bool) { +func (d *Demuxer) ReadSample(yield func(*Track, box.Sample) bool) { for { maxdts := int64(-1) - minTsSample := Sample{Timestamp: uint32(maxdts)} + minTsSample := box.Sample{Timestamp: uint32(maxdts)} var whichTrack *Track whichTracki := 0 for i, track := range d.Tracks { @@ -462,9 +487,9 @@ func (d *Demuxer) ReadSample(yield func(*Track, Sample) bool) { } } -func (d *Demuxer) RangeSample(yield func(*Track, *Sample) bool) { +func (d *Demuxer) RangeSample(yield func(*Track, *box.Sample) bool) { for { - var minTsSample *Sample + var minTsSample *box.Sample var whichTrack *Track whichTracki := 0 for i, track := range d.Tracks { @@ -496,6 +521,244 @@ func (d *Demuxer) RangeSample(yield func(*Track, *Sample) bool) { } // GetMoovBox returns the Movie Box from the demuxer -func (d *Demuxer) GetMoovBox() *MoovBox { +func (d *Demuxer) GetMoovBox() *box.MoovBox { return d.moov } + +// CreateRTMPSequenceFrame 创建 RTMP 序列帧 +func (d *Demuxer) CreateRTMPSequenceFrame(track *Track, allocator *util.ScalableMemoryAllocator) (videoSeq *rtmp.RTMPVideo, audioSeq *rtmp.RTMPAudio, err error) { + switch track.Cid { + case box.MP4_CODEC_H264: + videoSeq = &rtmp.RTMPVideo{} + videoSeq.SetAllocator(allocator) + videoSeq.Append([]byte{0x17, 0x00, 0x00, 0x00, 0x00}, track.ExtraData) + case box.MP4_CODEC_H265: + videoSeq = &rtmp.RTMPVideo{} + videoSeq.SetAllocator(allocator) + videoSeq.Append([]byte{0b1001_0000 | rtmp.PacketTypeSequenceStart}, codec.FourCC_H265[:], track.ExtraData) + case box.MP4_CODEC_AAC: + audioSeq = &rtmp.RTMPAudio{} + audioSeq.SetAllocator(allocator) + audioSeq.Append([]byte{0xaf, 0x00}, track.ExtraData) + } + return +} + +// ConvertSampleToRTMP 将 MP4 sample 转换为 RTMP 格式 +func (d *Demuxer) ConvertSampleToRTMP(track *Track, sample box.Sample, allocator *util.ScalableMemoryAllocator, timestampOffset uint64) (videoFrame *rtmp.RTMPVideo, audioFrame *rtmp.RTMPAudio, err error) { + switch track.Cid { + case box.MP4_CODEC_H264: + videoFrame = &rtmp.RTMPVideo{} + videoFrame.SetAllocator(allocator) + videoFrame.CTS = sample.CTS + videoFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) + videoFrame.AppendOne([]byte{util.Conditional[byte](sample.KeyFrame, 0x17, 0x27), 0x01, byte(videoFrame.CTS >> 24), byte(videoFrame.CTS >> 8), byte(videoFrame.CTS)}) + videoFrame.AddRecycleBytes(sample.Data) + case box.MP4_CODEC_H265: + videoFrame = &rtmp.RTMPVideo{} + videoFrame.SetAllocator(allocator) + videoFrame.CTS = uint32(sample.CTS) + videoFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) + var head []byte + var b0 byte = 0b1010_0000 + if sample.KeyFrame { + b0 = 0b1001_0000 + } + if videoFrame.CTS == 0 { + head = videoFrame.NextN(5) + head[0] = b0 | rtmp.PacketTypeCodedFramesX + } else { + head = videoFrame.NextN(8) + head[0] = b0 | rtmp.PacketTypeCodedFrames + util.PutBE(head[5:8], videoFrame.CTS) // cts + } + copy(head[1:], codec.FourCC_H265[:]) + videoFrame.AddRecycleBytes(sample.Data) + case box.MP4_CODEC_AAC: + audioFrame = &rtmp.RTMPAudio{} + audioFrame.SetAllocator(allocator) + audioFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) + audioFrame.AppendOne([]byte{0xaf, 0x01}) + audioFrame.AddRecycleBytes(sample.Data) + case box.MP4_CODEC_G711A: + audioFrame = &rtmp.RTMPAudio{} + audioFrame.SetAllocator(allocator) + audioFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) + audioFrame.AppendOne([]byte{0x72}) + audioFrame.AddRecycleBytes(sample.Data) + case box.MP4_CODEC_G711U: + audioFrame = &rtmp.RTMPAudio{} + audioFrame.SetAllocator(allocator) + audioFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) + audioFrame.AppendOne([]byte{0x82}) + audioFrame.AddRecycleBytes(sample.Data) + } + return +} + +// GetRTMPSequenceFrames 获取预生成的 RTMP 序列帧 +func (d *Demuxer) GetRTMPSequenceFrames() (videoSeq *rtmp.RTMPVideo, audioSeq *rtmp.RTMPAudio) { + return d.RTMPVideoSequence, d.RTMPAudioSequence +} + +// IterateRTMPFrames 迭代预生成的 RTMP 帧 +func (d *Demuxer) IterateRTMPFrames(timestampOffset uint64, yield func(*RTMPFrame) bool) { + for i := range d.RTMPFrames { + frame := &d.RTMPFrames[i] + + // 应用时间戳偏移 + switch f := frame.Frame.(type) { + case *rtmp.RTMPVideo: + f.Timestamp += uint32(timestampOffset) + case *rtmp.RTMPAudio: + f.Timestamp += uint32(timestampOffset) + } + + if !yield(frame) { + return + } + } +} + +// GetMaxTimestamp 获取所有帧中的最大时间戳 +func (d *Demuxer) GetMaxTimestamp() uint64 { + var maxTimestamp uint64 + for _, frame := range d.RTMPFrames { + var timestamp uint64 + switch f := frame.Frame.(type) { + case *rtmp.RTMPVideo: + timestamp = uint64(f.Timestamp) + case *rtmp.RTMPAudio: + timestamp = uint64(f.Timestamp) + } + if timestamp > maxTimestamp { + maxTimestamp = timestamp + } + } + return maxTimestamp +} + +// generateRTMPFrames 生成RTMP序列帧和所有帧数据 +func (d *Demuxer) generateRTMPFrames(allocator *util.ScalableMemoryAllocator) (err error) { + // 生成序列帧 + for _, track := range d.Tracks { + if track.Cid.IsVideo() && d.RTMPVideoSequence == nil { + d.RTMPVideoSequence, _, err = d.CreateRTMPSequenceFrame(track, allocator) + if err != nil { + return err + } + } else if track.Cid.IsAudio() && d.RTMPAudioSequence == nil { + _, d.RTMPAudioSequence, err = d.CreateRTMPSequenceFrame(track, allocator) + if err != nil { + return err + } + } + } + + // 预生成所有 RTMP 帧 + d.RTMPFrames = make([]RTMPFrame, 0) + + // 收集所有样本并按时间戳排序 + type sampleInfo struct { + track *Track + sample box.Sample + sampleIndex uint32 + trackIndex int + } + + var allSamples []sampleInfo + for trackIdx, track := range d.Tracks { + for sampleIdx, sample := range track.Samplelist { + // 读取样本数据 + if _, err = d.reader.Seek(sample.Offset, io.SeekStart); err != nil { + return err + } + sample.Data = allocator.Malloc(sample.Size) + if _, err = io.ReadFull(d.reader, sample.Data); err != nil { + allocator.Free(sample.Data) + return err + } + + allSamples = append(allSamples, sampleInfo{ + track: track, + sample: sample, + sampleIndex: uint32(sampleIdx), + trackIndex: trackIdx, + }) + } + } + + // 按时间戳排序样本 + slices.SortFunc(allSamples, func(a, b sampleInfo) int { + timeA := uint64(a.sample.Timestamp) * uint64(d.moov.MVHD.Timescale) / uint64(a.track.Timescale) + timeB := uint64(b.sample.Timestamp) * uint64(d.moov.MVHD.Timescale) / uint64(b.track.Timescale) + if timeA < timeB { + return -1 + } else if timeA > timeB { + return 1 + } + return 0 + }) + + // 预生成 RTMP 帧 + for _, sampleInfo := range allSamples { + videoFrame, audioFrame, err := d.ConvertSampleToRTMP(sampleInfo.track, sampleInfo.sample, allocator, 0) + if err != nil { + return err + } + + if videoFrame != nil { + d.RTMPFrames = append(d.RTMPFrames, RTMPFrame{Frame: videoFrame}) + } + + if audioFrame != nil { + d.RTMPFrames = append(d.RTMPFrames, RTMPFrame{Frame: audioFrame}) + } + } + + return nil +} + +// createRTMPSampleCallback 创建RTMP样本处理回调函数 +func (d *Demuxer) createRTMPSampleCallback(track *Track, trak *box.TrakBox) box.SampleCallback { + // 首先生成序列帧 + if track.Cid.IsVideo() && d.RTMPVideoSequence == nil { + videoSeq, _, err := d.CreateRTMPSequenceFrame(track, d.RTMPAllocator) + if err == nil { + d.RTMPVideoSequence = videoSeq + } + } else if track.Cid.IsAudio() && d.RTMPAudioSequence == nil { + _, audioSeq, err := d.CreateRTMPSequenceFrame(track, d.RTMPAllocator) + if err == nil { + d.RTMPAudioSequence = audioSeq + } + } + + return func(sample *box.Sample, sampleIndex int) error { + // 读取样本数据 + if _, err := d.reader.Seek(sample.Offset, io.SeekStart); err != nil { + return err + } + sample.Data = d.RTMPAllocator.Malloc(sample.Size) + if _, err := io.ReadFull(d.reader, sample.Data); err != nil { + d.RTMPAllocator.Free(sample.Data) + return err + } + + // 转换为 RTMP 格式 + videoFrame, audioFrame, err := d.ConvertSampleToRTMP(track, *sample, d.RTMPAllocator, 0) + if err != nil { + return err + } + + // 内部收集RTMP帧 + if videoFrame != nil { + d.RTMPFrames = append(d.RTMPFrames, RTMPFrame{Frame: videoFrame}) + } + if audioFrame != nil { + d.RTMPFrames = append(d.RTMPFrames, RTMPFrame{Frame: audioFrame}) + } + + return nil + } +} diff --git a/plugin/mp4/pkg/pull-httpfile.go b/plugin/mp4/pkg/pull-httpfile.go index 6f28a97..febd602 100644 --- a/plugin/mp4/pkg/pull-httpfile.go +++ b/plugin/mp4/pkg/pull-httpfile.go @@ -3,13 +3,12 @@ package mp4 import ( "errors" "io" + "slices" "strings" "time" m7s "m7s.live/v5" - "m7s.live/v5/pkg/codec" "m7s.live/v5/pkg/util" - "m7s.live/v5/plugin/mp4/pkg/box" rtmp "m7s.live/v5/plugin/rtmp/pkg" ) @@ -35,9 +34,40 @@ func (p *HTTPReader) Run() (err error) { content, err = io.ReadAll(p.ReadCloser) demuxer = NewDemuxer(strings.NewReader(string(content))) } - if err = demuxer.Demux(); err != nil { + + // 设置RTMP分配器以启用RTMP帧收集 + demuxer.RTMPAllocator = allocator + + if err = demuxer.DemuxWithAllocator(allocator); err != nil { return } + + // 获取demuxer内部收集的RTMP帧 + rtmpFrames := demuxer.RTMPFrames + + // 按时间戳排序所有帧 + slices.SortFunc(rtmpFrames, func(a, b RTMPFrame) int { + var timeA, timeB uint64 + switch f := a.Frame.(type) { + case *rtmp.RTMPVideo: + timeA = uint64(f.Timestamp) + case *rtmp.RTMPAudio: + timeA = uint64(f.Timestamp) + } + switch f := b.Frame.(type) { + case *rtmp.RTMPVideo: + timeB = uint64(f.Timestamp) + case *rtmp.RTMPAudio: + timeB = uint64(f.Timestamp) + } + if timeA < timeB { + return -1 + } else if timeA > timeB { + return 1 + } + return 0 + }) + publisher.OnSeek = func(seekTime time.Time) { p.Stop(errors.New("seek")) pullJob.Connection.Args.Set(util.StartKey, seekTime.Local().Format(util.LocalTimeFormat)) @@ -48,103 +78,61 @@ func (p *HTTPReader) Run() (err error) { seekTime, _ := time.Parse(util.LocalTimeFormat, pullJob.Connection.Args.Get(util.StartKey)) demuxer.SeekTime(uint64(seekTime.UnixMilli())) } - for _, track := range demuxer.Tracks { - switch track.Cid { - case box.MP4_CODEC_H264: - var sequence rtmp.RTMPVideo - sequence.SetAllocator(allocator) - sequence.Append([]byte{0x17, 0x00, 0x00, 0x00, 0x00}, track.ExtraData) - err = publisher.WriteVideo(&sequence) - case box.MP4_CODEC_H265: - var sequence rtmp.RTMPVideo - sequence.SetAllocator(allocator) - sequence.Append([]byte{0b1001_0000 | rtmp.PacketTypeSequenceStart}, codec.FourCC_H265[:], track.ExtraData) - err = publisher.WriteVideo(&sequence) - case box.MP4_CODEC_AAC: - var sequence rtmp.RTMPAudio - sequence.SetAllocator(allocator) - sequence.Append([]byte{0xaf, 0x00}, track.ExtraData) - err = publisher.WriteAudio(&sequence) + + // 读取预生成的 RTMP 序列帧 + videoSeq, audioSeq := demuxer.GetRTMPSequenceFrames() + if videoSeq != nil { + err = publisher.WriteVideo(videoSeq) + if err != nil { + return err + } + } + if audioSeq != nil { + err = publisher.WriteAudio(audioSeq) + if err != nil { + return err } } // 计算最大时间戳用于累计偏移 var maxTimestamp uint64 - for track, sample := range demuxer.ReadSample { - timestamp := uint64(sample.Timestamp) * 1000 / uint64(track.Timescale) + for _, frame := range rtmpFrames { + var timestamp uint64 + switch f := frame.Frame.(type) { + case *rtmp.RTMPVideo: + timestamp = uint64(f.Timestamp) + case *rtmp.RTMPAudio: + timestamp = uint64(f.Timestamp) + } if timestamp > maxTimestamp { maxTimestamp = timestamp } } + var timestampOffset uint64 loop := p.PullJob.Loop for { - demuxer.ReadSampleIdx = make([]uint32, len(demuxer.Tracks)) - for track, sample := range demuxer.ReadSample { + // 使用预生成的 RTMP 帧进行播放 + for _, frame := range rtmpFrames { if p.IsStopped() { - return + return nil } - if _, err = demuxer.reader.Seek(sample.Offset, io.SeekStart); err != nil { - return + + // 应用时间戳偏移 + switch f := frame.Frame.(type) { + case *rtmp.RTMPVideo: + f.Timestamp += uint32(timestampOffset) + err = publisher.WriteVideo(f) + case *rtmp.RTMPAudio: + f.Timestamp += uint32(timestampOffset) + err = publisher.WriteAudio(f) } - sample.Data = allocator.Malloc(sample.Size) - if _, err = io.ReadFull(demuxer.reader, sample.Data); err != nil { - allocator.Free(sample.Data) - return - } - switch track.Cid { - case box.MP4_CODEC_H264: - var videoFrame rtmp.RTMPVideo - videoFrame.SetAllocator(allocator) - videoFrame.CTS = sample.CTS - videoFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) - videoFrame.AppendOne([]byte{util.Conditional[byte](sample.KeyFrame, 0x17, 0x27), 0x01, byte(videoFrame.CTS >> 24), byte(videoFrame.CTS >> 8), byte(videoFrame.CTS)}) - videoFrame.AddRecycleBytes(sample.Data) - err = publisher.WriteVideo(&videoFrame) - case box.MP4_CODEC_H265: - var videoFrame rtmp.RTMPVideo - videoFrame.SetAllocator(allocator) - videoFrame.CTS = uint32(sample.CTS) - videoFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) - var head []byte - var b0 byte = 0b1010_0000 - if sample.KeyFrame { - b0 = 0b1001_0000 - } - if videoFrame.CTS == 0 { - head = videoFrame.NextN(5) - head[0] = b0 | rtmp.PacketTypeCodedFramesX - } else { - head = videoFrame.NextN(8) - head[0] = b0 | rtmp.PacketTypeCodedFrames - util.PutBE(head[5:8], videoFrame.CTS) // cts - } - copy(head[1:], codec.FourCC_H265[:]) - videoFrame.AddRecycleBytes(sample.Data) - err = publisher.WriteVideo(&videoFrame) - case box.MP4_CODEC_AAC: - var audioFrame rtmp.RTMPAudio - audioFrame.SetAllocator(allocator) - audioFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) - audioFrame.AppendOne([]byte{0xaf, 0x01}) - audioFrame.AddRecycleBytes(sample.Data) - err = publisher.WriteAudio(&audioFrame) - case box.MP4_CODEC_G711A: - var audioFrame rtmp.RTMPAudio - audioFrame.SetAllocator(allocator) - audioFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) - audioFrame.AppendOne([]byte{0x72}) - audioFrame.AddRecycleBytes(sample.Data) - err = publisher.WriteAudio(&audioFrame) - case box.MP4_CODEC_G711U: - var audioFrame rtmp.RTMPAudio - audioFrame.SetAllocator(allocator) - audioFrame.Timestamp = uint32(uint64(sample.Timestamp)*1000/uint64(track.Timescale) + timestampOffset) - audioFrame.AppendOne([]byte{0x82}) - audioFrame.AddRecycleBytes(sample.Data) - err = publisher.WriteAudio(&audioFrame) + + if err != nil { + return err } } + if loop >= 0 { loop-- if loop == -1 {