diff --git a/common/index.go b/common/index.go index c9820e2..67f1264 100644 --- a/common/index.go +++ b/common/index.go @@ -102,6 +102,7 @@ type AVTrack interface { Flush() SetSpeedLimit(time.Duration) GetRTPFromPool() *util.ListItem[RTPFrame] + GetFromPool(util.IBytes) *util.ListItem[util.Buffer] } type VideoTrack interface { AVTrack @@ -113,6 +114,6 @@ type VideoTrack interface { type AudioTrack interface { AVTrack - WriteADTS(uint32, []byte) - WriteRaw(uint32, []byte) + WriteADTS(uint32, util.IBytes) + WriteRawBytes(uint32, util.IBytes) } diff --git a/publisher-mp4.go b/publisher-mp4.go index 2d1ca7b..db3f6b9 100644 --- a/publisher-mp4.go +++ b/publisher-mp4.go @@ -6,6 +6,7 @@ import ( "github.com/yapingcat/gomedia/go-mp4" "go.uber.org/zap" "m7s.live/engine/v4/track" + "m7s.live/engine/v4/util" ) type MP4Publisher struct { @@ -52,9 +53,9 @@ func (p *MP4Publisher) ReadMP4Data(source io.ReadSeeker) error { case mp4.MP4_CODEC_H264, mp4.MP4_CODEC_H265: p.VideoTrack.WriteAnnexB(uint32(pkg.Pts*90), uint32(pkg.Dts*90), pkg.Data) case mp4.MP4_CODEC_AAC: - p.AudioTrack.WriteADTS(uint32(pkg.Pts*90), pkg.Data) + p.AudioTrack.WriteADTS(uint32(pkg.Pts*90), util.Buffer(pkg.Data)) case mp4.MP4_CODEC_G711A, mp4.MP4_CODEC_G711U: - p.AudioTrack.WriteRaw(uint32(pkg.Pts*90), pkg.Data) + p.AudioTrack.WriteRawBytes(uint32(pkg.Pts*90), util.Buffer(pkg.Data)) } } } diff --git a/publisher-ts.go b/publisher-ts.go index 9bd20f0..eac1234 100644 --- a/publisher-ts.go +++ b/publisher-ts.go @@ -85,7 +85,7 @@ func (t *TSPublisher) ReadPES() { case *track.AAC: t.AudioTrack.WriteADTS(uint32(pes.Header.Pts), pes.Payload) case *track.G711: - t.AudioTrack.WriteRaw(uint32(pes.Header.Pts), pes.Payload) + t.AudioTrack.WriteRawBytes(uint32(pes.Header.Pts), pes.Payload) } } } diff --git a/subscriber.go b/subscriber.go index 4b7d729..e56ef9c 100644 --- a/subscriber.go +++ b/subscriber.go @@ -77,6 +77,11 @@ func (a AudioFrame) GetADTS() (r net.Buffers) { return } +func (a AudioFrame) WriteRawTo(w io.Writer) (n int64, err error) { + aulist := a.AUList.ToBuffers() + return aulist.WriteTo(w) +} + func (v VideoFrame) GetAnnexB() (r net.Buffers) { if v.IFrame { r = v.ParamaterSets.GetAnnexB() diff --git a/track/aac.go b/track/aac.go index 24f4b0c..458619f 100644 --- a/track/aac.go +++ b/track/aac.go @@ -41,7 +41,8 @@ type AAC struct { fragments *util.BLL // 用于处理不完整的AU,缺少的字节数 } -func (aac *AAC) WriteADTS(ts uint32, adts []byte) { +func (aac *AAC) WriteADTS(ts uint32, b util.IBytes) { + adts := b.Bytes() if aac.SequenceHead == nil { profile := ((adts[2] & 0xc0) >> 6) + 1 sampleRate := (adts[2] & 0x3c) >> 2 @@ -64,7 +65,7 @@ func (aac *AAC) WriteADTS(ts uint32, adts []byte) { } frameLen = (int(adts[3]&3) << 11) | (int(adts[4]) << 3) | (int(adts[5]) >> 5) } - aac.Value.ADTS = aac.BytesPool.GetShell(adts) + aac.Value.ADTS = aac.GetFromPool(b) aac.Flush() } diff --git a/track/audio.go b/track/audio.go index d929bec..0fab951 100644 --- a/track/audio.go +++ b/track/audio.go @@ -43,7 +43,7 @@ func (a *Audio) GetName() string { return a.Name } -func (av *Audio) WriteADTS(pts uint32, adts []byte) { +func (av *Audio) WriteADTS(pts uint32, adts util.IBytes) { } @@ -56,10 +56,10 @@ func (av *Audio) Flush() { av.Media.Flush() } -func (av *Audio) WriteRaw(pts uint32, raw []byte) { +func (av *Audio) WriteRawBytes(pts uint32, raw util.IBytes) { curValue := &av.Value - curValue.BytesIn += len(raw) - av.AppendAuBytes(raw) + curValue.BytesIn += raw.Len() + av.Value.AUList.Push(av.GetFromPool(raw)) av.generateTimestamp(pts) av.Flush() } diff --git a/track/base.go b/track/base.go index 10ca67f..c8da7c6 100644 --- a/track/base.go +++ b/track/base.go @@ -99,6 +99,16 @@ type Media struct { 流速控制 } +func (av *Media) GetFromPool(b util.IBytes) (item *util.ListItem[util.Buffer]) { + if b.Reuse() { + item = av.BytesPool.Get(b.Len()) + copy(item.Value, b.Bytes()) + } else { + return av.BytesPool.GetShell(b.Bytes()) + } + return +} + func (av *Media) GetRBSize() int { return av.RingBuffer.Size } diff --git a/track/rtp.go b/track/rtp.go index c52680c..bf9bd9b 100644 --- a/track/rtp.go +++ b/track/rtp.go @@ -52,7 +52,8 @@ func (av *Media) PacketizeRTP(payloads ...[][]byte) { for _, pp := range payloads { rtpItem = av.GetRTPFromPool() packet := &rtpItem.Value - packet.Payload = packet.Payload[:0] + br := util.LimitBuffer{Buffer: packet.Payload} + br.Reset() if av.SampleRate != 90000 { packet.Timestamp = uint32(time.Duration(av.SampleRate) * av.Value.PTS / 90000) } else { @@ -60,8 +61,9 @@ func (av *Media) PacketizeRTP(payloads ...[][]byte) { } packet.Marker = false for _, p := range pp { - packet.Payload = append(packet.Payload, p...) + br.Write(p) } + packet.Payload = br.Bytes() av.Value.RTP.Push(rtpItem) } // 最后一个rtp包标记为true diff --git a/track/video.go b/track/video.go index 5de4667..2e29a9b 100644 --- a/track/video.go +++ b/track/video.go @@ -93,6 +93,7 @@ func (vt *Video) computeGOP() { func (vt *Video) writeAnnexBSlice(nalu []byte) { common.SplitAnnexB(nalu, vt.WriteSliceBytes, codec.NALU_Delimiter1) } + func (vt *Video) WriteNalu(pts uint32, dts uint32, nalu []byte) { if dts == 0 { vt.generateTimestamp(pts) @@ -104,6 +105,7 @@ func (vt *Video) WriteNalu(pts uint32, dts uint32, nalu []byte) { vt.WriteSliceBytes(nalu) vt.Flush() } + func (vt *Video) WriteAnnexB(pts uint32, dts uint32, frame []byte) { if dts == 0 { vt.generateTimestamp(pts) diff --git a/util/buffer.go b/util/buffer.go index fa5902e..e04639b 100644 --- a/util/buffer.go +++ b/util/buffer.go @@ -2,13 +2,78 @@ package util import ( "encoding/binary" + "fmt" "io" "math" "net" ) +// Buffer 用于方便自动扩容的内存写入,已经读取 type Buffer []byte +// ReuseBuffer 重用buffer,内容可能会被覆盖,要尽早复制 +type ReuseBuffer struct { + Buffer +} + +func (ReuseBuffer) Reuse() bool { + return true +} + +// LimitBuffer 限制buffer的长度,不会改变原来的buffer,防止内存泄漏 +type LimitBuffer struct { + Buffer +} + +func (b *LimitBuffer) ReadN(n int) (result LimitBuffer) { + result.Buffer = b.Buffer.ReadN(n) + return +} + +func (b LimitBuffer) Clone() (result LimitBuffer) { + result.Buffer = b.Buffer.Clone() + return +} + +func (b LimitBuffer) SubBuf(start int, length int) (result LimitBuffer) { + result.Buffer = b.Buffer.SubBuf(start, length) + return +} + +func (b *LimitBuffer) Malloc(count int) (result LimitBuffer) { + l := b.Len() + newL := l + count + if c := b.Cap(); newL > c { + panic(fmt.Sprintf("LimitBuffer Malloc %d > %d", newL, c)) + } else { + *b = b.SubBuf(0, newL) + } + return b.SubBuf(l, count) +} + +func (b *LimitBuffer) Write(a []byte) (n int, err error) { + l := b.Len() + newL := l + len(a) + if c := b.Cap(); newL > c { + panic(fmt.Sprintf("LimitBuffer Write %d > %d", newL, c)) + } else { + b.Buffer = b.Buffer.SubBuf(0, newL) + copy(b.Buffer[l:], a) + } + return len(a), nil +} + +// IBytes 用于区分传入的内存是否是复用内存,例如从网络中读取的数据,如果是复用内存,需要尽早复制 +type IBytes interface { + Len() int + Bytes() []byte + Reuse() bool +} + +func (Buffer) Reuse() bool { + return false +} + func (b *Buffer) Read(buf []byte) (n int, err error) { if !b.CanReadN(len(buf)) { copy(buf, *b) @@ -62,7 +127,14 @@ func (b *Buffer) WriteString(a string) { *b = append(*b, a...) } func (b *Buffer) Write(a []byte) (n int, err error) { - *b = append(*b, a...) + l := b.Len() + newL := l + len(a) + if newL > b.Cap() { + *b = append(*b, a...) + } else { + *b = b.SubBuf(0, newL) + copy((*b)[l:], a) + } return len(a), nil } @@ -70,6 +142,10 @@ func (b Buffer) Clone() (result Buffer) { return append(result, b...) } +func (b Buffer) Bytes() []byte { + return b +} + func (b Buffer) Len() int { return len(b) } @@ -130,16 +206,6 @@ func (b *Buffer) MarshalAMFs(v ...any) { *b = amf.Marshals(v...) } -// MallocSlice 用来对容量够的slice进行长度扩展+1,并返回新的位置的指针,用于写入 -func MallocSlice[T any](slice *[]T) *T { - oslice := *slice - if rawLen := len(oslice); cap(oslice) > rawLen { - *slice = oslice[:rawLen+1] - return &(*slice)[rawLen] - } - return nil -} - // ConcatBuffers 合并碎片内存为一个完整内存 func ConcatBuffers[T ~[]byte](input []T) (out []byte) { for _, v := range input { diff --git a/util/buffer_test.go b/util/buffer_test.go index 6a42421..383d71e 100644 --- a/util/buffer_test.go +++ b/util/buffer_test.go @@ -16,15 +16,3 @@ func TestBuffer(t *testing.T) { } }) } - -func TestMallocSlice(t *testing.T) { - t.Run(t.Name(), func(t *testing.T) { - var a [][]byte = [][]byte{} - b := MallocSlice(&a) - if *b != nil { - t.Fail() - } else if *b = []byte{1}; a[0][0] != 1 { - t.Fail() - } - }) -} diff --git a/util/index.go b/util/index.go index 73341bd..c8c77f6 100644 --- a/util/index.go +++ b/util/index.go @@ -6,6 +6,7 @@ import ( "os/signal" "path/filepath" "runtime" + "strings" "syscall" ) @@ -46,3 +47,12 @@ func WaitTerm(cancel context.CancelFunc) { <-sigc cancel() } + +// 判断目录是否是基础目录的子目录 +func IsSubdir(baseDir, joinedDir string) bool { + rel, err := filepath.Rel(baseDir, joinedDir) + if err != nil { + return false + } + return !strings.HasPrefix(rel, "..") && !strings.HasPrefix(rel, "/") +} diff --git a/util/socket.go b/util/socket.go index fd6512e..636285e 100644 --- a/util/socket.go +++ b/util/socket.go @@ -21,8 +21,11 @@ func ReturnJson[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWr return } } - } else if err := json.NewEncoder(rw).Encode(fetch()); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + rw.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(rw).Encode(fetch()); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } } } @@ -36,8 +39,11 @@ func ReturnYaml[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWr return } } - } else if err := yaml.NewEncoder(rw).Encode(fetch()); err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + } else { + rw.Header().Set("Content-Type", "application/yaml") + if err := yaml.NewEncoder(rw).Encode(fetch()); err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } } }