refactor: frame converter and mp4 track improvements

- Refactor frame converter implementation
- Update mp4 track to use ICodex
- General refactoring and code improvements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
langhuihui
2025-08-04 09:17:12 +08:00
parent b6ee2843b0
commit 8a9fffb987
262 changed files with 20831 additions and 12141 deletions

View File

@@ -1,17 +1,13 @@
package rtp
import (
"encoding/base64"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"strings"
"time"
"unsafe"
"github.com/bluenviron/mediacommon/pkg/bits"
"github.com/deepch/vdk/codec/aacparser"
"github.com/pion/rtp"
"github.com/pion/webrtc/v4"
@@ -21,43 +17,24 @@ import (
)
type RTPData struct {
*webrtc.RTPCodecParameters
Packets []*rtp.Packet
util.RecyclableMemory
Sample
Packets util.ReuseArray[rtp.Packet]
}
func (r *RTPData) Dump(t byte, w io.Writer) {
m := r.GetAllocator().Borrow(3 + len(r.Packets)*2 + r.GetSize())
m[0] = t
binary.BigEndian.PutUint16(m[1:], uint16(len(r.Packets)))
offset := 3
for _, p := range r.Packets {
size := p.MarshalSize()
binary.BigEndian.PutUint16(m[offset:], uint16(size))
offset += 2
p.MarshalTo(m[offset:])
offset += size
}
w.Write(m)
func (r *RTPData) Recycle() {
r.RecyclableMemory.Recycle()
r.Packets.Reset()
}
func (r *RTPData) String() (s string) {
for _, p := range r.Packets {
for p := range r.Packets.RangePoint {
s += fmt.Sprintf("t: %d, s: %d, p: %02X %d\n", p.Timestamp, p.SequenceNumber, p.Payload[0:2], len(p.Payload))
}
return
}
func (r *RTPData) GetTimestamp() time.Duration {
return time.Duration(r.Packets[0].Timestamp) * time.Second / time.Duration(r.ClockRate)
}
func (r *RTPData) GetCTS() time.Duration {
return 0
}
func (r *RTPData) GetSize() (s int) {
for _, p := range r.Packets {
for p := range r.Packets.RangePoint {
s += p.MarshalSize()
}
return
@@ -72,19 +49,19 @@ type (
}
PCMACtx struct {
RTPCtx
codec.PCMACtx
*codec.PCMACtx
}
PCMUCtx struct {
RTPCtx
codec.PCMUCtx
*codec.PCMUCtx
}
OPUSCtx struct {
RTPCtx
codec.OPUSCtx
*codec.OPUSCtx
}
AACCtx struct {
RTPCtx
codec.AACCtx
*codec.AACCtx
SizeLength int // 通常为13
IndexLength int
IndexDeltaLength int
@@ -94,7 +71,7 @@ type (
}
)
func (r *RTPCtx) parseFmtpLine(cp *webrtc.RTPCodecParameters) {
func (r *RTPCtx) ParseFmtpLine(cp *webrtc.RTPCodecParameters) {
r.RTPCodecParameters = *cp
r.Fmtp = make(map[string]string)
kvs := strings.Split(r.SDPFmtpLine, ";")
@@ -121,9 +98,9 @@ func (r *RTPCtx) GetRTPCodecParameter() webrtc.RTPCodecParameters {
return r.RTPCodecParameters
}
func (r *RTPData) Append(ctx *RTPCtx, ts uint32, payload []byte) (lastPacket *rtp.Packet) {
func (r *RTPData) Append(ctx *RTPCtx, ts uint32, payload []byte) *rtp.Packet {
ctx.SequenceNumber++
lastPacket = &rtp.Packet{
r.Packets = append(r.Packets, rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
@@ -132,135 +109,19 @@ func (r *RTPData) Append(ctx *RTPCtx, ts uint32, payload []byte) (lastPacket *rt
PayloadType: uint8(ctx.PayloadType),
},
Payload: payload,
}
r.Packets = append(r.Packets, lastPacket)
return
})
return &r.Packets[len(r.Packets)-1]
}
func (r *RTPData) ConvertCtx(from codec.ICodecCtx) (to codec.ICodecCtx, seq IAVFrame, err error) {
switch from.FourCC() {
case codec.FourCC_H264:
var ctx H264Ctx
ctx.H264Ctx = *from.GetBase().(*codec.H264Ctx)
ctx.PayloadType = 96
ctx.MimeType = webrtc.MimeTypeH264
ctx.ClockRate = 90000
spsInfo := ctx.SPSInfo
ctx.SDPFmtpLine = fmt.Sprintf("sprop-parameter-sets=%s,%s;profile-level-id=%02x%02x%02x;level-asymmetry-allowed=1;packetization-mode=1", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), spsInfo.ProfileIdc, spsInfo.ConstraintSetFlag, spsInfo.LevelIdc)
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
to = &ctx
case codec.FourCC_H265:
var ctx H265Ctx
ctx.H265Ctx = *from.GetBase().(*codec.H265Ctx)
ctx.PayloadType = 98
ctx.MimeType = webrtc.MimeTypeH265
ctx.SDPFmtpLine = fmt.Sprintf("profile-id=1;sprop-sps=%s;sprop-pps=%s;sprop-vps=%s", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), base64.StdEncoding.EncodeToString(ctx.VPS()))
ctx.ClockRate = 90000
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
to = &ctx
case codec.FourCC_MP4A:
var ctx AACCtx
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
ctx.AACCtx = *from.GetBase().(*codec.AACCtx)
ctx.MimeType = "audio/MPEG4-GENERIC"
ctx.SDPFmtpLine = fmt.Sprintf("profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3;config=%s", hex.EncodeToString(ctx.AACCtx.ConfigBytes))
ctx.IndexLength = 3
ctx.IndexDeltaLength = 3
ctx.SizeLength = 13
ctx.RTPCtx.Channels = uint16(ctx.AACCtx.GetChannels())
ctx.PayloadType = 97
ctx.ClockRate = uint32(ctx.CodecData.SampleRate())
to = &ctx
case codec.FourCC_ALAW:
var ctx PCMACtx
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
ctx.PCMACtx = *from.GetBase().(*codec.PCMACtx)
ctx.MimeType = webrtc.MimeTypePCMA
ctx.PayloadType = 8
ctx.ClockRate = uint32(ctx.SampleRate)
to = &ctx
case codec.FourCC_ULAW:
var ctx PCMUCtx
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
ctx.PCMUCtx = *from.GetBase().(*codec.PCMUCtx)
ctx.MimeType = webrtc.MimeTypePCMU
ctx.PayloadType = 0
ctx.ClockRate = uint32(ctx.SampleRate)
to = &ctx
case codec.FourCC_OPUS:
var ctx OPUSCtx
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
ctx.OPUSCtx = *from.GetBase().(*codec.OPUSCtx)
ctx.MimeType = webrtc.MimeTypeOpus
ctx.PayloadType = 111
ctx.ClockRate = uint32(ctx.CodecData.SampleRate())
to = &ctx
}
return
}
var _ IAVFrame = (*AudioFrame)(nil)
type Audio struct {
type AudioFrame struct {
RTPData
}
func (r *Audio) Parse(t *AVTrack) (err error) {
switch r.MimeType {
case webrtc.MimeTypeOpus:
var ctx OPUSCtx
ctx.parseFmtpLine(r.RTPCodecParameters)
ctx.OPUSCtx.Channels = int(ctx.RTPCodecParameters.Channels)
t.ICodecCtx = &ctx
case webrtc.MimeTypePCMA:
var ctx PCMACtx
ctx.parseFmtpLine(r.RTPCodecParameters)
ctx.AudioCtx.SampleRate = int(r.ClockRate)
ctx.AudioCtx.Channels = int(ctx.RTPCodecParameters.Channels)
t.ICodecCtx = &ctx
case webrtc.MimeTypePCMU:
var ctx PCMUCtx
ctx.parseFmtpLine(r.RTPCodecParameters)
ctx.AudioCtx.SampleRate = int(r.ClockRate)
ctx.AudioCtx.Channels = int(ctx.RTPCodecParameters.Channels)
t.ICodecCtx = &ctx
case "audio/MP4A-LATM":
var ctx *AACCtx
if t.ICodecCtx != nil {
// ctx = t.ICodecCtx.(*AACCtx)
} else {
ctx = &AACCtx{}
ctx.parseFmtpLine(r.RTPCodecParameters)
if conf, ok := ctx.Fmtp["config"]; ok {
if ctx.AACCtx.ConfigBytes, err = hex.DecodeString(conf); err == nil {
if ctx.CodecData, err = aacparser.NewCodecDataFromMPEG4AudioConfigBytes(ctx.AACCtx.ConfigBytes); err != nil {
return
}
}
}
t.ICodecCtx = ctx
}
case "audio/MPEG4-GENERIC":
var ctx *AACCtx
if t.ICodecCtx != nil {
// ctx = t.ICodecCtx.(*AACCtx)
} else {
ctx = &AACCtx{}
ctx.parseFmtpLine(r.RTPCodecParameters)
ctx.IndexLength = 3
ctx.IndexDeltaLength = 3
ctx.SizeLength = 13
if conf, ok := ctx.Fmtp["config"]; ok {
if ctx.AACCtx.ConfigBytes, err = hex.DecodeString(conf); err == nil {
if ctx.CodecData, err = aacparser.NewCodecDataFromMPEG4AudioConfigBytes(ctx.AACCtx.ConfigBytes); err != nil {
return
}
}
}
t.ICodecCtx = ctx
}
}
if len(r.Packets) == 0 {
return ErrSkip
}
func (r *AudioFrame) Parse(data IAVFrame) (err error) {
input := data.(*AudioFrame)
r.Packets = append(r.Packets[:0], input.Packets...)
return
}
@@ -286,17 +147,22 @@ func payloadLengthInfoDecode(buf []byte) (int, int, error) {
return l, n, nil
}
func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
func (r *AudioFrame) Demux() (err error) {
if len(r.Packets) == 0 {
return nil, ErrSkip
return ErrSkip
}
var data AudioData
switch r.MimeType {
data := r.GetAudioData()
// 从编解码器上下文获取 MimeType
var mimeType string
if rtpCtx, ok := r.ICodecCtx.(IRTPCtx); ok {
mimeType = rtpCtx.GetRTPCodecParameter().MimeType
}
switch mimeType {
case "audio/MP4A-LATM":
var fragments util.Memory
var fragmentsExpected int
var fragmentsSize int
for _, packet := range r.Packets {
for packet := range r.Packets.RangePoint {
if len(packet.Payload) == 0 {
continue
}
@@ -307,23 +173,23 @@ func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
if fragments.Size == 0 {
pl, n, err := payloadLengthInfoDecode(buf)
if err != nil {
return nil, err
return err
}
buf = buf[n:]
bl := len(buf)
if pl <= bl {
data.AppendOne(buf[:pl])
data.PushOne(buf[:pl])
// there could be other data, due to otherDataPresent. Ignore it.
} else {
if pl > 5*1024 {
fragments = util.Memory{} // discard pending fragments
return nil, fmt.Errorf("access unit size (%d) is too big, maximum is %d",
return fmt.Errorf("access unit size (%d) is too big, maximum is %d",
pl, 5*1024)
}
fragments.AppendOne(buf)
fragments.PushOne(buf)
fragmentsSize = pl
fragmentsExpected = pl - bl
continue
@@ -332,33 +198,33 @@ func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
bl := len(buf)
if fragmentsExpected > bl {
fragments.AppendOne(buf)
fragments.PushOne(buf)
fragmentsExpected -= bl
continue
}
fragments.AppendOne(buf[:fragmentsExpected])
fragments.PushOne(buf[:fragmentsExpected])
// there could be other data, due to otherDataPresent. Ignore it.
data.Append(fragments.Buffers...)
data.Push(fragments.Buffers...)
if fragments.Size != fragmentsSize {
return nil, fmt.Errorf("fragmented AU size is not correct %d != %d", data.Size, fragmentsSize)
return fmt.Errorf("fragmented AU size is not correct %d != %d", data.Size, fragmentsSize)
}
fragments = util.Memory{}
}
}
case "audio/MPEG4-GENERIC":
var fragments util.Memory
for _, packet := range r.Packets {
for packet := range r.Packets.RangePoint {
if len(packet.Payload) < 2 {
continue
}
auHeaderLen := util.ReadBE[int](packet.Payload[:2])
if auHeaderLen == 0 {
data.AppendOne(packet.Payload)
data.PushOne(packet.Payload)
} else {
dataLens, err := r.readAUHeaders(codexCtx.(*AACCtx), packet.Payload[2:], auHeaderLen)
dataLens, err := r.readAUHeaders(r.ICodecCtx.(*AACCtx), packet.Payload[2:], auHeaderLen)
if err != nil {
return nil, err
return err
}
payload := packet.Payload[2:]
pos := auHeaderLen >> 3
@@ -370,48 +236,65 @@ func (r *Audio) Demux(codexCtx codec.ICodecCtx) (any, error) {
if packet.Marker {
for _, dataLen := range dataLens {
if len(payload) < int(dataLen) {
return nil, fmt.Errorf("invalid data len %d", dataLen)
return fmt.Errorf("invalid data len %d", dataLen)
}
data.AppendOne(payload[:dataLen])
data.PushOne(payload[:dataLen])
payload = payload[dataLen:]
}
} else {
if len(dataLens) != 1 {
return nil, fmt.Errorf("a fragmented packet can only contain one AU")
return fmt.Errorf("a fragmented packet can only contain one AU")
}
fragments.AppendOne(payload)
fragments.PushOne(payload)
}
} else {
if len(dataLens) != 1 {
return nil, fmt.Errorf("a fragmented packet can only contain one AU")
return fmt.Errorf("a fragmented packet can only contain one AU")
}
fragments.AppendOne(payload)
fragments.PushOne(payload)
if !packet.Header.Marker {
continue
}
if uint64(fragments.Size) != dataLens[0] {
return nil, fmt.Errorf("fragmented AU size is not correct %d != %d", dataLens[0], fragments.Size)
return fmt.Errorf("fragmented AU size is not correct %d != %d", dataLens[0], fragments.Size)
}
data.Append(fragments.Buffers...)
data.Push(fragments.Buffers...)
fragments = util.Memory{}
}
}
break
}
default:
for _, packet := range r.Packets {
data.AppendOne(packet.Payload)
for packet := range r.Packets.RangePoint {
data.PushOne(packet.Payload)
}
}
return data, nil
return nil
}
func (r *Audio) Mux(codexCtx codec.ICodecCtx, from *AVFrame) {
data := from.Raw.(AudioData)
func (r *AudioFrame) Mux(from *Sample) (err error) {
data := from.Raw.(*AudioData)
var ctx *RTPCtx
var lastPacket *rtp.Packet
switch c := codexCtx.(type) {
case *AACCtx:
switch base := from.GetBase().(type) {
case *codec.AACCtx:
var c *AACCtx
if r.ICodecCtx == nil {
c = &AACCtx{}
c.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
c.AACCtx = base
c.MimeType = "audio/MPEG4-GENERIC"
c.SDPFmtpLine = fmt.Sprintf("profile-level-id=1;mode=AAC-hbr;sizelength=13;indexlength=3;indexdeltalength=3;config=%s", hex.EncodeToString(c.ConfigBytes))
c.IndexLength = 3
c.IndexDeltaLength = 3
c.SizeLength = 13
c.RTPCtx.Channels = uint16(base.GetChannels())
c.PayloadType = 97
c.ClockRate = uint32(base.CodecData.SampleRate())
r.ICodecCtx = c
} else {
c = r.ICodecCtx.(*AACCtx)
}
ctx = &c.RTPCtx
pts := uint32(from.Timestamp * time.Duration(ctx.ClockRate) / time.Second)
//AU_HEADER_LENGTH,因为单位是bit, 除以8就是auHeader的字节长度又因为单个auheader字节长度2字节所以再除以2就是auheader的个数。
@@ -423,15 +306,35 @@ func (r *Audio) Mux(codexCtx codec.ICodecCtx, from *AVFrame) {
}
mem := r.NextN(payloadLen)
copy(mem, auHeaderLen)
reader.ReadBytesTo(mem[4:])
reader.Read(mem[4:])
lastPacket = r.Append(ctx, pts, mem)
}
lastPacket.Header.Marker = true
return
case *PCMACtx:
ctx = &c.RTPCtx
case *PCMUCtx:
ctx = &c.RTPCtx
case *codec.PCMACtx:
if r.ICodecCtx == nil {
var ctx PCMACtx
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
ctx.PCMACtx = base
ctx.MimeType = webrtc.MimeTypePCMA
ctx.PayloadType = 8
ctx.ClockRate = uint32(ctx.SampleRate)
r.ICodecCtx = &ctx
} else {
ctx = &r.ICodecCtx.(*PCMACtx).RTPCtx
}
case *codec.PCMUCtx:
if r.ICodecCtx == nil {
var ctx PCMUCtx
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
ctx.PCMUCtx = base
ctx.MimeType = webrtc.MimeTypePCMU
ctx.PayloadType = 0
ctx.ClockRate = uint32(ctx.SampleRate)
r.ICodecCtx = &ctx
} else {
ctx = &r.ICodecCtx.(*PCMUCtx).RTPCtx
}
}
pts := uint32(from.Timestamp * time.Duration(ctx.ClockRate) / time.Second)
if reader := data.NewReader(); reader.Length > MTUSize {
@@ -441,18 +344,19 @@ func (r *Audio) Mux(codexCtx codec.ICodecCtx, from *AVFrame) {
payloadLen = reader.Length
}
mem := r.NextN(payloadLen)
reader.ReadBytesTo(mem)
reader.Read(mem)
lastPacket = r.Append(ctx, pts, mem)
}
} else {
mem := r.NextN(reader.Length)
reader.ReadBytesTo(mem)
reader.Read(mem)
lastPacket = r.Append(ctx, pts, mem)
}
lastPacket.Header.Marker = true
return
}
func (r *Audio) readAUHeaders(ctx *AACCtx, buf []byte, headersLen int) ([]uint64, error) {
func (r *AudioFrame) readAUHeaders(ctx *AACCtx, buf []byte, headersLen int) ([]uint64, error) {
firstRead := false
count := 0

476
plugin/rtp/pkg/forward.go Normal file
View File

@@ -0,0 +1,476 @@
package rtp
import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"time"
"github.com/pion/rtp"
"m7s.live/v5/pkg/util"
)
// ConnectionConfig 连接配置
type ConnectionConfig struct {
IP string
Port uint32
Mode StreamMode
SSRC uint32 // RTP SSRC
}
// ForwardConfig 转发配置
type ForwardConfig struct {
Source ConnectionConfig
Target ConnectionConfig
Relay bool
}
// Forwarder 转发器
type Forwarder struct {
config *ForwardConfig
source net.Conn
target net.Conn
}
// NewForwarder 创建新的转发器
func NewForwarder(config *ForwardConfig) *Forwarder {
return &Forwarder{
config: config,
}
}
// establishSourceConnection 建立源连接
func (f *Forwarder) establishSourceConnection(config ConnectionConfig) (net.Conn, error) {
switch config.Mode {
case StreamModeTCPActive:
dialer := &net.Dialer{Timeout: 10 * time.Second}
netConn, err := dialer.Dial("tcp", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("connect failed: %v", err)
}
return netConn, nil
case StreamModeTCPPassive:
listener, err := net.Listen("tcp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("listen failed: %v", err)
}
// Set timeout for accepting connections
if tcpListener, ok := listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(30 * time.Second))
}
netConn, err := listener.Accept()
if err != nil {
listener.Close()
return nil, fmt.Errorf("accept failed: %v", err)
}
return netConn, nil
case StreamModeUDP:
// Source UDP - listen
udpAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("resolve UDP address failed: %v", err)
}
netConn, err := net.ListenUDP("udp4", udpAddr)
if err != nil {
return nil, fmt.Errorf("UDP listen failed: %v", err)
}
return netConn, nil
}
return nil, fmt.Errorf("unsupported mode: %s", config.Mode)
}
// establishTargetConnection 建立目标连接
func (f *Forwarder) establishTargetConnection(config ConnectionConfig) (net.Conn, error) {
switch config.Mode {
case StreamModeTCPActive:
dialer := &net.Dialer{Timeout: 10 * time.Second}
netConn, err := dialer.Dial("tcp", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("connect failed: %v", err)
}
return netConn, nil
case StreamModeTCPPassive:
listener, err := net.Listen("tcp4", fmt.Sprintf("%s:%d", config.IP, config.Port))
if err != nil {
return nil, fmt.Errorf("listen failed: %v", err)
}
// Set timeout for accepting connections
if tcpListener, ok := listener.(*net.TCPListener); ok {
tcpListener.SetDeadline(time.Now().Add(30 * time.Second))
}
netConn, err := listener.Accept()
if err != nil {
listener.Close()
return nil, fmt.Errorf("accept failed: %v", err)
}
return netConn, nil
case StreamModeUDP:
// Target UDP - dial
netConn, err := net.DialUDP("udp", nil, &net.UDPAddr{
IP: net.ParseIP(config.IP),
Port: int(config.Port),
})
if err != nil {
return nil, fmt.Errorf("UDP dial failed: %v", err)
}
return netConn, nil
}
return nil, fmt.Errorf("unsupported mode: %s", config.Mode)
}
// setupConnections 建立源和目标连接
func (f *Forwarder) setupConnections() error {
var err error
// 建立源连接
f.source, err = f.establishSourceConnection(f.config.Source)
if err != nil {
return fmt.Errorf("source connection failed: %v", err)
}
// 建立目标连接
f.target, err = f.establishTargetConnection(f.config.Target)
if err != nil {
return fmt.Errorf("target connection failed: %v", err)
}
return nil
}
// cleanup 清理连接
func (f *Forwarder) cleanup() {
if f.source != nil {
f.source.Close()
}
if f.target != nil {
f.target.Close()
}
}
// createRTPReader 创建RTP读取器
func (f *Forwarder) createRTPReader() IRTPReader {
switch f.config.Source.Mode {
case StreamModeUDP:
return NewRTPUDPReader(f.source)
case StreamModeTCPActive, StreamModeTCPPassive:
return NewRTPTCPReader(f.source)
default:
return nil
}
}
// createRTPWriter 创建RTP写入器
func (f *Forwarder) createRTPWriter() RTPWriter {
return NewRTPWriter(f.target, f.config.Target.Mode)
}
// RTPWriter RTP写入器接口
type RTPWriter interface {
WritePacket(packet *rtp.Packet) error
WriteRaw(data []byte) error
}
// rtpWriter RTP写入器实现
type rtpWriter struct {
writer io.Writer
mode StreamMode
header []byte
sendBuffer util.Buffer // 可复用的发送缓冲区
}
// NewRTPWriter 创建RTP写入器
func NewRTPWriter(writer io.Writer, mode StreamMode) RTPWriter {
return &rtpWriter{
writer: writer,
mode: mode,
header: make([]byte, 2),
sendBuffer: util.Buffer{}, // 初始化可复用缓冲区
}
}
// WritePacket 写入RTP包
func (w *rtpWriter) WritePacket(packet *rtp.Packet) error {
// 复用sendBuffer避免重复创建
w.sendBuffer.Reset()
w.sendBuffer.Malloc(packet.MarshalSize())
_, err := packet.MarshalTo(w.sendBuffer)
if err != nil {
return fmt.Errorf("marshal RTP packet failed: %v", err)
}
return w.WriteRaw(w.sendBuffer)
}
// WriteRaw 写入原始数据
func (w *rtpWriter) WriteRaw(data []byte) error {
if w.mode == StreamModeUDP {
_, err := w.writer.Write(data)
return err
} else {
// TCP模式需要添加长度头
binary.BigEndian.PutUint16(w.header, uint16(len(data)))
_, err := w.writer.Write(w.header)
if err != nil {
return err
}
_, err = w.writer.Write(data)
return err
}
}
// RelayProcessor 中继处理器
type RelayProcessor struct {
reader io.Reader
writer io.Writer
sourceMode StreamMode
targetMode StreamMode
buffer []byte // 可复用的缓冲区
header []byte // 可复用的头部缓冲区
}
// NewRelayProcessor 创建中继处理器
func NewRelayProcessor(reader io.Reader, writer io.Writer, sourceMode, targetMode StreamMode) *RelayProcessor {
return &RelayProcessor{
reader: reader,
writer: writer,
sourceMode: sourceMode,
targetMode: targetMode,
buffer: make([]byte, 1460), // 初始化可复用缓冲区
header: make([]byte, 2), // 初始化可复用头部缓冲区
}
}
// Process 处理中继
func (p *RelayProcessor) Process(ctx context.Context) error {
if p.sourceMode == p.targetMode {
// 相同模式直接复制
_, err := io.Copy(p.writer, p.reader)
return err
}
// 不同模式需要转换
if p.sourceMode == StreamModeUDP && (p.targetMode == StreamModeTCPActive || p.targetMode == StreamModeTCPPassive) {
// UDP to TCP
return p.processUDPToTCP(ctx)
} else if (p.sourceMode == StreamModeTCPActive || p.sourceMode == StreamModeTCPPassive) && p.targetMode == StreamModeUDP {
// TCP to UDP
return p.processTCPToUDP(ctx)
}
return fmt.Errorf("unsupported mode combination")
}
// processUDPToTCP UDP转TCP
func (p *RelayProcessor) processUDPToTCP(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
n, err := p.reader.Read(p.buffer)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
// 添加2字节长度头
binary.BigEndian.PutUint16(p.header, uint16(n))
_, err = p.writer.Write(p.header)
if err != nil {
return err
}
_, err = p.writer.Write(p.buffer[:n])
if err != nil {
return err
}
}
}
// processTCPToUDP TCP转UDP
func (p *RelayProcessor) processTCPToUDP(ctx context.Context) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// 读取2字节长度头
_, err := io.ReadFull(p.reader, p.header)
if err != nil {
if err == io.EOF {
return nil
}
return err
}
// 获取包长度
packetLength := binary.BigEndian.Uint16(p.header)
// 如果包长度超过缓冲区大小,需要动态分配
if packetLength > uint16(len(p.buffer)) {
packetData := make([]byte, packetLength)
_, err = io.ReadFull(p.reader, packetData)
if err != nil {
return err
}
_, err = p.writer.Write(packetData)
} else {
// 使用可复用缓冲区
_, err = io.ReadFull(p.reader, p.buffer[:packetLength])
if err != nil {
return err
}
_, err = p.writer.Write(p.buffer[:packetLength])
}
if err != nil {
return err
}
}
}
// RTPProcessor RTP处理器
type RTPProcessor struct {
reader IRTPReader
writer RTPWriter
config *ForwardConfig
sendBuffer util.Buffer // 可复用的发送缓冲区
}
// NewRTPProcessor 创建RTP处理器
func NewRTPProcessor(reader IRTPReader, writer RTPWriter, config *ForwardConfig) *RTPProcessor {
return &RTPProcessor{
reader: reader,
writer: writer,
config: config,
sendBuffer: util.Buffer{}, // 初始化可复用缓冲区
}
}
// Process 处理RTP包
func (p *RTPProcessor) Process(ctx context.Context) error {
var packet rtp.Packet
var sequenceNumber uint16
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
err := p.reader.Read(&packet)
if err != nil {
if err == io.EOF {
return nil
}
return fmt.Errorf("read RTP packet failed: %v", err)
}
// 检查源SSRC过滤
if p.config.Source.SSRC != 0 && packet.SSRC != p.config.Source.SSRC {
continue
}
// 保存原始序列号用于分片包
sequenceNumber = packet.SequenceNumber
// 检查是否需要分片
if len(packet.Payload) > (1460 - packet.MarshalSize()) {
err = p.processFragmentedPacket(&packet, sequenceNumber)
} else {
err = p.processSinglePacket(&packet)
}
if err != nil {
return err
}
}
}
// processSinglePacket 处理单个包
func (p *RTPProcessor) processSinglePacket(packet *rtp.Packet) error {
if p.config.Target.SSRC != 0 {
packet.SSRC = p.config.Target.SSRC
}
return p.writer.WritePacket(packet)
}
// processFragmentedPacket 处理分片包
func (p *RTPProcessor) processFragmentedPacket(packet *rtp.Packet, sequenceNumber uint16) error {
maxPayloadSize := 1460 - 12 // RTP头通常是12字节
payload := packet.Payload
// 标记第一个包
marker := packet.Marker
packet.Marker = false
for i := 0; i < len(payload); i += int(maxPayloadSize) {
end := i + int(maxPayloadSize)
if end > len(payload) {
end = len(payload)
// 最后一个分片,恢复原始标记
packet.Marker = marker
}
// 创建包含分片的新包
fragmentPacket := *packet
if p.config.Target.SSRC != 0 {
fragmentPacket.SSRC = p.config.Target.SSRC
}
fragmentPacket.SequenceNumber = sequenceNumber
sequenceNumber++
fragmentPacket.Payload = payload[i:end]
err := p.writer.WritePacket(&fragmentPacket)
if err != nil {
return fmt.Errorf("write RTP fragment failed: %v", err)
}
}
return nil
}
// Forward 执行转发
func (f *Forwarder) Forward(ctx context.Context) error {
// 建立连接
err := f.setupConnections()
if err != nil {
return err
}
defer f.cleanup()
// 检查是否为中继模式
if f.config.Relay {
processor := NewRelayProcessor(f.source, f.target, f.config.Source.Mode, f.config.Target.Mode)
return processor.Process(ctx)
}
// RTP处理模式
reader := f.createRTPReader()
writer := f.createRTPWriter()
processor := NewRTPProcessor(reader, writer, f.config)
return processor.Process(ctx)
}

View File

@@ -0,0 +1,322 @@
package rtp
import (
"fmt"
"testing"
"github.com/pion/rtp"
)
func TestForwardConfig(t *testing.T) {
config := &ForwardConfig{
Source: ConnectionConfig{
IP: "127.0.0.1",
Port: 8080,
Mode: StreamModeUDP,
SSRC: 12345,
},
Target: ConnectionConfig{
IP: "127.0.0.1",
Port: 8081,
Mode: StreamModeTCPActive,
SSRC: 67890,
},
Relay: false,
}
if config.Source.IP != "127.0.0.1" {
t.Errorf("Expected source IP 127.0.0.1, got %s", config.Source.IP)
}
if config.Source.Port != 8080 {
t.Errorf("Expected source port 8080, got %d", config.Source.Port)
}
if config.Source.Mode != StreamModeUDP {
t.Errorf("Expected source mode UDP, got %s", config.Source.Mode)
}
if config.Source.SSRC != 12345 {
t.Errorf("Expected source SSRC 12345, got %d", config.Source.SSRC)
}
if config.Target.IP != "127.0.0.1" {
t.Errorf("Expected target IP 127.0.0.1, got %s", config.Target.IP)
}
if config.Target.Port != 8081 {
t.Errorf("Expected target port 8081, got %d", config.Target.Port)
}
if config.Target.Mode != StreamModeTCPActive {
t.Errorf("Expected target mode TCP-ACTIVE, got %s", config.Target.Mode)
}
if config.Target.SSRC != 67890 {
t.Errorf("Expected target SSRC 67890, got %d", config.Target.SSRC)
}
if config.Relay {
t.Error("Expected relay to be false")
}
}
func TestNewForwarder(t *testing.T) {
config := &ForwardConfig{
Source: ConnectionConfig{
IP: "127.0.0.1",
Port: 8080,
Mode: StreamModeUDP,
SSRC: 12345,
},
Target: ConnectionConfig{
IP: "127.0.0.1",
Port: 8081,
Mode: StreamModeTCPActive,
SSRC: 67890,
},
Relay: false,
}
forwarder := NewForwarder(config)
if forwarder.config != config {
t.Error("Expected forwarder config to match input config")
}
if forwarder.source != nil {
t.Error("Expected source connection to be nil initially")
}
if forwarder.target != nil {
t.Error("Expected target connection to be nil initially")
}
}
func TestConnectionConfig(t *testing.T) {
config := ConnectionConfig{
IP: "192.168.1.100",
Port: 9000,
Mode: StreamModeTCPPassive,
SSRC: 54321,
}
if config.IP != "192.168.1.100" {
t.Errorf("Expected IP 192.168.1.100, got %s", config.IP)
}
if config.Port != 9000 {
t.Errorf("Expected port 9000, got %d", config.Port)
}
if config.Mode != StreamModeTCPPassive {
t.Errorf("Expected mode TCP-PASSIVE, got %s", config.Mode)
}
if config.SSRC != 54321 {
t.Errorf("Expected SSRC 54321, got %d", config.SSRC)
}
}
func TestRTPWriter(t *testing.T) {
// 创建一个模拟的writer
mockWriter := &mockWriter{}
writer := NewRTPWriter(mockWriter, StreamModeUDP)
if writer == nil {
t.Error("Expected RTPWriter to be created")
}
// 测试UDP模式的WriteRaw
data := []byte{1, 2, 3, 4}
err := writer.WriteRaw(data)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(mockWriter.data) != 1 {
t.Errorf("Expected 1 write, got %d", len(mockWriter.data))
}
if len(mockWriter.data[0]) != 4 {
t.Errorf("Expected 4 bytes written, got %d", len(mockWriter.data[0]))
}
}
// mockWriter 用于测试的模拟writer
type mockWriter struct {
data [][]byte
}
func (w *mockWriter) Write(data []byte) (int, error) {
w.data = append(w.data, append([]byte{}, data...))
return len(data), nil
}
func TestRelayProcessor(t *testing.T) {
// 创建模拟的reader和writer
mockReader := &mockReader{data: [][]byte{{1, 2, 3}, {4, 5, 6}}}
mockWriter := &mockWriter{}
processor := NewRelayProcessor(mockReader, mockWriter, StreamModeUDP, StreamModeTCPActive)
if processor.reader != mockReader {
t.Error("Expected reader to match input")
}
if processor.writer != mockWriter {
t.Error("Expected writer to match input")
}
if processor.sourceMode != StreamModeUDP {
t.Errorf("Expected source mode UDP, got %s", processor.sourceMode)
}
if processor.targetMode != StreamModeTCPActive {
t.Errorf("Expected target mode TCP-ACTIVE, got %s", processor.targetMode)
}
}
// mockReader 用于测试的模拟reader
type mockReader struct {
data [][]byte
pos int
}
func (r *mockReader) Read(buf []byte) (int, error) {
if r.pos >= len(r.data) {
return 0, nil // EOF
}
data := r.data[r.pos]
r.pos++
copy(buf, data)
return len(data), nil
}
func TestConnectionTypes(t *testing.T) {
// 测试ConnectionConfig
config := ConnectionConfig{
IP: "127.0.0.1",
Port: 8080,
Mode: StreamModeUDP,
SSRC: 12345,
}
if config.Mode != StreamModeUDP {
t.Errorf("Expected mode UDP, got %s", config.Mode)
}
if config.SSRC != 12345 {
t.Errorf("Expected SSRC 12345, got %d", config.SSRC)
}
}
func TestConnectionDirection(t *testing.T) {
// 测试连接方向的概念
config := &ForwardConfig{
Source: ConnectionConfig{
IP: "127.0.0.1",
Port: 8080,
Mode: StreamModeUDP,
SSRC: 12345,
},
Target: ConnectionConfig{
IP: "127.0.0.1",
Port: 8081,
Mode: StreamModeTCPActive,
SSRC: 67890,
},
Relay: false,
}
forwarder := NewForwarder(config)
// 验证配置正确性
if forwarder.config.Source.SSRC != 12345 {
t.Errorf("Expected source SSRC 12345, got %d", forwarder.config.Source.SSRC)
}
if forwarder.config.Target.SSRC != 67890 {
t.Errorf("Expected target SSRC 67890, got %d", forwarder.config.Target.SSRC)
}
// 验证连接类型
if forwarder.source != nil {
t.Error("Expected source connection to be nil initially")
}
if forwarder.target != nil {
t.Error("Expected target connection to be nil initially")
}
}
func TestBufferReuse(t *testing.T) {
// 测试RTPWriter的buffer复用
writer := NewRTPWriter(&mockWriter{}, StreamModeUDP)
// 多次写入应该复用同一个buffer
for i := 0; i < 10; i++ {
packet := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: uint16(i),
Timestamp: uint32(i * 1000),
SSRC: uint32(i),
},
Payload: []byte(fmt.Sprintf("test packet %d", i)),
}
err := writer.WritePacket(packet)
if err != nil {
t.Errorf("WritePacket failed: %v", err)
}
}
// 测试RelayProcessor的buffer复用
processor := NewRelayProcessor(&mockReader{data: [][]byte{{1, 2, 3}, {4, 5, 6}}}, &mockWriter{}, StreamModeUDP, StreamModeTCPActive)
// 验证buffer字段存在
if processor.buffer == nil {
t.Error("Expected buffer to be initialized")
}
if len(processor.buffer) != 1460 {
t.Errorf("Expected buffer size 1460, got %d", len(processor.buffer))
}
if processor.header == nil {
t.Error("Expected header to be initialized")
}
if len(processor.header) != 2 {
t.Errorf("Expected header size 2, got %d", len(processor.header))
}
}
func TestRTPProcessorBufferReuse(t *testing.T) {
// 测试RTPProcessor的buffer复用
config := &ForwardConfig{
Source: ConnectionConfig{
IP: "127.0.0.1",
Port: 8080,
Mode: StreamModeUDP,
SSRC: 12345,
},
Target: ConnectionConfig{
IP: "127.0.0.1",
Port: 8081,
Mode: StreamModeTCPActive,
SSRC: 67890,
},
Relay: false,
}
processor := NewRTPProcessor(nil, nil, config)
// 验证sendBuffer字段存在
if processor.sendBuffer == nil {
t.Error("Expected sendBuffer to be initialized")
}
}

View File

@@ -0,0 +1,48 @@
package rtp
import (
"time"
m7s "m7s.live/v5"
"m7s.live/v5/pkg/util"
)
type DumpPuller struct {
m7s.HTTPFilePuller
}
func (p *DumpPuller) Start() (err error) {
p.PullJob.PublishConfig.PubType = m7s.PublishTypeReplay
return p.HTTPFilePuller.Start()
}
func (p *DumpPuller) Run() (err error) {
pub := p.PullJob.Publisher
var receiver PSReceiver
receiver.Publisher = pub
receiver.StreamMode = StreamModeManual
receiver.OnStart(func() {
go func() {
var t uint16
for l := make([]byte, 6); pub.State != m7s.PublisherStateDisposed; time.Sleep(time.Millisecond * time.Duration(t)) {
_, err = p.Read(l)
if err != nil {
return
}
payloadLen := util.ReadBE[int](l[:4])
payload := make([]byte, payloadLen)
t = util.ReadBE[uint16](l[4:])
_, err = p.Read(payload)
if err != nil {
return
}
select {
case receiver.RTPMouth <- payload:
case <-pub.Done():
return
}
}
}()
})
return p.RunTask(&receiver)
}

140
plugin/rtp/pkg/reader.go Normal file
View File

@@ -0,0 +1,140 @@
package rtp
import (
"errors"
"fmt"
"io"
"github.com/pion/rtp"
"m7s.live/v5/pkg/util"
)
type IRTPReader interface {
Read(packet *rtp.Packet) (err error)
}
type RTPUDPReader struct {
io.Reader
buf [MTUSize]byte
}
func NewRTPUDPReader(r io.Reader) *RTPUDPReader {
return &RTPUDPReader{Reader: r}
}
func (r *RTPUDPReader) Read(packet *rtp.Packet) (err error) {
n, err := r.Reader.Read(r.buf[:])
if err != nil {
return err
}
return packet.Unmarshal(r.buf[:n])
}
type RTPTCPReader struct {
*util.BufReader
buffer util.Buffer
}
func NewRTPTCPReader(r io.Reader) *RTPTCPReader {
return &RTPTCPReader{BufReader: util.NewBufReader(r)}
}
func (r *RTPTCPReader) Read(packet *rtp.Packet) (err error) {
var rtplen uint32
var b0, b1 byte
rtplen, err = r.ReadBE32(2)
if err != nil {
return
}
var mem util.Memory
mem, err = r.ReadBytes(int(rtplen))
if err != nil {
return
}
mr := mem.NewReader()
mr.ReadByteTo(&b0, &b1)
if b0>>6 != 2 || b0&0x0f > 15 || b1&0x7f > 127 {
// TODO:
panic(fmt.Errorf("invalid rtp packet: %x", r.buffer[:2]))
} else {
r.buffer.Relloc(int(rtplen))
mem.CopyTo(r.buffer)
err = packet.Unmarshal(r.buffer)
}
return
}
type RTPPayloadReader struct {
IRTPReader
rtp.Packet
SSRC uint32 // RTP SSRC
buffer util.MemoryReader
}
// func NewTCPRTPPayloadReaderForFeed() *RTPPayloadReader {
// r := &RTPPayloadReader{}
// r.IRTPReader = &RTPTCPReader{
// BufReader: util.NewBufReaderChan(10),
// }
// r.buffer.Memory = &util.Memory{}
// return r
// }
func NewRTPPayloadReader(t IRTPReader) *RTPPayloadReader {
r := &RTPPayloadReader{}
r.IRTPReader = t
r.buffer.Memory = &util.Memory{}
return r
}
func (r *RTPPayloadReader) Read(buf []byte) (n int, err error) {
// 如果缓冲区中有数据,先读取缓冲区中的数据
if r.buffer.Length > 0 {
n, _ = r.buffer.Read(buf)
return n, nil
}
// 读取新的RTP包
for {
lastSeq := r.SequenceNumber
err = r.IRTPReader.Read(&r.Packet)
if err != nil {
err = errors.Join(err, fmt.Errorf("failed to read RTP packet"))
return
}
// 检查SSRC是否匹配
if r.SSRC != 0 && r.SSRC != r.Packet.SSRC {
// SSRC不匹配继续读取下一个包
continue
}
// 检查序列号是否连续
if lastSeq == 0 || r.SequenceNumber == lastSeq+1 {
// 序列号连续,处理当前包的数据
if lbuf, lpayload := len(buf), len(r.Payload); lbuf >= lpayload {
// 缓冲区足够大,可以容纳整个负载
copy(buf, r.Payload)
n += lpayload
// 如果缓冲区还有剩余空间,继续读取下一个包
if lbuf > lpayload {
var nextn int
nextn, err = r.Read(buf[lpayload:])
if err != nil && err != io.EOF {
return n, err
}
n += nextn
}
return
} else {
// 缓冲区不够大,只复制部分数据,将剩余数据放入缓冲区
n += lbuf
copy(buf, r.Payload[:lbuf])
r.buffer.PushOne(r.Payload[lbuf:])
r.buffer.Length = lpayload - lbuf
return
}
}
}
}

View File

@@ -0,0 +1,113 @@
package rtp
import (
"bytes"
"fmt"
"io"
"testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/assert"
)
func TestRTPPayloadReaderDebug(t *testing.T) {
// 创建简单的测试数据
originalData := []byte("Hello World")
// 生成RTP包
packets := generateRTPPacketsForDebug(originalData, 0, 1000)
fmt.Printf("Generated %d RTP packets\n", len(packets))
for i, packet := range packets {
fmt.Printf("Packet %d: Seq=%d, Payload=%s, PayloadLen=%d\n", i, packet.SequenceNumber, string(packet.Payload), len(packet.Payload))
}
// 将RTP包序列化到缓冲区
var buf bytes.Buffer
for _, packet := range packets {
data, err := packet.Marshal()
assert.NoError(t, err)
fmt.Printf("Marshaled packet length: %d\n", len(data))
// 写入RTP包长度和数据
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
buf.Write(data)
}
fmt.Printf("Buffer size: %d\n", buf.Len())
fmt.Printf("Original data length: %d\n", len(originalData))
fmt.Printf("Original data: %s\n", string(originalData))
// 使用RTPPayloadReader读取数据
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
// 逐步读取数据
allData := make([]byte, 0)
bufSize := 3
for len(allData) < len(originalData) {
result := make([]byte, bufSize)
fmt.Printf("Buffer length before read: %d\n", reader.buffer.Length)
fmt.Printf("Buffer count before read: %d\n", reader.buffer.Count())
n, err := reader.Read(result)
fmt.Printf("Read returned: n=%d, err=%v\n", n, err)
fmt.Printf("Read data: %s\n", string(result[:n]))
fmt.Printf("Buffer length after read: %d\n", reader.buffer.Length)
fmt.Printf("Buffer count after read: %d\n", reader.buffer.Count())
if err != nil {
if err == io.EOF {
break
}
assert.NoError(t, err)
}
if n == 0 {
break
}
allData = append(allData, result[:n]...)
fmt.Printf("All data so far: %s\n", string(allData))
}
fmt.Printf("Final data length: %d\n", len(allData))
fmt.Printf("Final data: %s\n", string(allData))
// 验证数据是否匹配
assert.Equal(t, originalData, allData)
}
func generateRTPPacketsForDebug(data []byte, ssrc uint32, initialSeq uint16) []*rtp.Packet {
var packets []*rtp.Packet
seq := initialSeq
maxPayloadSize := 100
for len(data) > 0 {
// 确定当前包的负载大小
payloadSize := maxPayloadSize
if len(data) < payloadSize {
payloadSize = len(data)
}
// 创建RTP包
packet := &rtp.Packet{
Header: rtp.Header{
Version: 2,
Padding: false,
Extension: false,
Marker: false,
PayloadType: 96,
SequenceNumber: seq,
Timestamp: 123456,
SSRC: ssrc,
},
Payload: data[:payloadSize],
}
packets = append(packets, packet)
// 更新数据和序列号
data = data[payloadSize:]
seq++
}
return packets
}

View File

@@ -0,0 +1,153 @@
package rtp
import (
"bytes"
"io"
"testing"
"github.com/pion/rtp"
"github.com/stretchr/testify/assert"
)
func TestRTPPayloadReader(t *testing.T) {
// 创建测试数据
originalData := []byte("Hello, World! This is a test payload for RTP packets.")
// 生成RTP包
packets := generateRTPPackets(originalData, 0, 1000)
// 将RTP包序列化到缓冲区
var buf bytes.Buffer
for _, packet := range packets {
data, err := packet.Marshal()
assert.NoError(t, err)
// 写入RTP包长度和数据
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
buf.Write(data)
}
// 使用RTPPayloadReader读取数据
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
// 读取所有数据
result := make([]byte, len(originalData))
n, err := reader.Read(result)
assert.NoError(t, err)
assert.Equal(t, len(originalData), n)
// 验证数据是否匹配
assert.Equal(t, originalData, result)
}
func TestRTPPayloadReaderWithBuffer(t *testing.T) {
// 创建测试数据
originalData := []byte("This is a longer test payload that will be split across multiple RTP packets to test the buffering functionality of the RTPPayloadReader.")
// 生成RTP包
packets := generateRTPPackets(originalData, 0, 2000)
// 将RTP包序列化到缓冲区
var buf bytes.Buffer
for _, packet := range packets {
data, err := packet.Marshal()
assert.NoError(t, err)
// 写入RTP包长度和数据
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
buf.Write(data)
}
// 使用RTPPayloadReader读取数据
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
// 使用较小的缓冲区读取数据
allData := make([]byte, 0)
bufSize := 10
for len(allData) < len(originalData) {
result := make([]byte, bufSize)
n, err := reader.Read(result)
if err != nil {
if err == io.EOF {
break
}
assert.NoError(t, err)
}
if n == 0 {
break
}
allData = append(allData, result[:n]...)
}
// 验证数据是否匹配
assert.Equal(t, originalData, allData)
}
func TestRTPPayloadReaderSimple(t *testing.T) {
// 创建简单的测试数据
originalData := []byte("Hello World")
// 生成RTP包
packets := generateRTPPackets(originalData, 0, 1000)
// 将RTP包序列化到缓冲区
var buf bytes.Buffer
for _, packet := range packets {
data, err := packet.Marshal()
assert.NoError(t, err)
// 写入RTP包长度和数据
buf.Write([]byte{byte(len(data) >> 8), byte(len(data))})
buf.Write(data)
}
// 使用RTPPayloadReader读取数据
reader := NewRTPPayloadReader(NewRTPTCPReader(&buf))
// 读取所有数据
result := make([]byte, len(originalData))
n, err := reader.Read(result)
assert.NoError(t, err)
assert.Equal(t, len(originalData), n)
// 验证数据是否匹配
assert.Equal(t, originalData, result)
}
func generateRTPPackets(data []byte, ssrc uint32, initialSeq uint16) []*rtp.Packet {
var packets []*rtp.Packet
seq := initialSeq
maxPayloadSize := 100
for len(data) > 0 {
// 确定当前包的负载大小
payloadSize := maxPayloadSize
if len(data) < payloadSize {
payloadSize = len(data)
}
// 创建RTP包
packet := &rtp.Packet{
Header: rtp.Header{
Version: 2,
Padding: false,
Extension: false,
Marker: false,
PayloadType: 96,
SequenceNumber: seq,
Timestamp: 123456,
SSRC: ssrc,
},
Payload: data[:payloadSize],
}
packets = append(packets, packet)
// 更新数据和序列号
data = data[payloadSize:]
seq++
}
return packets
}

View File

@@ -4,13 +4,14 @@ import (
"bufio"
"encoding/binary"
"io"
"m7s.live/v5/pkg/util"
"net"
"m7s.live/v5/pkg/util"
)
type TCP net.TCPConn
func (t *TCP) Read(onRTP func(util.Buffer) error) (err error) {
func (t *TCP) ReadRTP(onRTP func(util.Buffer) error) (err error) {
reader := bufio.NewReader((*net.TCPConn)(t))
rtpLenBuf := make([]byte, 4)
buffer := make(util.Buffer, 1024)

View File

@@ -0,0 +1,148 @@
package rtp
import (
"errors"
"fmt"
"io"
"net"
"strings"
"github.com/pion/rtp"
mpegps "m7s.live/v5/pkg/format/ps"
"m7s.live/v5/pkg/task"
"m7s.live/v5/pkg/util"
)
var ErrRTPReceiveLost = errors.New("rtp receive lost")
// 数据流传输模式UDP:udp传输、TCP-ACTIVEtcp主动模式、TCP-PASSIVEtcp被动模式、MANUAL手动模式
type StreamMode string
const (
StreamModeUDP StreamMode = "UDP"
StreamModeTCPActive StreamMode = "TCP-ACTIVE"
StreamModeTCPPassive StreamMode = "TCP-PASSIVE"
StreamModeManual StreamMode = "MANUAL"
)
type ChanReader chan []byte
func (r ChanReader) Read(buf []byte) (n int, err error) {
b, ok := <-r
if !ok {
return 0, io.EOF
}
copy(buf, b)
return len(b), nil
}
type RTPChanReader chan []byte
func (r RTPChanReader) Read(packet *rtp.Packet) (err error) {
b, ok := <-r
if !ok {
return io.EOF
}
return packet.Unmarshal(b)
}
func (r RTPChanReader) Close() error {
close(r)
return nil
}
type Receiver struct {
task.Task
*util.BufReader
ListenAddr string
net.Listener
StreamMode StreamMode
SSRC uint32 // RTP SSRC
RTPMouth chan []byte
}
type PSReceiver struct {
Receiver
mpegps.MpegPsDemuxer
}
func (p *PSReceiver) Start() error {
err := p.Receiver.Start()
if err == nil {
p.Using(p.Publisher)
}
return err
}
func (p *PSReceiver) Run() error {
p.MpegPsDemuxer.Allocator = util.NewScalableMemoryAllocator(1 << util.MinPowerOf2)
p.Using(p.MpegPsDemuxer.Allocator)
return p.MpegPsDemuxer.Feed(p.BufReader)
}
func (p *Receiver) Start() (err error) {
var rtpReader *RTPPayloadReader
switch p.StreamMode {
case StreamModeTCPActive:
// TCP主动模式不需要监听直接返回
p.Info("TCP-ACTIVE mode, no need to listen")
addr := p.ListenAddr
if !strings.Contains(addr, ":") {
addr = ":" + addr
}
if strings.HasPrefix(addr, ":") {
p.Error("invalid address, missing IP", "addr", addr)
return fmt.Errorf("invalid address %s, missing IP", addr)
}
p.Info("TCP-ACTIVE mode, connecting to device", "addr", addr)
var conn net.Conn
conn, err = net.Dial("tcp", addr)
if err != nil {
p.Error("connect to device failed", "err", err)
return err
}
p.OnStop(conn.Close)
rtpReader = NewRTPPayloadReader(NewRTPTCPReader(conn))
p.BufReader = util.NewBufReader(rtpReader)
case StreamModeTCPPassive:
var conn net.Conn
if p.SSRC == 0 {
p.Info("start new listener", "addr", p.ListenAddr)
p.Listener, err = net.Listen("tcp4", p.ListenAddr)
if err != nil {
p.Error("start listen", "err", err)
return errors.New("start listen,err" + err.Error())
}
p.OnStop(p.Listener.Close)
conn, err = p.Accept()
} else {
//TODO: 公用监听端口
}
if err != nil {
p.Error("accept", "err", err)
return err
}
p.OnStop(conn.Close)
rtpReader = NewRTPPayloadReader(NewRTPTCPReader(conn))
p.BufReader = util.NewBufReader(rtpReader)
case StreamModeUDP:
var udpAddr *net.UDPAddr
udpAddr, err = net.ResolveUDPAddr("udp4", p.ListenAddr)
if err != nil {
return
}
var conn net.Conn
conn, err = net.ListenUDP("udp4", udpAddr)
if err != nil {
return
}
rtpReader = NewRTPPayloadReader(NewRTPUDPReader(conn))
p.BufReader = util.NewBufReader(rtpReader)
case StreamModeManual:
p.RTPMouth = make(chan []byte)
rtpReader = NewRTPPayloadReader((RTPChanReader)(p.RTPMouth))
p.BufReader = util.NewBufReader(rtpReader)
}
p.Using(rtpReader, p.BufReader)
return
}

View File

@@ -1,12 +1,13 @@
package rtp
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"slices"
"strings"
"time"
"unsafe"
"github.com/deepch/vdk/codec/h264parser"
"github.com/deepch/vdk/codec/h265parser"
@@ -26,29 +27,27 @@ type (
}
H264Ctx struct {
H26xCtx
codec.H264Ctx
*codec.H264Ctx
}
H265Ctx struct {
H26xCtx
codec.H265Ctx
*codec.H265Ctx
DONL bool
}
AV1Ctx struct {
RTPCtx
codec.AV1Ctx
*codec.AV1Ctx
}
VP9Ctx struct {
RTPCtx
}
Video struct {
VideoFrame struct {
RTPData
CTS time.Duration
DTS time.Duration
}
)
var (
_ IAVFrame = (*Video)(nil)
_ IAVFrame = (*VideoFrame)(nil)
_ IVideoCodecCtx = (*H264Ctx)(nil)
_ IVideoCodecCtx = (*H265Ctx)(nil)
_ IVideoCodecCtx = (*AV1Ctx)(nil)
@@ -62,207 +61,188 @@ const (
MTUSize = 1460
)
func (r *Video) Parse(t *AVTrack) (err error) {
switch r.MimeType {
case webrtc.MimeTypeH264:
var ctx *H264Ctx
if t.ICodecCtx != nil {
ctx = t.ICodecCtx.(*H264Ctx)
} else {
ctx = &H264Ctx{}
ctx.parseFmtpLine(r.RTPCodecParameters)
var sps, pps []byte
//packetization-mode=1; sprop-parameter-sets=J2QAKaxWgHgCJ+WagICAgQ==,KO48sA==; profile-level-id=640029
if sprop, ok := ctx.Fmtp["sprop-parameter-sets"]; ok {
if sprops := strings.Split(sprop, ","); len(sprops) == 2 {
if sps, err = base64.StdEncoding.DecodeString(sprops[0]); err != nil {
return
}
if pps, err = base64.StdEncoding.DecodeString(sprops[1]); err != nil {
return
}
}
if ctx.CodecData, err = h264parser.NewCodecDataFromSPSAndPPS(sps, pps); err != nil {
return
}
}
t.ICodecCtx = ctx
}
if t.Value.Raw, err = r.Demux(ctx); err != nil {
return
}
pts := r.Packets[0].Timestamp
var hasSPSPPS bool
func (r *VideoFrame) Parse(data IAVFrame) (err error) {
input := data.(*VideoFrame)
r.Packets = append(r.Packets[:0], input.Packets...)
return
}
func (r *VideoFrame) Recycle() {
r.RecyclableMemory.Recycle()
r.Packets.Reset()
}
func (r *VideoFrame) CheckCodecChange() (err error) {
if len(r.Packets) == 0 {
return ErrSkip
}
old := r.ICodecCtx
// 解复用数据
if err = r.Demux(); err != nil {
return
}
// 处理时间戳和序列号
pts := r.Packets[0].Timestamp
nalus := r.Raw.(*Nalus)
switch ctx := old.(type) {
case *H264Ctx:
dts := ctx.dtsEst.Feed(pts)
r.DTS = time.Duration(dts) * time.Millisecond / 90
r.CTS = time.Duration(pts-dts) * time.Millisecond / 90
for _, nalu := range t.Value.Raw.(Nalus) {
switch codec.ParseH264NALUType(nalu.Buffers[0][0]) {
r.SetDTS(time.Duration(dts))
r.SetPTS(time.Duration(pts))
// 检查 SPS、PPS 和 IDR 帧
var sps, pps []byte
var hasSPSPPS bool
for nalu := range nalus.RangePoint {
nalType := codec.ParseH264NALUType(nalu.Buffers[0][0])
switch nalType {
case h264parser.NALU_SPS:
ctx.RecordInfo.SPS = [][]byte{nalu.ToBytes()}
if ctx.SPSInfo, err = h264parser.ParseSPS(ctx.SPS()); err != nil {
return
}
sps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h264parser.NALU_PPS:
hasSPSPPS = true
ctx.RecordInfo.PPS = [][]byte{nalu.ToBytes()}
if ctx.CodecData, err = h264parser.NewCodecDataFromSPSAndPPS(ctx.RecordInfo.SPS[0], ctx.RecordInfo.PPS[0]); err != nil {
return
}
pps = nalu.ToBytes()
defer nalus.Remove(nalu)
case codec.NALU_IDR_Picture:
t.Value.IDR = true
r.IDR = true
}
}
if len(ctx.CodecData.Record) == 0 {
return ErrSkip
}
if t.Value.IDR && !hasSPSPPS {
spsRTP := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
// 如果发现新的 SPS/PPS更新编解码器上下文
if hasSPSPPS = sps != nil && pps != nil; hasSPSPPS && (len(ctx.Record) == 0 || !bytes.Equal(sps, ctx.SPS()) || !bytes.Equal(pps, ctx.PPS())) {
var newCodecData h264parser.CodecData
if newCodecData, err = h264parser.NewCodecDataFromSPSAndPPS(sps, pps); err != nil {
return
}
newCtx := &H264Ctx{
H26xCtx: ctx.H26xCtx,
H264Ctx: &codec.H264Ctx{
CodecData: newCodecData,
},
Payload: ctx.SPS(),
}
ppsRTP := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.PPS(),
// 保持原有的 RTP 参数
if oldCtx, ok := old.(*H264Ctx); ok {
newCtx.RTPCtx = oldCtx.RTPCtx
}
r.ICodecCtx = newCtx
} else {
// 如果是 IDR 帧但没有 SPS/PPS需要插入
if r.IDR && len(ctx.SPS()) > 0 && len(ctx.PPS()) > 0 {
spsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.SPS(),
}
ppsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.PPS(),
}
r.Packets = slices.Insert(r.Packets, 0, spsRTP, ppsRTP)
}
r.Packets = slices.Insert(r.Packets, 0, spsRTP, ppsRTP)
}
for _, p := range r.Packets {
// 更新序列号
for p := range r.Packets.RangePoint {
p.SequenceNumber = ctx.seq
ctx.seq++
}
case webrtc.MimeTypeH265:
var ctx *H265Ctx
if t.ICodecCtx != nil {
ctx = t.ICodecCtx.(*H265Ctx)
} else {
ctx = &H265Ctx{}
ctx.parseFmtpLine(r.RTPCodecParameters)
var vps, sps, pps []byte
if sprop_sps, ok := ctx.Fmtp["sprop-sps"]; ok {
if sps, err = base64.StdEncoding.DecodeString(sprop_sps); err != nil {
return
}
}
if sprop_pps, ok := ctx.Fmtp["sprop-pps"]; ok {
if pps, err = base64.StdEncoding.DecodeString(sprop_pps); err != nil {
return
}
}
if sprop_vps, ok := ctx.Fmtp["sprop-vps"]; ok {
if vps, err = base64.StdEncoding.DecodeString(sprop_vps); err != nil {
return
}
}
if len(vps) > 0 && len(sps) > 0 && len(pps) > 0 {
if ctx.CodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(vps, sps, pps); err != nil {
return
}
}
if sprop_donl, ok := ctx.Fmtp["sprop-max-don-diff"]; ok {
if sprop_donl != "0" {
ctx.DONL = true
}
}
t.ICodecCtx = ctx
}
if t.Value.Raw, err = r.Demux(ctx); err != nil {
return
}
pts := r.Packets[0].Timestamp
case *H265Ctx:
dts := ctx.dtsEst.Feed(pts)
r.DTS = time.Duration(dts) * time.Millisecond / 90
r.CTS = time.Duration(pts-dts) * time.Millisecond / 90
r.SetDTS(time.Duration(dts))
r.SetPTS(time.Duration(pts))
// 检查 VPS、SPS、PPS 和 IDR 帧
var vps, sps, pps []byte
var hasVPSSPSPPS bool
for _, nalu := range t.Value.Raw.(Nalus) {
for nalu := range nalus.RangePoint {
switch codec.ParseH265NALUType(nalu.Buffers[0][0]) {
case h265parser.NAL_UNIT_VPS:
ctx = &H265Ctx{}
ctx.RecordInfo.VPS = [][]byte{nalu.ToBytes()}
ctx.RTPCodecParameters = *r.RTPCodecParameters
t.ICodecCtx = ctx
vps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h265parser.NAL_UNIT_SPS:
ctx.RecordInfo.SPS = [][]byte{nalu.ToBytes()}
if ctx.SPSInfo, err = h265parser.ParseSPS(ctx.SPS()); err != nil {
return
}
sps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h265parser.NAL_UNIT_PPS:
hasVPSSPSPPS = true
ctx.RecordInfo.PPS = [][]byte{nalu.ToBytes()}
if ctx.CodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(ctx.RecordInfo.VPS[0], ctx.RecordInfo.SPS[0], ctx.RecordInfo.PPS[0]); err != nil {
return
}
pps = nalu.ToBytes()
defer nalus.Remove(nalu)
case h265parser.NAL_UNIT_CODED_SLICE_BLA_W_LP,
h265parser.NAL_UNIT_CODED_SLICE_BLA_W_RADL,
h265parser.NAL_UNIT_CODED_SLICE_BLA_N_LP,
h265parser.NAL_UNIT_CODED_SLICE_IDR_W_RADL,
h265parser.NAL_UNIT_CODED_SLICE_IDR_N_LP,
h265parser.NAL_UNIT_CODED_SLICE_CRA:
t.Value.IDR = true
r.IDR = true
}
}
if len(ctx.CodecData.Record) == 0 {
return ErrSkip
// 如果发现新的 VPS/SPS/PPS更新编解码器上下文
if hasVPSSPSPPS = vps != nil && sps != nil && pps != nil; hasVPSSPSPPS && (len(ctx.Record) == 0 || !bytes.Equal(vps, ctx.VPS()) || !bytes.Equal(sps, ctx.SPS()) || !bytes.Equal(pps, ctx.PPS())) {
var newCodecData h265parser.CodecData
if newCodecData, err = h265parser.NewCodecDataFromVPSAndSPSAndPPS(vps, sps, pps); err != nil {
return
}
newCtx := &H265Ctx{
H26xCtx: ctx.H26xCtx,
H265Ctx: &codec.H265Ctx{
CodecData: newCodecData,
},
}
// 保持原有的 RTP 参数
if oldCtx, ok := old.(*H265Ctx); ok {
newCtx.RTPCtx = oldCtx.RTPCtx
}
r.ICodecCtx = newCtx
} else {
// 如果是 IDR 帧但没有 VPS/SPS/PPS需要插入
if r.IDR && len(ctx.VPS()) > 0 && len(ctx.SPS()) > 0 && len(ctx.PPS()) > 0 {
vpsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.VPS(),
}
spsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.SPS(),
}
ppsRTP := rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.PPS(),
}
r.Packets = slices.Insert(r.Packets, 0, vpsRTP, spsRTP, ppsRTP)
}
}
if t.Value.IDR && !hasVPSSPSPPS {
vpsRTP := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.VPS(),
}
spsRTP := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.SPS(),
}
ppsRTP := &rtp.Packet{
Header: rtp.Header{
Version: 2,
SequenceNumber: ctx.SequenceNumber,
Timestamp: pts,
SSRC: ctx.SSRC,
PayloadType: uint8(ctx.PayloadType),
},
Payload: ctx.PPS(),
}
r.Packets = slices.Insert(r.Packets, 0, vpsRTP, spsRTP, ppsRTP)
}
for _, p := range r.Packets {
// 更新序列号
for p := range r.Packets.RangePoint {
p.SequenceNumber = ctx.seq
ctx.seq++
}
case webrtc.MimeTypeVP9:
// var ctx RTPVP9Ctx
// ctx.RTPCodecParameters = *r.RTPCodecParameters
// codecCtx = &ctx
case webrtc.MimeTypeAV1:
var ctx AV1Ctx
ctx.RTPCodecParameters = *r.RTPCodecParameters
t.ICodecCtx = &ctx
default:
err = ErrUnsupportCodec
}
return
}
@@ -279,26 +259,45 @@ func (av1 *AV1Ctx) GetInfo() string {
return av1.SDPFmtpLine
}
func (r *Video) GetTimestamp() time.Duration {
return r.DTS
}
func (r *VideoFrame) Mux(baseFrame *Sample) error {
// 获取编解码器上下文
codecCtx := r.ICodecCtx
if codecCtx == nil {
switch base := baseFrame.GetBase().(type) {
case *codec.H264Ctx:
var ctx H264Ctx
ctx.H264Ctx = base
ctx.PayloadType = 96
ctx.MimeType = webrtc.MimeTypeH264
ctx.ClockRate = 90000
spsInfo := ctx.SPSInfo
ctx.SDPFmtpLine = fmt.Sprintf("sprop-parameter-sets=%s,%s;profile-level-id=%02x%02x%02x;level-asymmetry-allowed=1;packetization-mode=1", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), spsInfo.ProfileIdc, spsInfo.ConstraintSetFlag, spsInfo.LevelIdc)
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
codecCtx = &ctx
case *codec.H265Ctx:
var ctx H265Ctx
ctx.H265Ctx = base
ctx.PayloadType = 98
ctx.MimeType = webrtc.MimeTypeH265
ctx.SDPFmtpLine = fmt.Sprintf("profile-id=1;sprop-sps=%s;sprop-pps=%s;sprop-vps=%s", base64.StdEncoding.EncodeToString(ctx.SPS()), base64.StdEncoding.EncodeToString(ctx.PPS()), base64.StdEncoding.EncodeToString(ctx.VPS()))
ctx.ClockRate = 90000
ctx.SSRC = uint32(uintptr(unsafe.Pointer(&ctx)))
codecCtx = &ctx
}
r.ICodecCtx = codecCtx
}
// 获取时间戳信息
pts := uint32(baseFrame.GetPTS())
func (r *Video) GetCTS() time.Duration {
return r.CTS
}
func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
pts := uint32((from.Timestamp + from.CTS) * 90 / time.Millisecond)
switch c := codecCtx.(type) {
case *H264Ctx:
ctx := &c.RTPCtx
r.RTPCodecParameters = &ctx.RTPCodecParameters
var lastPacket *rtp.Packet
if from.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 {
if baseFrame.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 {
r.Append(ctx, pts, c.SPS())
r.Append(ctx, pts, c.PPS())
}
for _, nalu := range from.Raw.(Nalus) {
for nalu := range baseFrame.Raw.(*Nalus).RangePoint {
if reader := nalu.NewReader(); reader.Length > MTUSize {
payloadLen := MTUSize
if reader.Length+1 < payloadLen {
@@ -306,7 +305,7 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
}
//fu-a
mem := r.NextN(payloadLen)
reader.ReadBytesTo(mem[1:])
reader.Read(mem[1:])
fuaHead, naluType := codec.NALU_FUA.Or(mem[1]&0x60), mem[1]&0x1f
mem[0], mem[1] = fuaHead, naluType|startBit
lastPacket = r.Append(ctx, pts, mem)
@@ -315,27 +314,26 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
payloadLen = reader.Length + 2
}
mem = r.NextN(payloadLen)
reader.ReadBytesTo(mem[2:])
reader.Read(mem[2:])
mem[0], mem[1] = fuaHead, naluType
}
lastPacket.Payload[1] |= endBit
} else {
mem := r.NextN(reader.Length)
reader.ReadBytesTo(mem)
reader.Read(mem)
lastPacket = r.Append(ctx, pts, mem)
}
}
lastPacket.Header.Marker = true
case *H265Ctx:
ctx := &c.RTPCtx
r.RTPCodecParameters = &ctx.RTPCodecParameters
var lastPacket *rtp.Packet
if from.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 && len(c.RecordInfo.VPS) > 0 {
if baseFrame.IDR && len(c.RecordInfo.SPS) > 0 && len(c.RecordInfo.PPS) > 0 && len(c.RecordInfo.VPS) > 0 {
r.Append(ctx, pts, c.VPS())
r.Append(ctx, pts, c.SPS())
r.Append(ctx, pts, c.PPS())
}
for _, nalu := range from.Raw.(Nalus) {
for nalu := range baseFrame.Raw.(*Nalus).RangePoint {
if reader := nalu.NewReader(); reader.Length > MTUSize {
var b0, b1 byte
_ = reader.ReadByteTo(&b0, &b1)
@@ -348,7 +346,7 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
payloadLen = reader.Length + 3
}
mem := r.NextN(payloadLen)
reader.ReadBytesTo(mem[3:])
reader.Read(mem[3:])
mem[0], mem[1], mem[2] = b0, b1, naluType|startBit
lastPacket = r.Append(ctx, pts, mem)
@@ -357,37 +355,37 @@ func (r *Video) Mux(codecCtx codec.ICodecCtx, from *AVFrame) {
payloadLen = reader.Length + 3
}
mem = r.NextN(payloadLen)
reader.ReadBytesTo(mem[3:])
reader.Read(mem[3:])
mem[0], mem[1], mem[2] = b0, b1, naluType
}
lastPacket.Payload[2] |= endBit
} else {
mem := r.NextN(reader.Length)
reader.ReadBytesTo(mem)
reader.Read(mem)
lastPacket = r.Append(ctx, pts, mem)
}
}
lastPacket.Header.Marker = true
}
return nil
}
func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
func (r *VideoFrame) Demux() (err error) {
if len(r.Packets) == 0 {
return nil, ErrSkip
return ErrSkip
}
switch c := ictx.(type) {
switch c := r.ICodecCtx.(type) {
case *H264Ctx:
var nalus Nalus
var nalu util.Memory
nalus := r.GetNalus()
nalu := nalus.GetNextPointer()
var naluType codec.H264NALUType
gotNalu := func() {
if nalu.Size > 0 {
nalus = append(nalus, nalu)
nalu = util.Memory{}
nalu = nalus.GetNextPointer()
}
}
for _, packet := range r.Packets {
if len(packet.Payload) == 0 {
for packet := range r.Packets.RangePoint {
if len(packet.Payload) < 2 {
continue
}
if packet.Padding {
@@ -395,31 +393,31 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
}
b0 := packet.Payload[0]
if t := codec.ParseH264NALUType(b0); t < 24 {
nalu.AppendOne(packet.Payload)
nalu.PushOne(packet.Payload)
gotNalu()
} else {
offset := t.Offset()
switch t {
case codec.NALU_STAPA, codec.NALU_STAPB:
if len(packet.Payload) <= offset {
return nil, fmt.Errorf("invalid nalu size %d", len(packet.Payload))
return fmt.Errorf("invalid nalu size %d", len(packet.Payload))
}
for buffer := util.Buffer(packet.Payload[offset:]); buffer.CanRead(); {
if nextSize := int(buffer.ReadUint16()); buffer.Len() >= nextSize {
nalu.AppendOne(buffer.ReadN(nextSize))
nalu.PushOne(buffer.ReadN(nextSize))
gotNalu()
} else {
return nil, fmt.Errorf("invalid nalu size %d", nextSize)
return fmt.Errorf("invalid nalu size %d", nextSize)
}
}
case codec.NALU_FUA, codec.NALU_FUB:
b1 := packet.Payload[1]
if util.Bit1(b1, 0) {
naluType.Parse(b1)
nalu.AppendOne([]byte{naluType.Or(b0 & 0x60)})
nalu.PushOne([]byte{naluType.Or(b0 & 0x60)})
}
if nalu.Size > 0 {
nalu.AppendOne(packet.Payload[offset:])
nalu.PushOne(packet.Payload[offset:])
} else {
continue
}
@@ -427,18 +425,18 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
gotNalu()
}
default:
return nil, fmt.Errorf("unsupported nalu type %d", t)
return fmt.Errorf("unsupported nalu type %d", t)
}
}
}
return nalus, nil
nalus.Reduce()
return nil
case *H265Ctx:
var nalus Nalus
var nalu util.Memory
nalus := r.GetNalus()
nalu := nalus.GetNextPointer()
gotNalu := func() {
if nalu.Size > 0 {
nalus = append(nalus, nalu)
nalu = util.Memory{}
nalu = nalus.GetNextPointer()
}
}
for _, packet := range r.Packets {
@@ -447,7 +445,7 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
}
b0 := packet.Payload[0]
if t := codec.ParseH265NALUType(b0); t < H265_NALU_AP {
nalu.AppendOne(packet.Payload)
nalu.PushOne(packet.Payload)
gotNalu()
} else {
var buffer = util.Buffer(packet.Payload)
@@ -458,7 +456,7 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
buffer.ReadUint16()
}
for buffer.CanRead() {
nalu.AppendOne(buffer.ReadN(int(buffer.ReadUint16())))
nalu.PushOne(buffer.ReadN(int(buffer.ReadUint16())))
gotNalu()
}
if c.DONL {
@@ -466,7 +464,7 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
}
case H265_NALU_FU:
if buffer.Len() < 3 {
return nil, io.ErrShortBuffer
return io.ErrShortBuffer
}
first3 := buffer.ReadN(3)
fuHeader := first3[2]
@@ -474,18 +472,19 @@ func (r *Video) Demux(ictx codec.ICodecCtx) (any, error) {
buffer.ReadUint16()
}
if naluType := fuHeader & 0b00111111; util.Bit1(fuHeader, 0) {
nalu.AppendOne([]byte{first3[0]&0b10000001 | (naluType << 1), first3[1]})
nalu.PushOne([]byte{first3[0]&0b10000001 | (naluType << 1), first3[1]})
}
nalu.AppendOne(buffer)
nalu.PushOne(buffer)
if util.Bit1(fuHeader, 1) {
gotNalu()
}
default:
return nil, fmt.Errorf("unsupported nalu type %d", t)
return fmt.Errorf("unsupported nalu type %d", t)
}
}
}
return nalus, nil
nalus.Reduce()
return nil
}
return nil, nil
return ErrUnsupportCodec
}

View File

@@ -1,37 +0,0 @@
package rtp
import (
"testing"
"github.com/pion/webrtc/v4"
"m7s.live/v5/pkg"
"m7s.live/v5/pkg/util"
)
func TestRTPH264Ctx_CreateFrame(t *testing.T) {
var ctx = &H264Ctx{}
ctx.RTPCodecParameters = webrtc.RTPCodecParameters{
PayloadType: 96,
RTPCodecCapability: webrtc.RTPCodecCapability{
MimeType: webrtc.MimeTypeH264,
ClockRate: 90000,
SDPFmtpLine: "packetization-mode=1; sprop-parameter-sets=J2QAKaxWgHgCJ+WagICAgQ==,KO48sA==; profile-level-id=640029",
},
}
var randStr = util.RandomString(1500)
var avFrame = &pkg.AVFrame{}
var mem util.Memory
mem.Append([]byte(randStr))
avFrame.Raw = []util.Memory{mem}
frame := new(Video)
frame.Mux(ctx, avFrame)
var track = &pkg.AVTrack{}
err := frame.Parse(track)
if err != nil {
t.Error(err)
return
}
if s := string(track.Value.Raw.(pkg.Nalus)[0].ToBytes()); s != randStr {
t.Error("not equal", len(s), len(randStr))
}
}