fix: g711 override memory

fix: PacketizeRTP memory leak
fix: avoid read files beyond the log path
This commit is contained in:
langhuihui
2023-05-25 19:12:01 +08:00
parent d3b26d69fc
commit fc7ac81c4e
13 changed files with 132 additions and 40 deletions

View File

@@ -102,6 +102,7 @@ type AVTrack interface {
Flush()
SetSpeedLimit(time.Duration)
GetRTPFromPool() *util.ListItem[RTPFrame]
GetFromPool(util.IBytes) *util.ListItem[util.Buffer]
}
type VideoTrack interface {
AVTrack
@@ -113,6 +114,6 @@ type VideoTrack interface {
type AudioTrack interface {
AVTrack
WriteADTS(uint32, []byte)
WriteRaw(uint32, []byte)
WriteADTS(uint32, util.IBytes)
WriteRawBytes(uint32, util.IBytes)
}

View File

@@ -6,6 +6,7 @@ import (
"github.com/yapingcat/gomedia/go-mp4"
"go.uber.org/zap"
"m7s.live/engine/v4/track"
"m7s.live/engine/v4/util"
)
type MP4Publisher struct {
@@ -52,9 +53,9 @@ func (p *MP4Publisher) ReadMP4Data(source io.ReadSeeker) error {
case mp4.MP4_CODEC_H264, mp4.MP4_CODEC_H265:
p.VideoTrack.WriteAnnexB(uint32(pkg.Pts*90), uint32(pkg.Dts*90), pkg.Data)
case mp4.MP4_CODEC_AAC:
p.AudioTrack.WriteADTS(uint32(pkg.Pts*90), pkg.Data)
p.AudioTrack.WriteADTS(uint32(pkg.Pts*90), util.Buffer(pkg.Data))
case mp4.MP4_CODEC_G711A, mp4.MP4_CODEC_G711U:
p.AudioTrack.WriteRaw(uint32(pkg.Pts*90), pkg.Data)
p.AudioTrack.WriteRawBytes(uint32(pkg.Pts*90), util.Buffer(pkg.Data))
}
}
}

View File

@@ -85,7 +85,7 @@ func (t *TSPublisher) ReadPES() {
case *track.AAC:
t.AudioTrack.WriteADTS(uint32(pes.Header.Pts), pes.Payload)
case *track.G711:
t.AudioTrack.WriteRaw(uint32(pes.Header.Pts), pes.Payload)
t.AudioTrack.WriteRawBytes(uint32(pes.Header.Pts), pes.Payload)
}
}
}

View File

