mirror of
https://github.com/langhuihui/monibuca.git
synced 2025-12-24 13:48:04 +08:00
refactor: frame converter and mp4 track improvements
- Refactor frame converter implementation - Update mp4 track to use ICodex - General refactoring and code improvements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,17 +1,13 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/bluenviron/mediacommon/pkg/bits"
|
||||
"github.com/deepch/vdk/codec/aacparser"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"github.com/pion/webrtc/v4"
|
||||
@@ -21,43 +17,24 @@ import (
|
||||
)
|
||||
|
||||
type RTPData struct {
|
||||
*webrtc.RTPCodecParameters
|
||||
Packets []*rtp.Packet
|
||||
util.RecyclableMemory
|
||||
Sample
|
||||
Packets util.ReuseArray[rtp.Packet]
|
||||
}
|
||||
|
||||
func (r *RTPData) Dump(t byte, w io.Writer) {
|
||||
m := r.GetAllocator().Borrow(3 + len(r.Packets)*2 + r.GetSize())
|
||||
m[0] = t
|
||||
binary.BigEndian.PutUint16(m[1:], uint16(len(r.Packets)))
|
||||
offset := 3
|
||||
for _, p := range r.Packets {
|
||||
size := p.MarshalSize()
|
||||
binary.BigEndian.PutUint16(m[offset:], uint16(size))
|
||||
offset += 2
|
||||
p.MarshalTo(m[offset:])
|
||||
offset += size
|
||||
}
|
||||
w.Write(m)
|
||||
func (r *RTPData) Recycle() {
|
||||
r.RecyclableMemory.Recycle()
|
||||
r.Packets.Reset()
|
||||
}
|
||||
|
||||
func (r *RTPData) String() (s string) {
|
||||
for _, p := range r.Packets {
|
||||
for p := range r.Packets.RangePoint {
|
||||
s += fmt.Sprintf("t: %d, s: %d, p: %02X %d\n", p.Timestamp, p.SequenceNumber, p.Payload[0:2], len(p.Payload))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (r *RTPData) GetTimestamp() time.Duration {
|
||||
return time.Duration(r.Packets[0].Timestamp) * time.Second / time.Duration(r.ClockRate)
|
||||
}
|
||||
|
||||
func (r *RTPData) GetCTS() time.Duration {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *RTPData) GetSize() (s int) {
|
||||
for _, p := range r.Packets {
|
||||
for p := range r.Packets.RangePoint {
|
||||
s += p.MarshalSize()
|
||||
}
|
||||
return
|
||||
@@ -72,19 +49,19 @@ type (
|
||||
}
|
||||
PCMACtx struct {
|
||||
RTPCtx
|
||||
codec.PCMACtx
|
||||
*codec.PCMACtx
|
||||
}
|
||||
PCMUCtx struct {
|
||||
RTPCtx
|
||||
codec.PCMUCtx
|
||||
*codec.PCMUCtx
|
||||
}
|
||||
OPUSCtx struct {
|
||||
RTPCtx
|
||||
codec.OPUSCtx
|
||||
*codec.OPUSCtx
|
||||
}
|
||||
AACCtx struct {
|
||||
RTPCtx
|
||||
codec.AACCtx
|
||||
*codec.AACCtx
|
||||
SizeLength int // 通常为13
|
||||
IndexLength int
|
||||
IndexDeltaLength int
|
||||
@@ -94,7 +71,7 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
func (r *RTPCtx) parseFmtpLine(cp *webrtc.RTPCodecParameters) {
|
||||
func (r *RTPCtx) ParseFmtpLine(cp *webrtc.RTPCodecParameters) {
|
||||
r.RTPCodecParameters = *cp
|
||||
r.Fmtp = make(map[string]string)
|
||||
kvs := strings.Split(r.SDPFmtpLine, ";")
|
||||
@@ -121,9 +98,9 @@ func (r *RTPCtx) GetRTPCodecParameter() webrtc.RTPCodecParameters {
|
||||
return r.RTPCodecParameters
|
||||
}
|
||||
|
||||
func (r *RTPData) Append(ctx *RTPCtx, ts uint32, payload []byte) (lastPacket *rtp.Packet) {
|
||||
func (r *RTPData) Append(ctx *RTPCtx, ts uint32, payload []byte) *rtp.Packet {
|
||||
ctx.SequenceNumber++
|
||||
lastPacket = &rtp.Packet{
|
||||
r.Packets = append(r.Packets, rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
@@ -132,135 +109,19 @@ func (r *RTPData) Append(ctx *RTPCtx, ts uint32, payload []byte) (lastPacket *rt
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: payload,
|
||||
}
|
||||
r.Packets = append(r.Packets, lastPacket)
|
||||
return
|
||||
})
|
||||
return &r.Packets[len(r.Packets)-1]
|
||||
}
|
||||
|
||||
func (r *RTPData) ConvertCtx(from codec.ICodecCtx) (to codec.ICodecCtx, seq IAVFrame, err error) {
|
||||
switch from.FourCC() {
|
||||
case codec.FourCC_H264:
|
||||
var ctx H264Ctx
|
||||
ctx.H264Ctx = *from.GetBase().(*codec.H264Ctx)
|
||||
ctx.PayloadType = 96
|
||||
ctx.MimeType = webrtc.MimeTypeH264
|
||||
ctx.ClockRate = 90000
|
||||
spsInfo := ctx.SPSInfo
|
||||
ctx.SDPFmtpLine = fmt.Sprintf("sprop-parameter-sets=%s,%s;profile-level-id=%02x%02x%02x;level-asymmetry-allowed=1;packetization-mode=1", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), spsInfo.ProfileIdc, spsInfo.ConstraintSetFlag, spsInfo.LevelIdc)
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
to = &ctx
|
||||
case codec.FourCC_H265:
|
||||
var ctx H265Ctx
|
||||
ctx.H265Ctx = *from.GetBase().(*codec.H265Ctx)
|
||||
ctx.PayloadType = 98
|
||||
ctx.MimeType = webrtc.MimeTypeH265
|
||||
ctx.SDPFmtpLine = fmt.Sprintf("profile-id=1;sprop-sps=%s;sprop-pps=%s;sprop-vps=%s", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), base64.StdEncoding.EncodeToString(ctx.VPS()))
|
||||
ctx.ClockRate = 90000
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
to = &ctx
|
||||
case codec.FourCC_MP4A:
|
||||
var ctx AACCtx
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
ctx.AACCtx = *from.GetBase().(*codec.AACCtx)
|
||||
ctx.MimeType = "audio/MPEG4-GENERIC"
|
||||
ctx.SDPFmtpLine = fmt.Sprintf("profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3;config=%s", hex.EncodeToString(ctx.AACCtx.ConfigBytes))
|
||||
ctx.IndexLength = 3
|
||||
ctx.IndexDeltaLength = 3
|
||||
ctx.SizeLength = 13
|
||||
ctx.RTPCtx.Channels = uint16(ctx.AACCtx.GetChannels())
|
||||
ctx.PayloadType = 97
|
||||
ctx.ClockRate = uint32(ctx.CodecData.SampleRate())
|
||||
to = &ctx
|
||||
case codec.FourCC_ALAW:
|
||||
var ctx PCMACtx
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
ctx.PCMACtx = *from.GetBase().(*codec.PCMACtx)
|
||||
ctx.MimeType = webrtc.MimeTypePCMA
|
||||
ctx.PayloadType = 8
|
||||
ctx.ClockRate = uint32(ctx.SampleRate)
|
||||
to = &ctx
|
||||
case codec.FourCC_ULAW:
|
||||
var ctx PCMUCtx
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
ctx.PCMUCtx = *from.GetBase().(*codec.PCMUCtx)
|
||||
ctx.MimeType = webrtc.MimeTypePCMU
|
||||
ctx.PayloadType = 0
|
||||
ctx.ClockRate = uint32(ctx.SampleRate)
|
||||
to = &ctx
|
||||
case codec.FourCC_OPUS:
|
||||
var ctx OPUSCtx
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
ctx.OPUSCtx = *from.GetBase().(*codec.OPUSCtx)
|
||||
ctx.MimeType = webrtc.MimeTypeOpus
|
||||
ctx.PayloadType = 111
|
||||
ctx.ClockRate = uint32(ctx.CodecData.SampleRate())
|
||||
to = &ctx
|
||||
}
|
||||
return
|
||||
}
|
||||
var _ IAVFrame = (*AudioFrame)(nil)
|
||||
|
||||
type Audio struct {
|
||||
type AudioFrame struct {
|
||||
RTPData
|
||||
}
|
||||
|
||||
func (r *Audio) Parse(t *AVTrack) (err error) {
|
||||
switch r.MimeType {
|
||||
case webrtc.MimeTypeOpus:
|
||||
var ctx OPUSCtx
|
||||
ctx.parseFmtpLine(r.RTPCodecParameters)
|
||||
ctx.OPUSCtx.Channels = int(ctx.RTPCodecParameters.Channels)
|
||||
t.ICodecCtx = &ctx
|
||||
case webrtc.MimeTypePCMA:
|
||||
var ctx PCMACtx
|
||||
ctx.parseFmtpLine(r.RTPCodecParameters)
|
||||
ctx.AudioCtx.SampleRate = int(r.ClockRate)
|
||||
ctx.AudioCtx.Channels = int(ctx.RTPCodecParameters.Channels)
|
||||
t.ICodecCtx = &ctx
|
||||
case webrtc.MimeTypePCMU:
|
||||
var ctx PCMUCtx
|
||||
ctx.parseFmtpLine(r.RTPCodecParameters)
|
||||
ctx.AudioCtx.SampleRate = int(r.ClockRate)
|
||||
ctx.AudioCtx.Channels = int(ctx.RTPCodecParameters.Channels)
|
||||
t.ICodecCtx = &ctx
|
||||
case "audio/MP4A-LATM":
|
||||
var ctx *AACCtx
|
||||
if t.ICodecCtx != nil {
|
||||
// ctx = t.ICodecCtx.(*AACCtx)
|
||||
} else {
|
||||
ctx = &AACCtx{}
|
||||
ctx.parseFmtpLine(r.RTPCodecParameters)
|
||||
if conf, ok := ctx.Fmtp["config"]; ok {
|
||||
if ctx.AACCtx.ConfigBytes, err = hex.DecodeString(conf); err == nil {
|
||||
if ctx.CodecData, err = aacparser.NewCodecDataFromMPEG4AudioConfigBytes(ctx.AACCtx.ConfigBytes); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.ICodecCtx = ctx
|
||||
}
|
||||
case "audio/MPEG4-GENERIC":
|
||||
var ctx *AACCtx
|
||||
if t.ICodecCtx != nil {
|
||||
// ctx = t.ICodecCtx.(*AACCtx)
|
||||
} else {
|
||||
ctx = &AACCtx{}
|
||||
ctx.parseFmtpLine(r.RTPCodecParameters)
|
||||
ctx.IndexLength = 3
|
||||
ctx.IndexDeltaLength = 3
|
||||
ctx.SizeLength = 13
|
||||
if conf, ok := ctx.Fmtp["config"]; ok {
|
||||
if ctx.AACCtx.ConfigBytes, err = hex.DecodeString(conf); err == nil {
|
||||
if ctx.CodecData, err = aacparser.NewCodecDataFromMPEG4AudioConfigBytes(ctx.AACCtx.ConfigBytes); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
t.ICodecCtx = ctx
|
||||
}
|
||||
}
|
||||
if len(r.Packets) == 0 {
|
||||
return ErrSkip
|
||||
}
|
||||
func (r *AudioFrame) Parse(data IAVFrame) (err error) {
|
||||
input := data.(*AudioFrame)
|
||||
r.Packets = append(r.Packets[:0], input.Packets...)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -286,17 +147,22 @@ func payloadLengthInfoDecode(buf []byte) (int, int, error) {
|
||||
return l, n, nil
|
||||
}
|
||||
|
||||
func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
|
||||
func (r *AudioFrame) Demux() (err error) {
|
||||
if len(r.Packets) == 0 {
|
||||
return nil, ErrSkip
|
||||
return ErrSkip
|
||||
}
|
||||
var data AudioData
|
||||
switch r.MimeType {
|
||||
data := r.GetAudioData()
|
||||
// 从编解码器上下文获取 MimeType
|
||||
var mimeType string
|
||||
if rtpCtx, ok := r.ICodecCtx.(IRTPCtx); ok {
|
||||
mimeType = rtpCtx.GetRTPCodecParameter().MimeType
|
||||
}
|
||||
switch mimeType {
|
||||
case "audio/MP4A-LATM":
|
||||
var fragments util.Memory
|
||||
var fragmentsExpected int
|
||||
var fragmentsSize int
|
||||
for _, packet := range r.Packets {
|
||||
for packet := range r.Packets.RangePoint {
|
||||
if len(packet.Payload) == 0 {
|
||||
continue
|
||||
}
|
||||
@@ -307,23 +173,23 @@ func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
|
||||
if fragments.Size == 0 {
|
||||
pl, n, err := payloadLengthInfoDecode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
buf = buf[n:]
|
||||
bl := len(buf)
|
||||
|
||||
if pl <= bl {
|
||||
data.AppendOne(buf[:pl])
|
||||
data.PushOne(buf[:pl])
|
||||
// there could be other data, due to otherDataPresent. Ignore it.
|
||||
} else {
|
||||
if pl > 5*1024 {
|
||||
fragments = util.Memory{} // discard pending fragments
|
||||
return nil, fmt.Errorf("access unit size (%d) is too big, maximum is %d",
|
||||
return fmt.Errorf("access unit size (%d) is too big, maximum is %d",
|
||||
pl, 5*1024)
|
||||
}
|
||||
|
||||
fragments.AppendOne(buf)
|
||||
fragments.PushOne(buf)
|
||||
fragmentsSize = pl
|
||||
fragmentsExpected = pl - bl
|
||||
continue
|
||||
@@ -332,33 +198,33 @@ func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
|
||||
bl := len(buf)
|
||||
|
||||
if fragmentsExpected > bl {
|
||||
fragments.AppendOne(buf)
|
||||
fragments.PushOne(buf)
|
||||
fragmentsExpected -= bl
|
||||
continue
|
||||
}
|
||||
|
||||
fragments.AppendOne(buf[:fragmentsExpected])
|
||||
fragments.PushOne(buf[:fragmentsExpected])
|
||||
// there could be other data, due to otherDataPresent. Ignore it.
|
||||
data.Append(fragments.Buffers...)
|
||||
data.Push(fragments.Buffers...)
|
||||
if fragments.Size != fragmentsSize {
|
||||
return nil, fmt.Errorf("fragmented AU size is not correct %d != %d", data.Size, fragmentsSize)
|
||||
return fmt.Errorf("fragmented AU size is not correct %d != %d", data.Size, fragmentsSize)
|
||||
}
|
||||
fragments = util.Memory{}
|
||||
}
|
||||
}
|
||||
case "audio/MPEG4-GENERIC":
|
||||
var fragments util.Memory
|
||||
for _, packet := range r.Packets {
|
||||
for packet := range r.Packets.RangePoint {
|
||||
if len(packet.Payload) < 2 {
|
||||
continue
|
||||
}
|
||||
auHeaderLen := util.ReadBE[int](packet.Payload[:2])
|
||||
if auHeaderLen == 0 {
|
||||
data.AppendOne(packet.Payload)
|
||||
data.PushOne(packet.Payload)
|
||||
} else {
|
||||
dataLens, err := r.readAUHeaders(codexCtx.(*AACCtx), packet.Payload[2:], auHeaderLen)
|
||||
dataLens, err := r.readAUHeaders(r.ICodecCtx.(*AACCtx), packet.Payload[2:], auHeaderLen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
payload := packet.Payload[2:]
|
||||
pos := auHeaderLen >> 3
|
||||
@@ -370,48 +236,65 @@ func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
|
||||
if packet.Marker {
|
||||
for _, dataLen := range dataLens {
|
||||
if len(payload) < int(dataLen) {
|
||||
return nil, fmt.Errorf("invalid data len %d", dataLen)
|
||||
return fmt.Errorf("invalid data len %d", dataLen)
|
||||
}
|
||||
data.AppendOne(payload[:dataLen])
|
||||
data.PushOne(payload[:dataLen])
|
||||
payload = payload[dataLen:]
|
||||
}
|
||||
} else {
|
||||
if len(dataLens) != 1 {
|
||||
return nil, fmt.Errorf("a fragmented packet can only contain one AU")
|
||||
return fmt.Errorf("a fragmented packet can only contain one AU")
|
||||
}
|
||||
fragments.AppendOne(payload)
|
||||
fragments.PushOne(payload)
|
||||
}
|
||||
} else {
|
||||
if len(dataLens) != 1 {
|
||||
return nil, fmt.Errorf("a fragmented packet can only contain one AU")
|
||||
return fmt.Errorf("a fragmented packet can only contain one AU")
|
||||
}
|
||||
fragments.AppendOne(payload)
|
||||
fragments.PushOne(payload)
|
||||
if !packet.Header.Marker {
|
||||
continue
|
||||
}
|
||||
if uint64(fragments.Size) != dataLens[0] {
|
||||
return nil, fmt.Errorf("fragmented AU size is not correct %d != %d", dataLens[0], fragments.Size)
|
||||
return fmt.Errorf("fragmented AU size is not correct %d != %d", dataLens[0], fragments.Size)
|
||||
}
|
||||
data.Append(fragments.Buffers...)
|
||||
data.Push(fragments.Buffers...)
|
||||
fragments = util.Memory{}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
default:
|
||||
for _, packet := range r.Packets {
|
||||
data.AppendOne(packet.Payload)
|
||||
for packet := range r.Packets.RangePoint {
|
||||
data.PushOne(packet.Payload)
|
||||
}
|
||||
}
|
||||
return data, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Audio) Mux(codexCtx codec.ICodecCtx, from *AVFrame) {
|
||||
data := from.Raw.(AudioData)
|
||||
func (r *AudioFrame) Mux(from *Sample) (err error) {
|
||||
data := from.Raw.(*AudioData)
|
||||
var ctx *RTPCtx
|
||||
var lastPacket *rtp.Packet
|
||||
switch c := codexCtx.(type) {
|
||||
case *AACCtx:
|
||||
switch base := from.GetBase().(type) {
|
||||
case *codec.AACCtx:
|
||||
var c *AACCtx
|
||||
if r.ICodecCtx == nil {
|
||||
c = &AACCtx{}
|
||||
c.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
c.AACCtx = base
|
||||
c.MimeType = "audio/MPEG4-GENERIC"
|
||||
c.SDPFmtpLine = fmt.Sprintf("profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3;config=%s", hex.EncodeToString(c.ConfigBytes))
|
||||
c.IndexLength = 3
|
||||
c.IndexDeltaLength = 3
|
||||
c.SizeLength = 13
|
||||
c.RTPCtx.Channels = uint16(base.GetChannels())
|
||||
c.PayloadType = 97
|
||||
c.ClockRate = uint32(base.CodecData.SampleRate())
|
||||
r.ICodecCtx = c
|
||||
} else {
|
||||
c = r.ICodecCtx.(*AACCtx)
|
||||
}
|
||||
ctx = &c.RTPCtx
|
||||
pts := uint32(from.Timestamp * time.Duration(ctx.ClockRate) / time.Second)
|
||||
//AU_HEADER_LENGTH,因为单位是bit, 除以8就是auHeader的字节长度;又因为单个auheader字节长度2字节,所以再除以2就是auheader的个数。
|
||||
@@ -423,15 +306,35 @@ func (r *Audio) Mux(codexCtx codec.ICodecCtx, from *AVFrame) {
|
||||
}
|
||||
mem := r.NextN(payloadLen)
|
||||
copy(mem, auHeaderLen)
|
||||
reader.ReadBytesTo(mem[4:])
|
||||
reader.Read(mem[4:])
|
||||
lastPacket = r.Append(ctx, pts, mem)
|
||||
}
|
||||
lastPacket.Header.Marker = true
|
||||
return
|
||||
case *PCMACtx:
|
||||
ctx = &c.RTPCtx
|
||||
case *PCMUCtx:
|
||||
ctx = &c.RTPCtx
|
||||
case *codec.PCMACtx:
|
||||
if r.ICodecCtx == nil {
|
||||
var ctx PCMACtx
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
ctx.PCMACtx = base
|
||||
ctx.MimeType = webrtc.MimeTypePCMA
|
||||
ctx.PayloadType = 8
|
||||
ctx.ClockRate = uint32(ctx.SampleRate)
|
||||
r.ICodecCtx = &ctx
|
||||
} else {
|
||||
ctx = &r.ICodecCtx.(*PCMACtx).RTPCtx
|
||||
}
|
||||
case *codec.PCMUCtx:
|
||||
if r.ICodecCtx == nil {
|
||||
var ctx PCMUCtx
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
ctx.PCMUCtx = base
|
||||
ctx.MimeType = webrtc.MimeTypePCMU
|
||||
ctx.PayloadType = 0
|
||||
ctx.ClockRate = uint32(ctx.SampleRate)
|
||||
r.ICodecCtx = &ctx
|
||||
} else {
|
||||
ctx = &r.ICodecCtx.(*PCMUCtx).RTPCtx
|
||||
}
|
||||
}
|
||||
pts := uint32(from.Timestamp * time.Duration(ctx.ClockRate) / time.Second)
|
||||
if reader := data.NewReader(); reader.Length > MTUSize {
|
||||
@@ -441,18 +344,19 @@ func (r *Audio) Mux(codexCtx codec.ICodecCtx, from *AVFrame) {
|
||||
payloadLen = reader.Length
|
||||
}
|
||||
mem := r.NextN(payloadLen)
|
||||
reader.ReadBytesTo(mem)
|
||||
reader.Read(mem)
|
||||
lastPacket = r.Append(ctx, pts, mem)
|
||||
}
|
||||
} else {
|
||||
mem := r.NextN(reader.Length)
|
||||
reader.ReadBytesTo(mem)
|
||||
reader.Read(mem)
|
||||
lastPacket = r.Append(ctx, pts, mem)
|
||||
}
|
||||
lastPacket.Header.Marker = true
|
||||
return
|
||||
}
|
||||
|
||||
func (r *Audio) readAUHeaders(ctx *AACCtx, buf []byte, headersLen int) ([]uint64, error) {
|
||||
func (r *AudioFrame) readAUHeaders(ctx *AACCtx, buf []byte, headersLen int) ([]uint64, error) {
|
||||
firstRead := false
|
||||
|
||||
count := 0
|
||||
|
||||
476
plugin/rtp/pkg/forward.go
Normal file
476
plugin/rtp/pkg/forward.go
Normal file
@@ -0,0 +1,476 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
// ConnectionConfig 连接配置
|
||||
type ConnectionConfig struct {
|
||||
IP string
|
||||
Port uint32
|
||||
Mode StreamMode
|
||||
SSRC uint32 // RTP SSRC
|
||||
}
|
||||
|
||||
// ForwardConfig 转发配置
|
||||
type ForwardConfig struct {
|
||||
Source ConnectionConfig
|
||||
Target ConnectionConfig
|
||||
Relay bool
|
||||
}
|
||||
|
||||
// Forwarder 转发器
|
||||
type Forwarder struct {
|
||||
config *ForwardConfig
|
||||
source net.Conn
|
||||
target net.Conn
|
||||
}
|
||||
|
||||
// NewForwarder 创建新的转发器
|
||||
func NewForwarder(config *ForwardConfig) *Forwarder {
|
||||
return &Forwarder{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// establishSourceConnection 建立源连接
|
||||
func (f *Forwarder) establishSourceConnection(config ConnectionConfig) (net.Conn, error) {
|
||||
switch config.Mode {
|
||||
case StreamModeTCPActive:
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
netConn, err := dialer.Dial("tcp", fmt.Sprintf("%s:%d", config.IP, config.Port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect failed: %v", err)
|
||||
}
|
||||
return netConn, nil
|
||||
|
||||
case StreamModeTCPPassive:
|
||||
listener, err := net.Listen("tcp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen failed: %v", err)
|
||||
}
|
||||
|
||||
// Set timeout for accepting connections
|
||||
if tcpListener, ok := listener.(*net.TCPListener); ok {
|
||||
tcpListener.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
}
|
||||
|
||||
netConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
listener.Close()
|
||||
return nil, fmt.Errorf("accept failed: %v", err)
|
||||
}
|
||||
|
||||
return netConn, nil
|
||||
|
||||
case StreamModeUDP:
|
||||
// Source UDP - listen
|
||||
udpAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve UDP address failed: %v", err)
|
||||
}
|
||||
netConn, err := net.ListenUDP("udp4", udpAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("UDP listen failed: %v", err)
|
||||
}
|
||||
return netConn, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported mode: %s", config.Mode)
|
||||
}
|
||||
|
||||
// establishTargetConnection 建立目标连接
|
||||
func (f *Forwarder) establishTargetConnection(config ConnectionConfig) (net.Conn, error) {
|
||||
switch config.Mode {
|
||||
case StreamModeTCPActive:
|
||||
dialer := &net.Dialer{Timeout: 10 * time.Second}
|
||||
netConn, err := dialer.Dial("tcp", fmt.Sprintf("%s:%d", config.IP, config.Port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect failed: %v", err)
|
||||
}
|
||||
return netConn, nil
|
||||
|
||||
case StreamModeTCPPassive:
|
||||
listener, err := net.Listen("tcp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("listen failed: %v", err)
|
||||
}
|
||||
|
||||
// Set timeout for accepting connections
|
||||
if tcpListener, ok := listener.(*net.TCPListener); ok {
|
||||
tcpListener.SetDeadline(time.Now().Add(30 * time.Second))
|
||||
}
|
||||
|
||||
netConn, err := listener.Accept()
|
||||
if err != nil {
|
||||
listener.Close()
|
||||
return nil, fmt.Errorf("accept failed: %v", err)
|
||||
}
|
||||
|
||||
return netConn, nil
|
||||
|
||||
case StreamModeUDP:
|
||||
// Target UDP - dial
|
||||
netConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
|
||||
IP: net.ParseIP(config.IP),
|
||||
Port: int(config.Port),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("UDP dial failed: %v", err)
|
||||
}
|
||||
return netConn, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported mode: %s", config.Mode)
|
||||
}
|
||||
|
||||
// setupConnections 建立源和目标连接
|
||||
func (f *Forwarder) setupConnections() error {
|
||||
var err error
|
||||
|
||||
// 建立源连接
|
||||
f.source, err = f.establishSourceConnection(f.config.Source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("source connection failed: %v", err)
|
||||
}
|
||||
|
||||
// 建立目标连接
|
||||
f.target, err = f.establishTargetConnection(f.config.Target)
|
||||
if err != nil {
|
||||
return fmt.Errorf("target connection failed: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanup 清理连接
|
||||
func (f *Forwarder) cleanup() {
|
||||
if f.source != nil {
|
||||
f.source.Close()
|
||||
}
|
||||
if f.target != nil {
|
||||
f.target.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// createRTPReader 创建RTP读取器
|
||||
func (f *Forwarder) createRTPReader() IRTPReader {
|
||||
switch f.config.Source.Mode {
|
||||
case StreamModeUDP:
|
||||
return NewRTPUDPReader(f.source)
|
||||
case StreamModeTCPActive, StreamModeTCPPassive:
|
||||
return NewRTPTCPReader(f.source)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// createRTPWriter 创建RTP写入器
|
||||
func (f *Forwarder) createRTPWriter() RTPWriter {
|
||||
return NewRTPWriter(f.target, f.config.Target.Mode)
|
||||
}
|
||||
|
||||
// RTPWriter RTP写入器接口
|
||||
type RTPWriter interface {
|
||||
WritePacket(packet *rtp.Packet) error
|
||||
WriteRaw(data []byte) error
|
||||
}
|
||||
|
||||
// rtpWriter RTP写入器实现
|
||||
type rtpWriter struct {
|
||||
writer io.Writer
|
||||
mode StreamMode
|
||||
header []byte
|
||||
sendBuffer util.Buffer // 可复用的发送缓冲区
|
||||
}
|
||||
|
||||
// NewRTPWriter 创建RTP写入器
|
||||
func NewRTPWriter(writer io.Writer, mode StreamMode) RTPWriter {
|
||||
return &rtpWriter{
|
||||
writer: writer,
|
||||
mode: mode,
|
||||
header: make([]byte, 2),
|
||||
sendBuffer: util.Buffer{}, // 初始化可复用缓冲区
|
||||
}
|
||||
}
|
||||
|
||||
// WritePacket 写入RTP包
|
||||
func (w *rtpWriter) WritePacket(packet *rtp.Packet) error {
|
||||
// 复用sendBuffer,避免重复创建
|
||||
w.sendBuffer.Reset()
|
||||
w.sendBuffer.Malloc(packet.MarshalSize())
|
||||
_, err := packet.MarshalTo(w.sendBuffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal RTP packet failed: %v", err)
|
||||
}
|
||||
|
||||
return w.WriteRaw(w.sendBuffer)
|
||||
}
|
||||
|
||||
// WriteRaw 写入原始数据
|
||||
func (w *rtpWriter) WriteRaw(data []byte) error {
|
||||
if w.mode == StreamModeUDP {
|
||||
_, err := w.writer.Write(data)
|
||||
return err
|
||||
} else {
|
||||
// TCP模式需要添加长度头
|
||||
binary.BigEndian.PutUint16(w.header, uint16(len(data)))
|
||||
_, err := w.writer.Write(w.header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.writer.Write(data)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// RelayProcessor 中继处理器
|
||||
type RelayProcessor struct {
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
sourceMode StreamMode
|
||||
targetMode StreamMode
|
||||
buffer []byte // 可复用的缓冲区
|
||||
header []byte // 可复用的头部缓冲区
|
||||
}
|
||||
|
||||
// NewRelayProcessor 创建中继处理器
|
||||
func NewRelayProcessor(reader io.Reader, writer io.Writer, sourceMode, targetMode StreamMode) *RelayProcessor {
|
||||
return &RelayProcessor{
|
||||
reader: reader,
|
||||
writer: writer,
|
||||
sourceMode: sourceMode,
|
||||
targetMode: targetMode,
|
||||
buffer: make([]byte, 1460), // 初始化可复用缓冲区
|
||||
header: make([]byte, 2), // 初始化可复用头部缓冲区
|
||||
}
|
||||
}
|
||||
|
||||
// Process 处理中继
|
||||
func (p *RelayProcessor) Process(ctx context.Context) error {
|
||||
if p.sourceMode == p.targetMode {
|
||||
// 相同模式直接复制
|
||||
_, err := io.Copy(p.writer, p.reader)
|
||||
return err
|
||||
}
|
||||
|
||||
// 不同模式需要转换
|
||||
if p.sourceMode == StreamModeUDP && (p.targetMode == StreamModeTCPActive || p.targetMode == StreamModeTCPPassive) {
|
||||
// UDP to TCP
|
||||
return p.processUDPToTCP(ctx)
|
||||
} else if (p.sourceMode == StreamModeTCPActive || p.sourceMode == StreamModeTCPPassive) && p.targetMode == StreamModeUDP {
|
||||
// TCP to UDP
|
||||
return p.processTCPToUDP(ctx)
|
||||
}
|
||||
|
||||
return fmt.Errorf("unsupported mode combination")
|
||||
}
|
||||
|
||||
// processUDPToTCP UDP转TCP
|
||||
func (p *RelayProcessor) processUDPToTCP(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
n, err := p.reader.Read(p.buffer)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 添加2字节长度头
|
||||
binary.BigEndian.PutUint16(p.header, uint16(n))
|
||||
_, err = p.writer.Write(p.header)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = p.writer.Write(p.buffer[:n])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processTCPToUDP TCP转UDP
|
||||
func (p *RelayProcessor) processTCPToUDP(ctx context.Context) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// 读取2字节长度头
|
||||
_, err := io.ReadFull(p.reader, p.header)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取包长度
|
||||
packetLength := binary.BigEndian.Uint16(p.header)
|
||||
|
||||
// 如果包长度超过缓冲区大小,需要动态分配
|
||||
if packetLength > uint16(len(p.buffer)) {
|
||||
packetData := make([]byte, packetLength)
|
||||
_, err = io.ReadFull(p.reader, packetData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = p.writer.Write(packetData)
|
||||
} else {
|
||||
// 使用可复用缓冲区
|
||||
_, err = io.ReadFull(p.reader, p.buffer[:packetLength])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = p.writer.Write(p.buffer[:packetLength])
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RTPProcessor RTP处理器
|
||||
type RTPProcessor struct {
|
||||
reader IRTPReader
|
||||
writer RTPWriter
|
||||
config *ForwardConfig
|
||||
sendBuffer util.Buffer // 可复用的发送缓冲区
|
||||
}
|
||||
|
||||
// NewRTPProcessor 创建RTP处理器
|
||||
func NewRTPProcessor(reader IRTPReader, writer RTPWriter, config *ForwardConfig) *RTPProcessor {
|
||||
return &RTPProcessor{
|
||||
reader: reader,
|
||||
writer: writer,
|
||||
config: config,
|
||||
sendBuffer: util.Buffer{}, // 初始化可复用缓冲区
|
||||
}
|
||||
}
|
||||
|
||||
// Process 处理RTP包
|
||||
func (p *RTPProcessor) Process(ctx context.Context) error {
|
||||
var packet rtp.Packet
|
||||
var sequenceNumber uint16
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
err := p.reader.Read(&packet)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("read RTP packet failed: %v", err)
|
||||
}
|
||||
|
||||
// 检查源SSRC过滤
|
||||
if p.config.Source.SSRC != 0 && packet.SSRC != p.config.Source.SSRC {
|
||||
continue
|
||||
}
|
||||
|
||||
// 保存原始序列号用于分片包
|
||||
sequenceNumber = packet.SequenceNumber
|
||||
|
||||
// 检查是否需要分片
|
||||
if len(packet.Payload) > (1460 - packet.MarshalSize()) {
|
||||
err = p.processFragmentedPacket(&packet, sequenceNumber)
|
||||
} else {
|
||||
err = p.processSinglePacket(&packet)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processSinglePacket 处理单个包
|
||||
func (p *RTPProcessor) processSinglePacket(packet *rtp.Packet) error {
|
||||
if p.config.Target.SSRC != 0 {
|
||||
packet.SSRC = p.config.Target.SSRC
|
||||
}
|
||||
|
||||
return p.writer.WritePacket(packet)
|
||||
}
|
||||
|
||||
// processFragmentedPacket 处理分片包
|
||||
func (p *RTPProcessor) processFragmentedPacket(packet *rtp.Packet, sequenceNumber uint16) error {
|
||||
maxPayloadSize := 1460 - 12 // RTP头通常是12字节
|
||||
payload := packet.Payload
|
||||
|
||||
// 标记第一个包
|
||||
marker := packet.Marker
|
||||
packet.Marker = false
|
||||
|
||||
for i := 0; i < len(payload); i += int(maxPayloadSize) {
|
||||
end := i + int(maxPayloadSize)
|
||||
if end > len(payload) {
|
||||
end = len(payload)
|
||||
// 最后一个分片,恢复原始标记
|
||||
packet.Marker = marker
|
||||
}
|
||||
|
||||
// 创建包含分片的新包
|
||||
fragmentPacket := *packet
|
||||
if p.config.Target.SSRC != 0 {
|
||||
fragmentPacket.SSRC = p.config.Target.SSRC
|
||||
}
|
||||
fragmentPacket.SequenceNumber = sequenceNumber
|
||||
sequenceNumber++
|
||||
fragmentPacket.Payload = payload[i:end]
|
||||
|
||||
err := p.writer.WritePacket(&fragmentPacket)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write RTP fragment failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Forward 执行转发
|
||||
func (f *Forwarder) Forward(ctx context.Context) error {
|
||||
// 建立连接
|
||||
err := f.setupConnections()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.cleanup()
|
||||
|
||||
// 检查是否为中继模式
|
||||
if f.config.Relay {
|
||||
processor := NewRelayProcessor(f.source, f.target, f.config.Source.Mode, f.config.Target.Mode)
|
||||
return processor.Process(ctx)
|
||||
}
|
||||
|
||||
// RTP处理模式
|
||||
reader := f.createRTPReader()
|
||||
writer := f.createRTPWriter()
|
||||
processor := NewRTPProcessor(reader, writer, f.config)
|
||||
|
||||
return processor.Process(ctx)
|
||||
}
|
||||
322
plugin/rtp/pkg/forward_test.go
Normal file
322
plugin/rtp/pkg/forward_test.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
)
|
||||
|
||||
func TestForwardConfig(t *testing.T) {
|
||||
config := &ForwardConfig{
|
||||
Source: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Mode: StreamModeUDP,
|
||||
SSRC: 12345,
|
||||
},
|
||||
Target: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8081,
|
||||
Mode: StreamModeTCPActive,
|
||||
SSRC: 67890,
|
||||
},
|
||||
Relay: false,
|
||||
}
|
||||
|
||||
if config.Source.IP != "127.0.0.1" {
|
||||
t.Errorf("Expected source IP 127.0.0.1, got %s", config.Source.IP)
|
||||
}
|
||||
|
||||
if config.Source.Port != 8080 {
|
||||
t.Errorf("Expected source port 8080, got %d", config.Source.Port)
|
||||
}
|
||||
|
||||
if config.Source.Mode != StreamModeUDP {
|
||||
t.Errorf("Expected source mode UDP, got %s", config.Source.Mode)
|
||||
}
|
||||
|
||||
if config.Source.SSRC != 12345 {
|
||||
t.Errorf("Expected source SSRC 12345, got %d", config.Source.SSRC)
|
||||
}
|
||||
|
||||
if config.Target.IP != "127.0.0.1" {
|
||||
t.Errorf("Expected target IP 127.0.0.1, got %s", config.Target.IP)
|
||||
}
|
||||
|
||||
if config.Target.Port != 8081 {
|
||||
t.Errorf("Expected target port 8081, got %d", config.Target.Port)
|
||||
}
|
||||
|
||||
if config.Target.Mode != StreamModeTCPActive {
|
||||
t.Errorf("Expected target mode TCP-ACTIVE, got %s", config.Target.Mode)
|
||||
}
|
||||
|
||||
if config.Target.SSRC != 67890 {
|
||||
t.Errorf("Expected target SSRC 67890, got %d", config.Target.SSRC)
|
||||
}
|
||||
|
||||
if config.Relay {
|
||||
t.Error("Expected relay to be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewForwarder(t *testing.T) {
|
||||
config := &ForwardConfig{
|
||||
Source: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Mode: StreamModeUDP,
|
||||
SSRC: 12345,
|
||||
},
|
||||
Target: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8081,
|
||||
Mode: StreamModeTCPActive,
|
||||
SSRC: 67890,
|
||||
},
|
||||
Relay: false,
|
||||
}
|
||||
|
||||
forwarder := NewForwarder(config)
|
||||
|
||||
if forwarder.config != config {
|
||||
t.Error("Expected forwarder config to match input config")
|
||||
}
|
||||
|
||||
if forwarder.source != nil {
|
||||
t.Error("Expected source connection to be nil initially")
|
||||
}
|
||||
|
||||
if forwarder.target != nil {
|
||||
t.Error("Expected target connection to be nil initially")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionConfig(t *testing.T) {
|
||||
config := ConnectionConfig{
|
||||
IP: "192.168.1.100",
|
||||
Port: 9000,
|
||||
Mode: StreamModeTCPPassive,
|
||||
SSRC: 54321,
|
||||
}
|
||||
|
||||
if config.IP != "192.168.1.100" {
|
||||
t.Errorf("Expected IP 192.168.1.100, got %s", config.IP)
|
||||
}
|
||||
|
||||
if config.Port != 9000 {
|
||||
t.Errorf("Expected port 9000, got %d", config.Port)
|
||||
}
|
||||
|
||||
if config.Mode != StreamModeTCPPassive {
|
||||
t.Errorf("Expected mode TCP-PASSIVE, got %s", config.Mode)
|
||||
}
|
||||
|
||||
if config.SSRC != 54321 {
|
||||
t.Errorf("Expected SSRC 54321, got %d", config.SSRC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRTPWriter(t *testing.T) {
|
||||
// 创建一个模拟的writer
|
||||
mockWriter := &mockWriter{}
|
||||
writer := NewRTPWriter(mockWriter, StreamModeUDP)
|
||||
|
||||
if writer == nil {
|
||||
t.Error("Expected RTPWriter to be created")
|
||||
}
|
||||
|
||||
// 测试UDP模式的WriteRaw
|
||||
data := []byte{1, 2, 3, 4}
|
||||
err := writer.WriteRaw(data)
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(mockWriter.data) != 1 {
|
||||
t.Errorf("Expected 1 write, got %d", len(mockWriter.data))
|
||||
}
|
||||
|
||||
if len(mockWriter.data[0]) != 4 {
|
||||
t.Errorf("Expected 4 bytes written, got %d", len(mockWriter.data[0]))
|
||||
}
|
||||
}
|
||||
|
||||
// mockWriter 用于测试的模拟writer
|
||||
type mockWriter struct {
|
||||
data [][]byte
|
||||
}
|
||||
|
||||
func (w *mockWriter) Write(data []byte) (int, error) {
|
||||
w.data = append(w.data, append([]byte{}, data...))
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func TestRelayProcessor(t *testing.T) {
|
||||
// 创建模拟的reader和writer
|
||||
mockReader := &mockReader{data: [][]byte{{1, 2, 3}, {4, 5, 6}}}
|
||||
mockWriter := &mockWriter{}
|
||||
|
||||
processor := NewRelayProcessor(mockReader, mockWriter, StreamModeUDP, StreamModeTCPActive)
|
||||
|
||||
if processor.reader != mockReader {
|
||||
t.Error("Expected reader to match input")
|
||||
}
|
||||
|
||||
if processor.writer != mockWriter {
|
||||
t.Error("Expected writer to match input")
|
||||
}
|
||||
|
||||
if processor.sourceMode != StreamModeUDP {
|
||||
t.Errorf("Expected source mode UDP, got %s", processor.sourceMode)
|
||||
}
|
||||
|
||||
if processor.targetMode != StreamModeTCPActive {
|
||||
t.Errorf("Expected target mode TCP-ACTIVE, got %s", processor.targetMode)
|
||||
}
|
||||
}
|
||||
|
||||
// mockReader 用于测试的模拟reader
|
||||
type mockReader struct {
|
||||
data [][]byte
|
||||
pos int
|
||||
}
|
||||
|
||||
func (r *mockReader) Read(buf []byte) (int, error) {
|
||||
if r.pos >= len(r.data) {
|
||||
return 0, nil // EOF
|
||||
}
|
||||
|
||||
data := r.data[r.pos]
|
||||
r.pos++
|
||||
|
||||
copy(buf, data)
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func TestConnectionTypes(t *testing.T) {
|
||||
// 测试ConnectionConfig
|
||||
config := ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Mode: StreamModeUDP,
|
||||
SSRC: 12345,
|
||||
}
|
||||
|
||||
if config.Mode != StreamModeUDP {
|
||||
t.Errorf("Expected mode UDP, got %s", config.Mode)
|
||||
}
|
||||
|
||||
if config.SSRC != 12345 {
|
||||
t.Errorf("Expected SSRC 12345, got %d", config.SSRC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionDirection(t *testing.T) {
|
||||
// 测试连接方向的概念
|
||||
config := &ForwardConfig{
|
||||
Source: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Mode: StreamModeUDP,
|
||||
SSRC: 12345,
|
||||
},
|
||||
Target: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8081,
|
||||
Mode: StreamModeTCPActive,
|
||||
SSRC: 67890,
|
||||
},
|
||||
Relay: false,
|
||||
}
|
||||
|
||||
forwarder := NewForwarder(config)
|
||||
|
||||
// 验证配置正确性
|
||||
if forwarder.config.Source.SSRC != 12345 {
|
||||
t.Errorf("Expected source SSRC 12345, got %d", forwarder.config.Source.SSRC)
|
||||
}
|
||||
|
||||
if forwarder.config.Target.SSRC != 67890 {
|
||||
t.Errorf("Expected target SSRC 67890, got %d", forwarder.config.Target.SSRC)
|
||||
}
|
||||
|
||||
// 验证连接类型
|
||||
if forwarder.source != nil {
|
||||
t.Error("Expected source connection to be nil initially")
|
||||
}
|
||||
|
||||
if forwarder.target != nil {
|
||||
t.Error("Expected target connection to be nil initially")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBufferReuse(t *testing.T) {
|
||||
// 测试RTPWriter的buffer复用
|
||||
writer := NewRTPWriter(&mockWriter{}, StreamModeUDP)
|
||||
|
||||
// 多次写入应该复用同一个buffer
|
||||
for i := 0; i < 10; i++ {
|
||||
packet := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: uint16(i),
|
||||
Timestamp: uint32(i * 1000),
|
||||
SSRC: uint32(i),
|
||||
},
|
||||
Payload: []byte(fmt.Sprintf("test packet %d", i)),
|
||||
}
|
||||
|
||||
err := writer.WritePacket(packet)
|
||||
if err != nil {
|
||||
t.Errorf("WritePacket failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试RelayProcessor的buffer复用
|
||||
processor := NewRelayProcessor(&mockReader{data: [][]byte{{1, 2, 3}, {4, 5, 6}}}, &mockWriter{}, StreamModeUDP, StreamModeTCPActive)
|
||||
|
||||
// 验证buffer字段存在
|
||||
if processor.buffer == nil {
|
||||
t.Error("Expected buffer to be initialized")
|
||||
}
|
||||
|
||||
if len(processor.buffer) != 1460 {
|
||||
t.Errorf("Expected buffer size 1460, got %d", len(processor.buffer))
|
||||
}
|
||||
|
||||
if processor.header == nil {
|
||||
t.Error("Expected header to be initialized")
|
||||
}
|
||||
|
||||
if len(processor.header) != 2 {
|
||||
t.Errorf("Expected header size 2, got %d", len(processor.header))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRTPProcessorBufferReuse(t *testing.T) {
|
||||
// 测试RTPProcessor的buffer复用
|
||||
config := &ForwardConfig{
|
||||
Source: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Mode: StreamModeUDP,
|
||||
SSRC: 12345,
|
||||
},
|
||||
Target: ConnectionConfig{
|
||||
IP: "127.0.0.1",
|
||||
Port: 8081,
|
||||
Mode: StreamModeTCPActive,
|
||||
SSRC: 67890,
|
||||
},
|
||||
Relay: false,
|
||||
}
|
||||
|
||||
processor := NewRTPProcessor(nil, nil, config)
|
||||
|
||||
// 验证sendBuffer字段存在
|
||||
if processor.sendBuffer == nil {
|
||||
t.Error("Expected sendBuffer to be initialized")
|
||||
}
|
||||
}
|
||||
48
plugin/rtp/pkg/puller-dump.go
Normal file
48
plugin/rtp/pkg/puller-dump.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
m7s "m7s.live/v5"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
type DumpPuller struct {
|
||||
m7s.HTTPFilePuller
|
||||
}
|
||||
|
||||
func (p *DumpPuller) Start() (err error) {
|
||||
p.PullJob.PublishConfig.PubType = m7s.PublishTypeReplay
|
||||
return p.HTTPFilePuller.Start()
|
||||
}
|
||||
|
||||
func (p *DumpPuller) Run() (err error) {
|
||||
pub := p.PullJob.Publisher
|
||||
var receiver PSReceiver
|
||||
receiver.Publisher = pub
|
||||
receiver.StreamMode = StreamModeManual
|
||||
receiver.OnStart(func() {
|
||||
go func() {
|
||||
var t uint16
|
||||
for l := make([]byte, 6); pub.State != m7s.PublisherStateDisposed; time.Sleep(time.Millisecond * time.Duration(t)) {
|
||||
_, err = p.Read(l)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
payloadLen := util.ReadBE[int](l[:4])
|
||||
payload := make([]byte, payloadLen)
|
||||
t = util.ReadBE[uint16](l[4:])
|
||||
_, err = p.Read(payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case receiver.RTPMouth <- payload:
|
||||
case <-pub.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
})
|
||||
return p.RunTask(&receiver)
|
||||
}
|
||||
140
plugin/rtp/pkg/reader.go
Normal file
140
plugin/rtp/pkg/reader.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
type IRTPReader interface {
|
||||
Read(packet *rtp.Packet) (err error)
|
||||
}
|
||||
|
||||
type RTPUDPReader struct {
|
||||
io.Reader
|
||||
buf [MTUSize]byte
|
||||
}
|
||||
|
||||
func NewRTPUDPReader(r io.Reader) *RTPUDPReader {
|
||||
return &RTPUDPReader{Reader: r}
|
||||
}
|
||||
|
||||
func (r *RTPUDPReader) Read(packet *rtp.Packet) (err error) {
|
||||
n, err := r.Reader.Read(r.buf[:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return packet.Unmarshal(r.buf[:n])
|
||||
}
|
||||
|
||||
type RTPTCPReader struct {
|
||||
*util.BufReader
|
||||
buffer util.Buffer
|
||||
}
|
||||
|
||||
func NewRTPTCPReader(r io.Reader) *RTPTCPReader {
|
||||
return &RTPTCPReader{BufReader: util.NewBufReader(r)}
|
||||
}
|
||||
|
||||
func (r *RTPTCPReader) Read(packet *rtp.Packet) (err error) {
|
||||
var rtplen uint32
|
||||
var b0, b1 byte
|
||||
rtplen, err = r.ReadBE32(2)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var mem util.Memory
|
||||
mem, err = r.ReadBytes(int(rtplen))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mr := mem.NewReader()
|
||||
mr.ReadByteTo(&b0, &b1)
|
||||
if b0>>6 != 2 || b0&0x0f > 15 || b1&0x7f > 127 {
|
||||
// TODO:
|
||||
panic(fmt.Errorf("invalid rtp packet: %x", r.buffer[:2]))
|
||||
} else {
|
||||
r.buffer.Relloc(int(rtplen))
|
||||
mem.CopyTo(r.buffer)
|
||||
err = packet.Unmarshal(r.buffer)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type RTPPayloadReader struct {
|
||||
IRTPReader
|
||||
rtp.Packet
|
||||
SSRC uint32 // RTP SSRC
|
||||
buffer util.MemoryReader
|
||||
}
|
||||
|
||||
// func NewTCPRTPPayloadReaderForFeed() *RTPPayloadReader {
|
||||
// r := &RTPPayloadReader{}
|
||||
// r.IRTPReader = &RTPTCPReader{
|
||||
// BufReader: util.NewBufReaderChan(10),
|
||||
// }
|
||||
// r.buffer.Memory = &util.Memory{}
|
||||
// return r
|
||||
// }
|
||||
|
||||
func NewRTPPayloadReader(t IRTPReader) *RTPPayloadReader {
|
||||
r := &RTPPayloadReader{}
|
||||
r.IRTPReader = t
|
||||
r.buffer.Memory = &util.Memory{}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *RTPPayloadReader) Read(buf []byte) (n int, err error) {
|
||||
// 如果缓冲区中有数据,先读取缓冲区中的数据
|
||||
if r.buffer.Length > 0 {
|
||||
n, _ = r.buffer.Read(buf)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// 读取新的RTP包
|
||||
for {
|
||||
lastSeq := r.SequenceNumber
|
||||
err = r.IRTPReader.Read(&r.Packet)
|
||||
if err != nil {
|
||||
err = errors.Join(err, fmt.Errorf("failed to read RTP packet"))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查SSRC是否匹配
|
||||
if r.SSRC != 0 && r.SSRC != r.Packet.SSRC {
|
||||
// SSRC不匹配,继续读取下一个包
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查序列号是否连续
|
||||
if lastSeq == 0 || r.SequenceNumber == lastSeq+1 {
|
||||
// 序列号连续,处理当前包的数据
|
||||
if lbuf, lpayload := len(buf), len(r.Payload); lbuf >= lpayload {
|
||||
// 缓冲区足够大,可以容纳整个负载
|
||||
copy(buf, r.Payload)
|
||||
n += lpayload
|
||||
|
||||
// 如果缓冲区还有剩余空间,继续读取下一个包
|
||||
if lbuf > lpayload {
|
||||
var nextn int
|
||||
nextn, err = r.Read(buf[lpayload:])
|
||||
if err != nil && err != io.EOF {
|
||||
return n, err
|
||||
}
|
||||
n += nextn
|
||||
}
|
||||
return
|
||||
} else {
|
||||
// 缓冲区不够大,只复制部分数据,将剩余数据放入缓冲区
|
||||
n += lbuf
|
||||
copy(buf, r.Payload[:lbuf])
|
||||
r.buffer.PushOne(r.Payload[lbuf:])
|
||||
r.buffer.Length = lpayload - lbuf
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
113
plugin/rtp/pkg/reader_debug_test.go
Normal file
113
plugin/rtp/pkg/reader_debug_test.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRTPPayloadReaderDebug(t *testing.T) {
|
||||
// 创建简单的测试数据
|
||||
originalData := []byte("Hello World")
|
||||
|
||||
// 生成RTP包
|
||||
packets := generateRTPPacketsForDebug(originalData, 0, 1000)
|
||||
|
||||
fmt.Printf("Generated %d RTP packets\n", len(packets))
|
||||
for i, packet := range packets {
|
||||
fmt.Printf("Packet %d: Seq=%d, Payload=%s, PayloadLen=%d\n", i, packet.SequenceNumber, string(packet.Payload), len(packet.Payload))
|
||||
}
|
||||
|
||||
// 将RTP包序列化到缓冲区
|
||||
var buf bytes.Buffer
|
||||
for _, packet := range packets {
|
||||
data, err := packet.Marshal()
|
||||
assert.NoError(t, err)
|
||||
fmt.Printf("Marshaled packet length: %d\n", len(data))
|
||||
|
||||
// 写入RTP包长度和数据
|
||||
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
|
||||
buf.Write(data)
|
||||
}
|
||||
|
||||
fmt.Printf("Buffer size: %d\n", buf.Len())
|
||||
fmt.Printf("Original data length: %d\n", len(originalData))
|
||||
fmt.Printf("Original data: %s\n", string(originalData))
|
||||
|
||||
// 使用RTPPayloadReader读取数据
|
||||
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
|
||||
|
||||
// 逐步读取数据
|
||||
allData := make([]byte, 0)
|
||||
bufSize := 3
|
||||
|
||||
for len(allData) < len(originalData) {
|
||||
result := make([]byte, bufSize)
|
||||
fmt.Printf("Buffer length before read: %d\n", reader.buffer.Length)
|
||||
fmt.Printf("Buffer count before read: %d\n", reader.buffer.Count())
|
||||
n, err := reader.Read(result)
|
||||
fmt.Printf("Read returned: n=%d, err=%v\n", n, err)
|
||||
fmt.Printf("Read data: %s\n", string(result[:n]))
|
||||
fmt.Printf("Buffer length after read: %d\n", reader.buffer.Length)
|
||||
fmt.Printf("Buffer count after read: %d\n", reader.buffer.Count())
|
||||
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
allData = append(allData, result[:n]...)
|
||||
fmt.Printf("All data so far: %s\n", string(allData))
|
||||
}
|
||||
|
||||
fmt.Printf("Final data length: %d\n", len(allData))
|
||||
fmt.Printf("Final data: %s\n", string(allData))
|
||||
|
||||
// 验证数据是否匹配
|
||||
assert.Equal(t, originalData, allData)
|
||||
}
|
||||
|
||||
func generateRTPPacketsForDebug(data []byte, ssrc uint32, initialSeq uint16) []*rtp.Packet {
|
||||
var packets []*rtp.Packet
|
||||
seq := initialSeq
|
||||
maxPayloadSize := 100
|
||||
|
||||
for len(data) > 0 {
|
||||
// 确定当前包的负载大小
|
||||
payloadSize := maxPayloadSize
|
||||
if len(data) < payloadSize {
|
||||
payloadSize = len(data)
|
||||
}
|
||||
|
||||
// 创建RTP包
|
||||
packet := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Padding: false,
|
||||
Extension: false,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: seq,
|
||||
Timestamp: 123456,
|
||||
SSRC: ssrc,
|
||||
},
|
||||
Payload: data[:payloadSize],
|
||||
}
|
||||
|
||||
packets = append(packets, packet)
|
||||
|
||||
// 更新数据和序列号
|
||||
data = data[payloadSize:]
|
||||
seq++
|
||||
}
|
||||
|
||||
return packets
|
||||
}
|
||||
153
plugin/rtp/pkg/reader_test.go
Normal file
153
plugin/rtp/pkg/reader_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRTPPayloadReader(t *testing.T) {
|
||||
// 创建测试数据
|
||||
originalData := []byte("Hello, World! This is a test payload for RTP packets.")
|
||||
|
||||
// 生成RTP包
|
||||
packets := generateRTPPackets(originalData, 0, 1000)
|
||||
|
||||
// 将RTP包序列化到缓冲区
|
||||
var buf bytes.Buffer
|
||||
for _, packet := range packets {
|
||||
data, err := packet.Marshal()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 写入RTP包长度和数据
|
||||
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
|
||||
buf.Write(data)
|
||||
}
|
||||
|
||||
// 使用RTPPayloadReader读取数据
|
||||
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
|
||||
|
||||
// 读取所有数据
|
||||
result := make([]byte, len(originalData))
|
||||
n, err := reader.Read(result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
|
||||
// 验证数据是否匹配
|
||||
assert.Equal(t, originalData, result)
|
||||
}
|
||||
|
||||
func TestRTPPayloadReaderWithBuffer(t *testing.T) {
|
||||
// 创建测试数据
|
||||
originalData := []byte("This is a longer test payload that will be split across multiple RTP packets to test the buffering functionality of the RTPPayloadReader.")
|
||||
|
||||
// 生成RTP包
|
||||
packets := generateRTPPackets(originalData, 0, 2000)
|
||||
|
||||
// 将RTP包序列化到缓冲区
|
||||
var buf bytes.Buffer
|
||||
for _, packet := range packets {
|
||||
data, err := packet.Marshal()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 写入RTP包长度和数据
|
||||
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
|
||||
buf.Write(data)
|
||||
}
|
||||
|
||||
// 使用RTPPayloadReader读取数据
|
||||
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
|
||||
|
||||
// 使用较小的缓冲区读取数据
|
||||
allData := make([]byte, 0)
|
||||
bufSize := 10
|
||||
|
||||
for len(allData) < len(originalData) {
|
||||
result := make([]byte, bufSize)
|
||||
n, err := reader.Read(result)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
if n == 0 {
|
||||
break
|
||||
}
|
||||
allData = append(allData, result[:n]...)
|
||||
}
|
||||
|
||||
// 验证数据是否匹配
|
||||
assert.Equal(t, originalData, allData)
|
||||
}
|
||||
|
||||
func TestRTPPayloadReaderSimple(t *testing.T) {
|
||||
// 创建简单的测试数据
|
||||
originalData := []byte("Hello World")
|
||||
|
||||
// 生成RTP包
|
||||
packets := generateRTPPackets(originalData, 0, 1000)
|
||||
|
||||
// 将RTP包序列化到缓冲区
|
||||
var buf bytes.Buffer
|
||||
for _, packet := range packets {
|
||||
data, err := packet.Marshal()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 写入RTP包长度和数据
|
||||
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
|
||||
buf.Write(data)
|
||||
}
|
||||
|
||||
// 使用RTPPayloadReader读取数据
|
||||
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
|
||||
|
||||
// 读取所有数据
|
||||
result := make([]byte, len(originalData))
|
||||
n, err := reader.Read(result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
|
||||
// 验证数据是否匹配
|
||||
assert.Equal(t, originalData, result)
|
||||
}
|
||||
|
||||
func generateRTPPackets(data []byte, ssrc uint32, initialSeq uint16) []*rtp.Packet {
|
||||
var packets []*rtp.Packet
|
||||
seq := initialSeq
|
||||
maxPayloadSize := 100
|
||||
|
||||
for len(data) > 0 {
|
||||
// 确定当前包的负载大小
|
||||
payloadSize := maxPayloadSize
|
||||
if len(data) < payloadSize {
|
||||
payloadSize = len(data)
|
||||
}
|
||||
|
||||
// 创建RTP包
|
||||
packet := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
Padding: false,
|
||||
Extension: false,
|
||||
Marker: false,
|
||||
PayloadType: 96,
|
||||
SequenceNumber: seq,
|
||||
Timestamp: 123456,
|
||||
SSRC: ssrc,
|
||||
},
|
||||
Payload: data[:payloadSize],
|
||||
}
|
||||
|
||||
packets = append(packets, packet)
|
||||
|
||||
// 更新数据和序列号
|
||||
data = data[payloadSize:]
|
||||
seq++
|
||||
}
|
||||
|
||||
return packets
|
||||
}
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"m7s.live/v5/pkg/util"
|
||||
"net"
|
||||
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
type TCP net.TCPConn
|
||||
|
||||
func (t *TCP) Read(onRTP func(util.Buffer) error) (err error) {
|
||||
func (t *TCP) ReadRTP(onRTP func(util.Buffer) error) (err error) {
|
||||
reader := bufio.NewReader((*net.TCPConn)(t))
|
||||
rtpLenBuf := make([]byte, 4)
|
||||
buffer := make(util.Buffer, 1024)
|
||||
|
||||
148
plugin/rtp/pkg/transceiver.go
Normal file
148
plugin/rtp/pkg/transceiver.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/pion/rtp"
|
||||
mpegps "m7s.live/v5/pkg/format/ps"
|
||||
"m7s.live/v5/pkg/task"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
var ErrRTPReceiveLost = errors.New("rtp receive lost")
|
||||
|
||||
// 数据流传输模式(UDP:udp传输、TCP-ACTIVE:tcp主动模式、TCP-PASSIVE:tcp被动模式、MANUAL:手动模式)
|
||||
type StreamMode string
|
||||
|
||||
const (
|
||||
StreamModeUDP StreamMode = "UDP"
|
||||
StreamModeTCPActive StreamMode = "TCP-ACTIVE"
|
||||
StreamModeTCPPassive StreamMode = "TCP-PASSIVE"
|
||||
StreamModeManual StreamMode = "MANUAL"
|
||||
)
|
||||
|
||||
type ChanReader chan []byte
|
||||
|
||||
func (r ChanReader) Read(buf []byte) (n int, err error) {
|
||||
b, ok := <-r
|
||||
if !ok {
|
||||
return 0, io.EOF
|
||||
}
|
||||
copy(buf, b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
type RTPChanReader chan []byte
|
||||
|
||||
func (r RTPChanReader) Read(packet *rtp.Packet) (err error) {
|
||||
b, ok := <-r
|
||||
if !ok {
|
||||
return io.EOF
|
||||
}
|
||||
return packet.Unmarshal(b)
|
||||
}
|
||||
|
||||
func (r RTPChanReader) Close() error {
|
||||
close(r)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Receiver struct {
|
||||
task.Task
|
||||
*util.BufReader
|
||||
ListenAddr string
|
||||
net.Listener
|
||||
StreamMode StreamMode
|
||||
SSRC uint32 // RTP SSRC
|
||||
RTPMouth chan []byte
|
||||
}
|
||||
|
||||
type PSReceiver struct {
|
||||
Receiver
|
||||
mpegps.MpegPsDemuxer
|
||||
}
|
||||
|
||||
func (p *PSReceiver) Start() error {
|
||||
err := p.Receiver.Start()
|
||||
if err == nil {
|
||||
p.Using(p.Publisher)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *PSReceiver) Run() error {
|
||||
p.MpegPsDemuxer.Allocator = util.NewScalableMemoryAllocator(1 << util.MinPowerOf2)
|
||||
p.Using(p.MpegPsDemuxer.Allocator)
|
||||
return p.MpegPsDemuxer.Feed(p.BufReader)
|
||||
}
|
||||
|
||||
func (p *Receiver) Start() (err error) {
|
||||
var rtpReader *RTPPayloadReader
|
||||
switch p.StreamMode {
|
||||
case StreamModeTCPActive:
|
||||
// TCP主动模式不需要监听,直接返回
|
||||
p.Info("TCP-ACTIVE mode, no need to listen")
|
||||
addr := p.ListenAddr
|
||||
if !strings.Contains(addr, ":") {
|
||||
addr = ":" + addr
|
||||
}
|
||||
if strings.HasPrefix(addr, ":") {
|
||||
p.Error("invalid address, missing IP", "addr", addr)
|
||||
return fmt.Errorf("invalid address %s, missing IP", addr)
|
||||
}
|
||||
p.Info("TCP-ACTIVE mode, connecting to device", "addr", addr)
|
||||
var conn net.Conn
|
||||
conn, err = net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
p.Error("connect to device failed", "err", err)
|
||||
return err
|
||||
}
|
||||
p.OnStop(conn.Close)
|
||||
rtpReader = NewRTPPayloadReader(NewRTPTCPReader(conn))
|
||||
p.BufReader = util.NewBufReader(rtpReader)
|
||||
case StreamModeTCPPassive:
|
||||
var conn net.Conn
|
||||
if p.SSRC == 0 {
|
||||
p.Info("start new listener", "addr", p.ListenAddr)
|
||||
p.Listener, err = net.Listen("tcp4", p.ListenAddr)
|
||||
if err != nil {
|
||||
p.Error("start listen", "err", err)
|
||||
return errors.New("start listen,err" + err.Error())
|
||||
}
|
||||
p.OnStop(p.Listener.Close)
|
||||
conn, err = p.Accept()
|
||||
} else {
|
||||
//TODO: 公用监听端口
|
||||
}
|
||||
if err != nil {
|
||||
p.Error("accept", "err", err)
|
||||
return err
|
||||
}
|
||||
p.OnStop(conn.Close)
|
||||
rtpReader = NewRTPPayloadReader(NewRTPTCPReader(conn))
|
||||
p.BufReader = util.NewBufReader(rtpReader)
|
||||
case StreamModeUDP:
|
||||
var udpAddr *net.UDPAddr
|
||||
udpAddr, err = net.ResolveUDPAddr("udp4", p.ListenAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var conn net.Conn
|
||||
conn, err = net.ListenUDP("udp4", udpAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rtpReader = NewRTPPayloadReader(NewRTPUDPReader(conn))
|
||||
p.BufReader = util.NewBufReader(rtpReader)
|
||||
case StreamModeManual:
|
||||
p.RTPMouth = make(chan []byte)
|
||||
rtpReader = NewRTPPayloadReader((RTPChanReader)(p.RTPMouth))
|
||||
p.BufReader = util.NewBufReader(rtpReader)
|
||||
}
|
||||
p.Using(rtpReader, p.BufReader)
|
||||
return
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/deepch/vdk/codec/h264parser"
|
||||
"github.com/deepch/vdk/codec/h265parser"
|
||||
@@ -26,29 +27,27 @@ type (
|
||||
}
|
||||
H264Ctx struct {
|
||||
H26xCtx
|
||||
codec.H264Ctx
|
||||
*codec.H264Ctx
|
||||
}
|
||||
H265Ctx struct {
|
||||
H26xCtx
|
||||
codec.H265Ctx
|
||||
*codec.H265Ctx
|
||||
DONL bool
|
||||
}
|
||||
AV1Ctx struct {
|
||||
RTPCtx
|
||||
codec.AV1Ctx
|
||||
*codec.AV1Ctx
|
||||
}
|
||||
VP9Ctx struct {
|
||||
RTPCtx
|
||||
}
|
||||
Video struct {
|
||||
VideoFrame struct {
|
||||
RTPData
|
||||
CTS time.Duration
|
||||
DTS time.Duration
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
_ IAVFrame = (*Video)(nil)
|
||||
_ IAVFrame = (*VideoFrame)(nil)
|
||||
_ IVideoCodecCtx = (*H264Ctx)(nil)
|
||||
_ IVideoCodecCtx = (*H265Ctx)(nil)
|
||||
_ IVideoCodecCtx = (*AV1Ctx)(nil)
|
||||
@@ -62,207 +61,188 @@ const (
|
||||
MTUSize = 1460
|
||||
)
|
||||
|
||||
func (r *Video) Parse(t *AVTrack) (err error) {
|
||||
switch r.MimeType {
|
||||
case webrtc.MimeTypeH264:
|
||||
var ctx *H264Ctx
|
||||
if t.ICodecCtx != nil {
|
||||
ctx = t.ICodecCtx.(*H264Ctx)
|
||||
} else {
|
||||
ctx = &H264Ctx{}
|
||||
ctx.parseFmtpLine(r.RTPCodecParameters)
|
||||
var sps, pps []byte
|
||||
//packetization-mode=1; sprop-parameter-sets=J2QAKaxWgHgCJ+WagICAgQ==,KO48sA==; profile-level-id=640029
|
||||
if sprop, ok := ctx.Fmtp["sprop-parameter-sets"]; ok {
|
||||
if sprops := strings.Split(sprop, ","); len(sprops) == 2 {
|
||||
if sps, err = base64.StdEncoding.DecodeString(sprops[0]); err != nil {
|
||||
return
|
||||
}
|
||||
if pps, err = base64.StdEncoding.DecodeString(sprops[1]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if ctx.CodecData, err = h264parser.NewCodecDataFromSPSAndPPS(sps, pps); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.ICodecCtx = ctx
|
||||
}
|
||||
if t.Value.Raw, err = r.Demux(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
pts := r.Packets[0].Timestamp
|
||||
var hasSPSPPS bool
|
||||
func (r *VideoFrame) Parse(data IAVFrame) (err error) {
|
||||
input := data.(*VideoFrame)
|
||||
r.Packets = append(r.Packets[:0], input.Packets...)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *VideoFrame) Recycle() {
|
||||
r.RecyclableMemory.Recycle()
|
||||
r.Packets.Reset()
|
||||
}
|
||||
|
||||
func (r *VideoFrame) CheckCodecChange() (err error) {
|
||||
if len(r.Packets) == 0 {
|
||||
return ErrSkip
|
||||
}
|
||||
old := r.ICodecCtx
|
||||
// 解复用数据
|
||||
if err = r.Demux(); err != nil {
|
||||
return
|
||||
}
|
||||
// 处理时间戳和序列号
|
||||
pts := r.Packets[0].Timestamp
|
||||
nalus := r.Raw.(*Nalus)
|
||||
switch ctx := old.(type) {
|
||||
case *H264Ctx:
|
||||
dts := ctx.dtsEst.Feed(pts)
|
||||
r.DTS = time.Duration(dts) * time.Millisecond / 90
|
||||
r.CTS = time.Duration(pts-dts) * time.Millisecond / 90
|
||||
for _, nalu := range t.Value.Raw.(Nalus) {
|
||||
switch codec.ParseH264NALUType(nalu.Buffers[0][0]) {
|
||||
r.SetDTS(time.Duration(dts))
|
||||
r.SetPTS(time.Duration(pts))
|
||||
|
||||
// 检查 SPS、PPS 和 IDR 帧
|
||||
var sps, pps []byte
|
||||
var hasSPSPPS bool
|
||||
for nalu := range nalus.RangePoint {
|
||||
nalType := codec.ParseH264NALUType(nalu.Buffers[0][0])
|
||||
switch nalType {
|
||||
case h264parser.NALU_SPS:
|
||||
ctx.RecordInfo.SPS = [][]byte{nalu.ToBytes()}
|
||||
if ctx.SPSInfo, err = h264parser.ParseSPS(ctx.SPS()); err != nil {
|
||||
return
|
||||
}
|
||||
sps = nalu.ToBytes()
|
||||
defer nalus.Remove(nalu)
|
||||
case h264parser.NALU_PPS:
|
||||
hasSPSPPS = true
|
||||
ctx.RecordInfo.PPS = [][]byte{nalu.ToBytes()}
|
||||
if ctx.CodecData, err = h264parser.NewCodecDataFromSPSAndPPS(ctx.RecordInfo.SPS[0], ctx.RecordInfo.PPS[0]); err != nil {
|
||||
return
|
||||
}
|
||||
pps = nalu.ToBytes()
|
||||
defer nalus.Remove(nalu)
|
||||
case codec.NALU_IDR_Picture:
|
||||
t.Value.IDR = true
|
||||
r.IDR = true
|
||||
}
|
||||
}
|
||||
if len(ctx.CodecData.Record) == 0 {
|
||||
return ErrSkip
|
||||
}
|
||||
if t.Value.IDR && !hasSPSPPS {
|
||||
spsRTP := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
|
||||
// 如果发现新的 SPS/PPS,更新编解码器上下文
|
||||
if hasSPSPPS = sps != nil && pps != nil; hasSPSPPS && (len(ctx.Record) == 0 || !bytes.Equal(sps, ctx.SPS()) || !bytes.Equal(pps, ctx.PPS())) {
|
||||
var newCodecData h264parser.CodecData
|
||||
if newCodecData, err = h264parser.NewCodecDataFromSPSAndPPS(sps, pps); err != nil {
|
||||
return
|
||||
}
|
||||
newCtx := &H264Ctx{
|
||||
H26xCtx: ctx.H26xCtx,
|
||||
H264Ctx: &codec.H264Ctx{
|
||||
CodecData: newCodecData,
|
||||
},
|
||||
Payload: ctx.SPS(),
|
||||
}
|
||||
ppsRTP := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.PPS(),
|
||||
// 保持原有的 RTP 参数
|
||||
if oldCtx, ok := old.(*H264Ctx); ok {
|
||||
newCtx.RTPCtx = oldCtx.RTPCtx
|
||||
}
|
||||
r.ICodecCtx = newCtx
|
||||
} else {
|
||||
// 如果是 IDR 帧但没有 SPS/PPS,需要插入
|
||||
if r.IDR && len(ctx.SPS()) > 0 && len(ctx.PPS()) > 0 {
|
||||
spsRTP := rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.SPS(),
|
||||
}
|
||||
ppsRTP := rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.PPS(),
|
||||
}
|
||||
r.Packets = slices.Insert(r.Packets, 0, spsRTP, ppsRTP)
|
||||
}
|
||||
r.Packets = slices.Insert(r.Packets, 0, spsRTP, ppsRTP)
|
||||
}
|
||||
for _, p := range r.Packets {
|
||||
|
||||
// 更新序列号
|
||||
for p := range r.Packets.RangePoint {
|
||||
p.SequenceNumber = ctx.seq
|
||||
ctx.seq++
|
||||
}
|
||||
case webrtc.MimeTypeH265:
|
||||
var ctx *H265Ctx
|
||||
if t.ICodecCtx != nil {
|
||||
ctx = t.ICodecCtx.(*H265Ctx)
|
||||
} else {
|
||||
ctx = &H265Ctx{}
|
||||
ctx.parseFmtpLine(r.RTPCodecParameters)
|
||||
var vps, sps, pps []byte
|
||||
if sprop_sps, ok := ctx.Fmtp["sprop-sps"]; ok {
|
||||
if sps, err = base64.StdEncoding.DecodeString(sprop_sps); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sprop_pps, ok := ctx.Fmtp["sprop-pps"]; ok {
|
||||
if pps, err = base64.StdEncoding.DecodeString(sprop_pps); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sprop_vps, ok := ctx.Fmtp["sprop-vps"]; ok {
|
||||
if vps, err = base64.StdEncoding.DecodeString(sprop_vps); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(vps) > 0 && len(sps) > 0 && len(pps) > 0 {
|
||||
if ctx.CodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(vps, sps, pps); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sprop_donl, ok := ctx.Fmtp["sprop-max-don-diff"]; ok {
|
||||
if sprop_donl != "0" {
|
||||
ctx.DONL = true
|
||||
}
|
||||
}
|
||||
t.ICodecCtx = ctx
|
||||
}
|
||||
if t.Value.Raw, err = r.Demux(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
pts := r.Packets[0].Timestamp
|
||||
case *H265Ctx:
|
||||
dts := ctx.dtsEst.Feed(pts)
|
||||
r.DTS = time.Duration(dts) * time.Millisecond / 90
|
||||
r.CTS = time.Duration(pts-dts) * time.Millisecond / 90
|
||||
r.SetDTS(time.Duration(dts))
|
||||
r.SetPTS(time.Duration(pts))
|
||||
// 检查 VPS、SPS、PPS 和 IDR 帧
|
||||
var vps, sps, pps []byte
|
||||
var hasVPSSPSPPS bool
|
||||
for _, nalu := range t.Value.Raw.(Nalus) {
|
||||
for nalu := range nalus.RangePoint {
|
||||
switch codec.ParseH265NALUType(nalu.Buffers[0][0]) {
|
||||
case h265parser.NAL_UNIT_VPS:
|
||||
ctx = &H265Ctx{}
|
||||
ctx.RecordInfo.VPS = [][]byte{nalu.ToBytes()}
|
||||
ctx.RTPCodecParameters = *r.RTPCodecParameters
|
||||
t.ICodecCtx = ctx
|
||||
vps = nalu.ToBytes()
|
||||
defer nalus.Remove(nalu)
|
||||
case h265parser.NAL_UNIT_SPS:
|
||||
ctx.RecordInfo.SPS = [][]byte{nalu.ToBytes()}
|
||||
if ctx.SPSInfo, err = h265parser.ParseSPS(ctx.SPS()); err != nil {
|
||||
return
|
||||
}
|
||||
sps = nalu.ToBytes()
|
||||
defer nalus.Remove(nalu)
|
||||
case h265parser.NAL_UNIT_PPS:
|
||||
hasVPSSPSPPS = true
|
||||
ctx.RecordInfo.PPS = [][]byte{nalu.ToBytes()}
|
||||
if ctx.CodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(ctx.RecordInfo.VPS[0], ctx.RecordInfo.SPS[0], ctx.RecordInfo.PPS[0]); err != nil {
|
||||
return
|
||||
}
|
||||
pps = nalu.ToBytes()
|
||||
defer nalus.Remove(nalu)
|
||||
case h265parser.NAL_UNIT_CODED_SLICE_BLA_W_LP,
|
||||
h265parser.NAL_UNIT_CODED_SLICE_BLA_W_RADL,
|
||||
h265parser.NAL_UNIT_CODED_SLICE_BLA_N_LP,
|
||||
h265parser.NAL_UNIT_CODED_SLICE_IDR_W_RADL,
|
||||
h265parser.NAL_UNIT_CODED_SLICE_IDR_N_LP,
|
||||
h265parser.NAL_UNIT_CODED_SLICE_CRA:
|
||||
t.Value.IDR = true
|
||||
r.IDR = true
|
||||
}
|
||||
}
|
||||
if len(ctx.CodecData.Record) == 0 {
|
||||
return ErrSkip
|
||||
|
||||
// 如果发现新的 VPS/SPS/PPS,更新编解码器上下文
|
||||
if hasVPSSPSPPS = vps != nil && sps != nil && pps != nil; hasVPSSPSPPS && (len(ctx.Record) == 0 || !bytes.Equal(vps, ctx.VPS()) || !bytes.Equal(sps, ctx.SPS()) || !bytes.Equal(pps, ctx.PPS())) {
|
||||
var newCodecData h265parser.CodecData
|
||||
if newCodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(vps, sps, pps); err != nil {
|
||||
return
|
||||
}
|
||||
newCtx := &H265Ctx{
|
||||
H26xCtx: ctx.H26xCtx,
|
||||
H265Ctx: &codec.H265Ctx{
|
||||
CodecData: newCodecData,
|
||||
},
|
||||
}
|
||||
// 保持原有的 RTP 参数
|
||||
if oldCtx, ok := old.(*H265Ctx); ok {
|
||||
newCtx.RTPCtx = oldCtx.RTPCtx
|
||||
}
|
||||
r.ICodecCtx = newCtx
|
||||
} else {
|
||||
// 如果是 IDR 帧但没有 VPS/SPS/PPS,需要插入
|
||||
if r.IDR && len(ctx.VPS()) > 0 && len(ctx.SPS()) > 0 && len(ctx.PPS()) > 0 {
|
||||
vpsRTP := rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.VPS(),
|
||||
}
|
||||
spsRTP := rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.SPS(),
|
||||
}
|
||||
ppsRTP := rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.PPS(),
|
||||
}
|
||||
r.Packets = slices.Insert(r.Packets, 0, vpsRTP, spsRTP, ppsRTP)
|
||||
}
|
||||
}
|
||||
if t.Value.IDR && !hasVPSSPSPPS {
|
||||
vpsRTP := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.VPS(),
|
||||
}
|
||||
spsRTP := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.SPS(),
|
||||
}
|
||||
ppsRTP := &rtp.Packet{
|
||||
Header: rtp.Header{
|
||||
Version: 2,
|
||||
SequenceNumber: ctx.SequenceNumber,
|
||||
Timestamp: pts,
|
||||
SSRC: ctx.SSRC,
|
||||
PayloadType: uint8(ctx.PayloadType),
|
||||
},
|
||||
Payload: ctx.PPS(),
|
||||
}
|
||||
r.Packets = slices.Insert(r.Packets, 0, vpsRTP, spsRTP, ppsRTP)
|
||||
}
|
||||
for _, p := range r.Packets {
|
||||
|
||||
// 更新序列号
|
||||
for p := range r.Packets.RangePoint {
|
||||
p.SequenceNumber = ctx.seq
|
||||
ctx.seq++
|
||||
}
|
||||
case webrtc.MimeTypeVP9:
|
||||
// var ctx RTPVP9Ctx
|
||||
// ctx.RTPCodecParameters = *r.RTPCodecParameters
|
||||
// codecCtx = &ctx
|
||||
case webrtc.MimeTypeAV1:
|
||||
var ctx AV1Ctx
|
||||
ctx.RTPCodecParameters = *r.RTPCodecParameters
|
||||
t.ICodecCtx = &ctx
|
||||
default:
|
||||
err = ErrUnsupportCodec
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -279,26 +259,45 @@ func (av1 *AV1Ctx) GetInfo() string {
|
||||
return av1.SDPFmtpLine
|
||||
}
|
||||
|
||||
func (r *Video) GetTimestamp() time.Duration {
|
||||
return r.DTS
|
||||
}
|
||||
func (r *VideoFrame) Mux(baseFrame *Sample) error {
|
||||
// 获取编解码器上下文
|
||||
codecCtx := r.ICodecCtx
|
||||
if codecCtx == nil {
|
||||
switch base := baseFrame.GetBase().(type) {
|
||||
case *codec.H264Ctx:
|
||||
var ctx H264Ctx
|
||||
ctx.H264Ctx = base
|
||||
ctx.PayloadType = 96
|
||||
ctx.MimeType = webrtc.MimeTypeH264
|
||||
ctx.ClockRate = 90000
|
||||
spsInfo := ctx.SPSInfo
|
||||
ctx.SDPFmtpLine = fmt.Sprintf("sprop-parameter-sets=%s,%s;profile-level-id=%02x%02x%02x;level-asymmetry-allowed=1;packetization-mode=1", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), spsInfo.ProfileIdc, spsInfo.ConstraintSetFlag, spsInfo.LevelIdc)
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
codecCtx = &ctx
|
||||
case *codec.H265Ctx:
|
||||
var ctx H265Ctx
|
||||
ctx.H265Ctx = base
|
||||
ctx.PayloadType = 98
|
||||
ctx.MimeType = webrtc.MimeTypeH265
|
||||
ctx.SDPFmtpLine = fmt.Sprintf("profile-id=1;sprop-sps=%s;sprop-pps=%s;sprop-vps=%s", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), base64.StdEncoding.EncodeToString(ctx.VPS()))
|
||||
ctx.ClockRate = 90000
|
||||
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
|
||||
codecCtx = &ctx
|
||||
}
|
||||
r.ICodecCtx = codecCtx
|
||||
}
|
||||
// 获取时间戳信息
|
||||
pts := uint32(baseFrame.GetPTS())
|
||||
|
||||
func (r *Video) GetCTS() time.Duration {
|
||||
return r.CTS
|
||||
}
|
||||
|
||||
func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
|
||||
pts := uint32((from.Timestamp + from.CTS) * 90 / time.Millisecond)
|
||||
switch c := codecCtx.(type) {
|
||||
case *H264Ctx:
|
||||
ctx := &c.RTPCtx
|
||||
r.RTPCodecParameters = &ctx.RTPCodecParameters
|
||||
var lastPacket *rtp.Packet
|
||||
if from.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 {
|
||||
if baseFrame.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 {
|
||||
r.Append(ctx, pts, c.SPS())
|
||||
r.Append(ctx, pts, c.PPS())
|
||||
}
|
||||
for _, nalu := range from.Raw.(Nalus) {
|
||||
for nalu := range baseFrame.Raw.(*Nalus).RangePoint {
|
||||
if reader := nalu.NewReader(); reader.Length > MTUSize {
|
||||
payloadLen := MTUSize
|
||||
if reader.Length+1 < payloadLen {
|
||||
@@ -306,7 +305,7 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
|
||||
}
|
||||
//fu-a
|
||||
mem := r.NextN(payloadLen)
|
||||
reader.ReadBytesTo(mem[1:])
|
||||
reader.Read(mem[1:])
|
||||
fuaHead, naluType := codec.NALU_FUA.Or(mem[1]&0x60), mem[1]&0x1f
|
||||
mem[0], mem[1] = fuaHead, naluType|startBit
|
||||
lastPacket = r.Append(ctx, pts, mem)
|
||||
@@ -315,27 +314,26 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
|
||||
payloadLen = reader.Length + 2
|
||||
}
|
||||
mem = r.NextN(payloadLen)
|
||||
reader.ReadBytesTo(mem[2:])
|
||||
reader.Read(mem[2:])
|
||||
mem[0], mem[1] = fuaHead, naluType
|
||||
}
|
||||
lastPacket.Payload[1] |= endBit
|
||||
} else {
|
||||
mem := r.NextN(reader.Length)
|
||||
reader.ReadBytesTo(mem)
|
||||
reader.Read(mem)
|
||||
lastPacket = r.Append(ctx, pts, mem)
|
||||
}
|
||||
}
|
||||
lastPacket.Header.Marker = true
|
||||
case *H265Ctx:
|
||||
ctx := &c.RTPCtx
|
||||
r.RTPCodecParameters = &ctx.RTPCodecParameters
|
||||
var lastPacket *rtp.Packet
|
||||
if from.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 && len(c.RecordInfo.VPS) > 0 {
|
||||
if baseFrame.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 && len(c.RecordInfo.VPS) > 0 {
|
||||
r.Append(ctx, pts, c.VPS())
|
||||
r.Append(ctx, pts, c.SPS())
|
||||
r.Append(ctx, pts, c.PPS())
|
||||
}
|
||||
for _, nalu := range from.Raw.(Nalus) {
|
||||
for nalu := range baseFrame.Raw.(*Nalus).RangePoint {
|
||||
if reader := nalu.NewReader(); reader.Length > MTUSize {
|
||||
var b0, b1 byte
|
||||
_ = reader.ReadByteTo(&b0, &b1)
|
||||
@@ -348,7 +346,7 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
|
||||
payloadLen = reader.Length + 3
|
||||
}
|
||||
mem := r.NextN(payloadLen)
|
||||
reader.ReadBytesTo(mem[3:])
|
||||
reader.Read(mem[3:])
|
||||
mem[0], mem[1], mem[2] = b0, b1, naluType|startBit
|
||||
lastPacket = r.Append(ctx, pts, mem)
|
||||
|
||||
@@ -357,37 +355,37 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
|
||||
payloadLen = reader.Length + 3
|
||||
}
|
||||
mem = r.NextN(payloadLen)
|
||||
reader.ReadBytesTo(mem[3:])
|
||||
reader.Read(mem[3:])
|
||||
mem[0], mem[1], mem[2] = b0, b1, naluType
|
||||
}
|
||||
lastPacket.Payload[2] |= endBit
|
||||
} else {
|
||||
mem := r.NextN(reader.Length)
|
||||
reader.ReadBytesTo(mem)
|
||||
reader.Read(mem)
|
||||
lastPacket = r.Append(ctx, pts, mem)
|
||||
}
|
||||
}
|
||||
lastPacket.Header.Marker = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
|
||||
func (r *VideoFrame) Demux() (err error) {
|
||||
if len(r.Packets) == 0 {
|
||||
return nil, ErrSkip
|
||||
return ErrSkip
|
||||
}
|
||||
switch c := ictx.(type) {
|
||||
switch c := r.ICodecCtx.(type) {
|
||||
case *H264Ctx:
|
||||
var nalus Nalus
|
||||
var nalu util.Memory
|
||||
nalus := r.GetNalus()
|
||||
nalu := nalus.GetNextPointer()
|
||||
var naluType codec.H264NALUType
|
||||
gotNalu := func() {
|
||||
if nalu.Size > 0 {
|
||||
nalus = append(nalus, nalu)
|
||||
nalu = util.Memory{}
|
||||
nalu = nalus.GetNextPointer()
|
||||
}
|
||||
}
|
||||
for _, packet := range r.Packets {
|
||||
if len(packet.Payload) == 0 {
|
||||
for packet := range r.Packets.RangePoint {
|
||||
if len(packet.Payload) < 2 {
|
||||
continue
|
||||
}
|
||||
if packet.Padding {
|
||||
@@ -395,31 +393,31 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
|
||||
}
|
||||
b0 := packet.Payload[0]
|
||||
if t := codec.ParseH264NALUType(b0); t < 24 {
|
||||
nalu.AppendOne(packet.Payload)
|
||||
nalu.PushOne(packet.Payload)
|
||||
gotNalu()
|
||||
} else {
|
||||
offset := t.Offset()
|
||||
switch t {
|
||||
case codec.NALU_STAPA, codec.NALU_STAPB:
|
||||
if len(packet.Payload) <= offset {
|
||||
return nil, fmt.Errorf("invalid nalu size %d", len(packet.Payload))
|
||||
return fmt.Errorf("invalid nalu size %d", len(packet.Payload))
|
||||
}
|
||||
for buffer := util.Buffer(packet.Payload[offset:]); buffer.CanRead(); {
|
||||
if nextSize := int(buffer.ReadUint16()); buffer.Len() >= nextSize {
|
||||
nalu.AppendOne(buffer.ReadN(nextSize))
|
||||
nalu.PushOne(buffer.ReadN(nextSize))
|
||||
gotNalu()
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid nalu size %d", nextSize)
|
||||
return fmt.Errorf("invalid nalu size %d", nextSize)
|
||||
}
|
||||
}
|
||||
case codec.NALU_FUA, codec.NALU_FUB:
|
||||
b1 := packet.Payload[1]
|
||||
if util.Bit1(b1, 0) {
|
||||
naluType.Parse(b1)
|
||||
nalu.AppendOne([]byte{naluType.Or(b0 & 0x60)})
|
||||
nalu.PushOne([]byte{naluType.Or(b0 & 0x60)})
|
||||
}
|
||||
if nalu.Size > 0 {
|
||||
nalu.AppendOne(packet.Payload[offset:])
|
||||
nalu.PushOne(packet.Payload[offset:])
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
@@ -427,18 +425,18 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
|
||||
gotNalu()
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported nalu type %d", t)
|
||||
return fmt.Errorf("unsupported nalu type %d", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nalus, nil
|
||||
nalus.Reduce()
|
||||
return nil
|
||||
case *H265Ctx:
|
||||
var nalus Nalus
|
||||
var nalu util.Memory
|
||||
nalus := r.GetNalus()
|
||||
nalu := nalus.GetNextPointer()
|
||||
gotNalu := func() {
|
||||
if nalu.Size > 0 {
|
||||
nalus = append(nalus, nalu)
|
||||
nalu = util.Memory{}
|
||||
nalu = nalus.GetNextPointer()
|
||||
}
|
||||
}
|
||||
for _, packet := range r.Packets {
|
||||
@@ -447,7 +445,7 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
|
||||
}
|
||||
b0 := packet.Payload[0]
|
||||
if t := codec.ParseH265NALUType(b0); t < H265_NALU_AP {
|
||||
nalu.AppendOne(packet.Payload)
|
||||
nalu.PushOne(packet.Payload)
|
||||
gotNalu()
|
||||
} else {
|
||||
var buffer = util.Buffer(packet.Payload)
|
||||
@@ -458,7 +456,7 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
|
||||
buffer.ReadUint16()
|
||||
}
|
||||
for buffer.CanRead() {
|
||||
nalu.AppendOne(buffer.ReadN(int(buffer.ReadUint16())))
|
||||
nalu.PushOne(buffer.ReadN(int(buffer.ReadUint16())))
|
||||
gotNalu()
|
||||
}
|
||||
if c.DONL {
|
||||
@@ -466,7 +464,7 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
|
||||
}
|
||||
case H265_NALU_FU:
|
||||
if buffer.Len() < 3 {
|
||||
return nil, io.ErrShortBuffer
|
||||
return io.ErrShortBuffer
|
||||
}
|
||||
first3 := buffer.ReadN(3)
|
||||
fuHeader := first3[2]
|
||||
@@ -474,18 +472,19 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
|
||||
buffer.ReadUint16()
|
||||
}
|
||||
if naluType := fuHeader & 0b00111111; util.Bit1(fuHeader, 0) {
|
||||
nalu.AppendOne([]byte{first3[0]&0b10000001 | (naluType << 1), first3[1]})
|
||||
nalu.PushOne([]byte{first3[0]&0b10000001 | (naluType << 1), first3[1]})
|
||||
}
|
||||
nalu.AppendOne(buffer)
|
||||
nalu.PushOne(buffer)
|
||||
if util.Bit1(fuHeader, 1) {
|
||||
gotNalu()
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported nalu type %d", t)
|
||||
return fmt.Errorf("unsupported nalu type %d", t)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nalus, nil
|
||||
nalus.Reduce()
|
||||
return nil
|
||||
}
|
||||
return nil, nil
|
||||
return ErrUnsupportCodec
|
||||
}
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package rtp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/pion/webrtc/v4"
|
||||
"m7s.live/v5/pkg"
|
||||
"m7s.live/v5/pkg/util"
|
||||
)
|
||||
|
||||
func TestRTPH264Ctx_CreateFrame(t *testing.T) {
|
||||
var ctx = &H264Ctx{}
|
||||
ctx.RTPCodecParameters = webrtc.RTPCodecParameters{
|
||||
PayloadType: 96,
|
||||
RTPCodecCapability: webrtc.RTPCodecCapability{
|
||||
MimeType: webrtc.MimeTypeH264,
|
||||
ClockRate: 90000,
|
||||
SDPFmtpLine: "packetization-mode=1; sprop-parameter-sets=J2QAKaxWgHgCJ+WagICAgQ==,KO48sA==; profile-level-id=640029",
|
||||
},
|
||||
}
|
||||
var randStr = util.RandomString(1500)
|
||||
var avFrame = &pkg.AVFrame{}
|
||||
var mem util.Memory
|
||||
mem.Append([]byte(randStr))
|
||||
avFrame.Raw = []util.Memory{mem}
|
||||
frame := new(Video)
|
||||
frame.Mux(ctx, avFrame)
|
||||
var track = &pkg.AVTrack{}
|
||||
err := frame.Parse(track)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if s := string(track.Value.Raw.(pkg.Nalus)[0].ToBytes()); s != randStr {
|
||||
t.Error("not equal", len(s), len(randStr))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user