@@ -77,6 +77,11 @@ func (a AudioFrame) GetADTS() (r net.Buffers) {
return
}
func (a AudioFrame) WriteRawTo(w io.Writer) (n int64, err error) {
aulist := a.AUList.ToBuffers()
return aulist.WriteTo(w)
}
func (v VideoFrame) GetAnnexB() (r net.Buffers) {
if v.IFrame {
r = v.ParamaterSets.GetAnnexB()

View File

@@ -41,7 +41,8 @@ type AAC struct {
fragments *util.BLL // 用于处理不完整的AU,缺少的字节数
}
func (aac *AAC) WriteADTS(ts uint32, adts []byte) {
func (aac *AAC) WriteADTS(ts uint32, b util.IBytes) {
adts := b.Bytes()
if aac.SequenceHead == nil {
profile := ((adts[2] & 0xc0) >> 6) + 1
sampleRate := (adts[2] & 0x3c) >> 2
@@ -64,7 +65,7 @@ func (aac *AAC) WriteADTS(ts uint32, adts []byte) {
}
frameLen = (int(adts[3]&3) << 11) | (int(adts[4]) << 3) | (int(adts[5]) >> 5)
}
aac.Value.ADTS = aac.BytesPool.GetShell(adts)
aac.Value.ADTS = aac.GetFromPool(b)
aac.Flush()
}

View File

@@ -43,7 +43,7 @@ func (a *Audio) GetName() string {
return a.Name
}
func (av *Audio) WriteADTS(pts uint32, adts []byte) {
func (av *Audio) WriteADTS(pts uint32, adts util.IBytes) {
}
@@ -56,10 +56,10 @@ func (av *Audio) Flush() {
av.Media.Flush()
}
func (av *Audio) WriteRaw(pts uint32, raw []byte) {
func (av *Audio) WriteRawBytes(pts uint32, raw util.IBytes) {
curValue := &av.Value
curValue.BytesIn += len(raw)
av.AppendAuBytes(raw)
curValue.BytesIn += raw.Len()
av.Value.AUList.Push(av.GetFromPool(raw))
av.generateTimestamp(pts)
av.Flush()
}

View File

@@ -99,6 +99,16 @@ type Media struct {
流速控制
}
func (av *Media) GetFromPool(b util.IBytes) (item *util.ListItem[util.Buffer]) {
if b.Reuse() {
item = av.BytesPool.Get(b.Len())
copy(item.Value, b.Bytes())
} else {
return av.BytesPool.GetShell(b.Bytes())
}
return
}
func (av *Media) GetRBSize() int {
return av.RingBuffer.Size
}

View File

@@ -52,7 +52,8 @@ func (av *Media) PacketizeRTP(payloads ...[][]byte) {
for _, pp := range payloads {
rtpItem = av.GetRTPFromPool()
packet := &rtpItem.Value
packet.Payload = packet.Payload[:0]
br := util.LimitBuffer{Buffer: packet.Payload}
br.Reset()
if av.SampleRate != 90000 {
packet.Timestamp = uint32(time.Duration(av.SampleRate) * av.Value.PTS / 90000)
} else {
@@ -60,8 +61,9 @@ func (av *Media) PacketizeRTP(payloads ...[][]byte) {
}
packet.Marker = false
for _, p := range pp {
packet.Payload = append(packet.Payload, p...)
br.Write(p)
}
packet.Payload = br.Bytes()
av.Value.RTP.Push(rtpItem)
}
// 最后一个rtp包标记为true

View File

@@ -93,6 +93,7 @@ func (vt *Video) computeGOP() {
func (vt *Video) writeAnnexBSlice(nalu []byte) {
common.SplitAnnexB(nalu, vt.WriteSliceBytes, codec.NALU_Delimiter1)
}
func (vt *Video) WriteNalu(pts uint32, dts uint32, nalu []byte) {
if dts == 0 {
vt.generateTimestamp(pts)
@@ -104,6 +105,7 @@ func (vt *Video) WriteNalu(pts uint32, dts uint32, nalu []byte) {
vt.WriteSliceBytes(nalu)
vt.Flush()
}
func (vt *Video) WriteAnnexB(pts uint32, dts uint32, frame []byte) {
if dts == 0 {
vt.generateTimestamp(pts)

View File

@@ -2,13 +2,78 @@ package util
import (
"encoding/binary"
"fmt"
"io"
"math"
"net"
)
// Buffer 用于方便自动扩容的内存写入,已经读取
type Buffer []byte
// ReuseBuffer 重用buffer内容可能会被覆盖要尽早复制
type ReuseBuffer struct {
Buffer
}
func (ReuseBuffer) Reuse() bool {
return true
}
// LimitBuffer 限制buffer的长度不会改变原来的buffer防止内存泄漏
type LimitBuffer struct {
Buffer
}
func (b *LimitBuffer) ReadN(n int) (result LimitBuffer) {
result.Buffer = b.Buffer.ReadN(n)
return
}
func (b LimitBuffer) Clone() (result LimitBuffer) {
result.Buffer = b.Buffer.Clone()
return
}
func (b LimitBuffer) SubBuf(start int, length int) (result LimitBuffer) {
result.Buffer = b.Buffer.SubBuf(start, length)
return
}
func (b *LimitBuffer) Malloc(count int) (result LimitBuffer) {
l := b.Len()
newL := l + count
if c := b.Cap(); newL > c {
panic(fmt.Sprintf("LimitBuffer Malloc %d > %d", newL, c))
} else {
*b = b.SubBuf(0, newL)
}
return b.SubBuf(l, count)
}
func (b *LimitBuffer) Write(a []byte) (n int, err error) {
l := b.Len()
newL := l + len(a)
if c := b.Cap(); newL > c {
panic(fmt.Sprintf("LimitBuffer Write %d > %d", newL, c))
} else {
b.Buffer = b.Buffer.SubBuf(0, newL)
copy(b.Buffer[l:], a)
}
return len(a), nil
}
// IBytes 用于区分传入的内存是否是复用内存,例如从网络中读取的数据,如果是复用内存,需要尽早复制
type IBytes interface {
Len() int
Bytes() []byte
Reuse() bool
}
func (Buffer) Reuse() bool {
return false
}
func (b *Buffer) Read(buf []byte) (n int, err error) {
if !b.CanReadN(len(buf)) {
copy(buf, *b)
@@ -62,7 +127,14 @@ func (b *Buffer) WriteString(a string) {
*b = append(*b, a...)
}
func (b *Buffer) Write(a []byte) (n int, err error) {
l := b.Len()
newL := l + len(a)
if newL > b.Cap() {
*b = append(*b, a...)
} else {
*b = b.SubBuf(0, newL)
copy((*b)[l:], a)
}
return len(a), nil
}
@@ -70,6 +142,10 @@ func (b Buffer) Clone() (result Buffer) {
return append(result, b...)
}
func (b Buffer) Bytes() []byte {
return b
}
func (b Buffer) Len() int {
return len(b)
}
@@ -130,16 +206,6 @@ func (b *Buffer) MarshalAMFs(v ...any) {
*b = amf.Marshals(v...)
}
// MallocSlice 用来对容量够的slice进行长度扩展+1并返回新的位置的指针用于写入
func MallocSlice[T any](slice *[]T) *T {
oslice := *slice
if rawLen := len(oslice); cap(oslice) > rawLen {
*slice = oslice[:rawLen+1]
return &(*slice)[rawLen]
}
return nil
}
// ConcatBuffers 合并碎片内存为一个完整内存
func ConcatBuffers[T ~[]byte](input []T) (out []byte) {
for _, v := range input {

View File

@@ -16,15 +16,3 @@ func TestBuffer(t *testing.T) {
}
})
}
func TestMallocSlice(t *testing.T) {
t.Run(t.Name(), func(t *testing.T) {
var a [][]byte = [][]byte{}
b := MallocSlice(&a)
if *b != nil {
t.Fail()
} else if *b = []byte{1}; a[0][0] != 1 {
t.Fail()
}
})
}

View File

@@ -6,6 +6,7 @@ import (
"os/signal"
"path/filepath"
"runtime"
"strings"
"syscall"
)
@@ -46,3 +47,12 @@ func WaitTerm(cancel context.CancelFunc) {
<-sigc
cancel()
}
// 判断目录是否是基础目录的子目录
func IsSubdir(baseDir, joinedDir string) bool {
rel, err := filepath.Rel(baseDir, joinedDir)
if err != nil {
return false
}
return !strings.HasPrefix(rel, "..") && !strings.HasPrefix(rel, "/")
}

View File

@@ -21,10 +21,13 @@ func ReturnJson[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWr
return
}
}
} else if err := json.NewEncoder(rw).Encode(fetch()); err != nil {
} else {
rw.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(rw).Encode(fetch()); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
}
func ReturnYaml[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWriter, r *http.Request) {
if r.Header.Get("Accept") == "text/event-stream" {
@@ -36,10 +39,13 @@ func ReturnYaml[T any](fetch func() T, tickDur time.Duration, rw http.ResponseWr
return
}
}
} else if err := yaml.NewEncoder(rw).Encode(fetch()); err != nil {
} else {
rw.Header().Set("Content-Type", "application/yaml")
if err := yaml.NewEncoder(rw).Encode(fetch()); err != nil {
http.Error(rw, err.Error(), http.StatusInternalServerError)
}
}
}
func ListenUDP(address string, networkBuffer int) (*net.UDPConn, error) {