feat: 转发流的ssrc以offer sdp的为准

This commit is contained in:
ydajiang
2025-06-04 20:55:18 +08:00
parent c09a132433
commit 98afe59c67
15 changed files with 242 additions and 95 deletions

View File

@@ -153,21 +153,49 @@ func (api *ApiServer) OnGBOfferCreate(v *SourceSDP, w http.ResponseWriter, r *ht
}
}
func (api *ApiServer) AddForwardSink(protocol stream.TransStreamProtocol, transport stream.TransportType, sourceId string, remoteAddr string, w http.ResponseWriter, r *http.Request) {
func (api *ApiServer) AddForwardSink(protocol stream.TransStreamProtocol, transport stream.TransportType, sourceId string, remoteAddr string, ssrc, sessionName string, w http.ResponseWriter, r *http.Request) {
// 解析或生成应答的ssrc
var ssrcOffer int
var ssrcAnswer string
if ssrc != "" {
var err error
ssrcOffer, err = strconv.Atoi(ssrc)
if err != nil {
log.Sugar.Errorf("解析ssrc失败 err: %s ssrc: %s", err.Error(), ssrc)
} else {
ssrcAnswer = ssrc
}
}
if ssrcAnswer == "" {
if "download" != sessionName && "playback" != sessionName {
ssrcAnswer = gb28181.GetLiveSSRC()
} else {
ssrcAnswer = gb28181.GetVodSSRC()
}
var err error
ssrcOffer, err = strconv.Atoi(ssrcAnswer)
// 严重错误, 直接panic
if err != nil {
panic(err)
}
}
var port int
sink, port, err := stream.ForwardStream(protocol, transport, sourceId, r.URL.Query(), remoteAddr, gb28181.TransportManger)
sink, port, err := stream.ForwardStream(protocol, transport, sourceId, r.URL.Query(), remoteAddr, gb28181.TransportManger, uint32(ssrcOffer))
if err != nil {
log.Sugar.Errorf("创建转发sink失败 err: %s", err.Error())
httpResponseError(w, err.Error())
return
}
log.Sugar.Infof("创建转发sink成功, sink: %s port: %d transport: %s", sink.GetID(), port, transport)
log.Sugar.Infof("创建转发sink成功, sink: %s port: %d transport: %s ssrc: %s", sink.GetID(), port, transport, ssrcAnswer)
response := struct {
Sink string `json:"sink"` // sink id
SDP
}{Sink: stream.SinkID2String(sink.GetID()), SDP: SDP{Addr: net.JoinHostPort(stream.AppConfig.PublicIP, strconv.Itoa(port))}}
}{Sink: stream.SinkID2String(sink.GetID()), SDP: SDP{Addr: net.JoinHostPort(stream.AppConfig.PublicIP, strconv.Itoa(port)), SSRC: ssrcAnswer}}
httpResponseOK(w, &response)
}
@@ -235,5 +263,5 @@ func (api *ApiServer) OnSinkAdd(v *GBOffer, w http.ResponseWriter, r *http.Reque
setup = gb28181.SetupTypeFromString(v.AnswerSetup)
}
api.AddForwardSink(v.TransStreamProtocol, setup.TransportType(), v.Source, v.Addr, w, r)
api.AddForwardSink(v.TransStreamProtocol, setup.TransportType(), v.Source, v.Addr, v.SSRC, v.SessionName, w, r)
}

View File

@@ -171,7 +171,7 @@ func NewHttpTransStream(metadata *amf0.Object, prevTagSize uint32) stream.TransS
}
}
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) {
func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
var prevTagSize uint32
var metaData *amf0.Object

View File

@@ -18,6 +18,7 @@ type GBGateway struct {
rtp rtp.Muxer
psBuffer []byte
tracks map[utils.AVCodecID]int // codec->track index
rtpBuffer *stream.RtpBuffer
}
func (s *GBGateway) WriteHeader() error {
@@ -61,6 +62,7 @@ func (s *GBGateway) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo
data = avformat.AVCCPacket2AnnexB(s.BaseTransStream.Tracks[packet.Index].Stream, packet)
}
// 扩容ps buffer
if cap(s.psBuffer) < len(data)+1024*64 {
s.psBuffer = make([]byte, len(data)*2)
}
@@ -69,27 +71,52 @@ func (s *GBGateway) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo
var result []*collections.ReferenceCounter[[]byte]
var rtpBuffer []byte
var counter *collections.ReferenceCounter[[]byte]
s.rtp.Input(s.psBuffer[:n], uint32(dts), func() []byte {
rtpBuffer = stream.UDPReceiveBufferPool.Get().([]byte)
counter = s.rtpBuffer.Get()
counter.Refer()
rtpBuffer = counter.Get()
return rtpBuffer[2:]
}, func(bytes []byte) {
binary.BigEndian.PutUint16(rtpBuffer, uint16(len(bytes)))
refPacket := collections.NewReferenceCounter(rtpBuffer[:2+len(bytes)])
result = append(result, refPacket)
counter.ResetData(rtpBuffer[:2+len(bytes)])
result = append(result, counter)
})
// 引用计数保持为1
for _, pkt := range result {
pkt.Release()
}
return result, 0, true, nil
}
func NewGBGateway() *GBGateway {
func (s *GBGateway) Close() ([]*collections.ReferenceCounter[[]byte], int64, error) {
s.rtpBuffer.Clear()
return nil, 0, nil
}
func NewGBGateway(ssrc uint32) *GBGateway {
return &GBGateway{
ps: mpeg.NewPsMuxer(),
rtp: rtp.NewMuxer(96, 0, 0xFFFFFFFF),
rtp: rtp.NewMuxer(96, 0, ssrc),
psBuffer: make([]byte, 1024*1024*2),
tracks: make(map[utils.AVCodecID]int),
rtpBuffer: stream.NewRtpBuffer(1024),
}
}
func GatewayTransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) {
return NewGBGateway(), nil
func GatewayTransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, sink stream.Sink) (stream.TransStream, error) {
// 默认ssrc
var ssrc uint32 = 0xFFFFFFFF
// 优先使用sink的ssrc, 减少内存拷贝
if sink != nil {
if forwardSink, ok := sink.(*stream.ForwardSink); ok {
ssrc = forwardSink.GetSSRC()
}
}
gateway := NewGBGateway(ssrc)
return gateway, nil
}

View File

@@ -29,14 +29,22 @@ func (s *TalkStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceC
return s.RtpStream.Input(packet)
}
func NewTalkTransStream() (stream.TransStream, error) {
func NewTalkTransStream(ssrc uint32) (stream.TransStream, error) {
return &TalkStream{
RtpStream: stream.NewRtpTransStream(stream.TransStreamGBTalk, 1024),
muxer: rtp.NewMuxer(8, 0, 0xFFFFFFFF),
muxer: rtp.NewMuxer(8, 0, ssrc),
packet: make([]byte, 1500),
}, nil
}
func TalkTransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) {
return NewTalkTransStream()
func TalkTransStreamFactory(_ stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, sink stream.Sink) (stream.TransStream, error) {
var ssrc uint32 = 0xFFFFFFFF
if sink != nil {
forwardSink, ok := sink.(*stream.ForwardSink)
if ok {
ssrc = forwardSink.GetSSRC()
}
}
return NewTalkTransStream(ssrc)
}

View File

@@ -297,7 +297,7 @@ func NewTransStream(dir, m3u8Name, tsFormat, tsUrl string, segmentDuration, play
return transStream, nil
}
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) {
func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
id := source.GetID()
var writer stream.M3U8Writer

View File

@@ -182,7 +182,7 @@ func TestPublish(t *testing.T) {
buffer: make([]byte, 1024*1024*2),
fos: openFile,
tracks: make(map[int]int),
gateway: gb28181.NewGBGateway(),
gateway: gb28181.NewGBGateway(0xFFFFFFFF),
udp: client,
})

View File

@@ -78,6 +78,6 @@ func NewTransStream() stream.TransStream {
return t
}
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) {
func TransStreamFactory(_ stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
return NewTransStream(), nil
}

View File

@@ -238,7 +238,7 @@ func NewTransStream(chunkSize int, metaData *amf0.Object) stream.TransStream {
return &transStream{chunkSize: chunkSize, metaData: metaData}
}
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) {
func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
// 获取推流的元数据
var metaData *amf0.Object
if stream.SourceTypeRtmp == source.GetType() {

View File

@@ -32,7 +32,7 @@ type TransStream struct {
oldTracks map[byte]uint16
sdp string
rtpBuffers *collections.Queue[*collections.ReferenceCounter[[]byte]]
rtpBuffer *stream.RtpBuffer
}
func (t *TransStream) OverTCP(data []byte, channel int) {
@@ -42,20 +42,6 @@ func (t *TransStream) OverTCP(data []byte, channel int) {
}
func (t *TransStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCounter[[]byte], int64, bool, error) {
// 释放rtp包
for t.rtpBuffers.Size() > 0 {
rtp := t.rtpBuffers.Peek(0)
if rtp.UseCount() > 1 {
break
}
t.rtpBuffers.Pop()
// 放回池中
data := rtp.Get()
stream.UDPReceiveBufferPool.Put(data[:cap(data)])
}
var ts uint32
var result []*collections.ReferenceCounter[[]byte]
track := t.RtspTracks[packet.Index]
@@ -97,22 +83,30 @@ func (t *TransStream) ReadExtraData(ts int64) ([]*collections.ReferenceCounter[[
func (t *TransStream) PackRtpPayload(track *Track, channel int, data []byte, timestamp uint32) []*collections.ReferenceCounter[[]byte] {
var result []*collections.ReferenceCounter[[]byte]
var packet []byte
var counter *collections.ReferenceCounter[[]byte]
// 保存开始序号
track.StartSeq = track.Muxer.GetHeader().Seq
track.Muxer.Input(data, timestamp, func() []byte {
packet = stream.UDPReceiveBufferPool.Get().([]byte)
counter = t.rtpBuffer.Get()
counter.Refer()
packet = counter.Get()
return packet[OverTcpHeaderSize:]
}, func(bytes []byte) {
track.EndSeq = track.Muxer.GetHeader().Seq
overTCPPacket := packet[:OverTcpHeaderSize+len(bytes)]
t.OverTCP(overTCPPacket, channel)
refPacket := collections.NewReferenceCounter(overTCPPacket)
result = append(result, refPacket)
t.rtpBuffers.Push(refPacket)
counter.ResetData(overTCPPacket)
result = append(result, counter)
})
// 引用计数保持为1
for _, pkt := range result {
pkt.Release()
}
return result
}
@@ -154,9 +148,10 @@ func (t *TransStream) AddTrack(track *stream.Track) error {
packAndAdd := func(data []byte) {
packets := t.PackRtpPayload(rtspTrack, trackIndex, data, 0)
for _, packet := range packets {
extraDataPackets = append(extraDataPackets, packet)
// 出队列, 单独保存
t.rtpBuffers.Pop()
extra := packet.Get()
bytes := make([]byte, len(extra))
copy(bytes, extra)
extraDataPackets = append(extraDataPackets, collections.NewReferenceCounter(bytes))
}
}
@@ -277,7 +272,7 @@ func NewTransStream(addr net.IPAddr, urlFormat string, oldTracks map[byte]uint16
addr: addr,
urlFormat: urlFormat,
oldTracks: oldTracks,
rtpBuffers: collections.NewQueue[*collections.ReferenceCounter[[]byte]](512),
rtpBuffer: stream.NewRtpBuffer(512),
}
if addr.IP.To4() != nil {
@@ -289,7 +284,7 @@ func NewTransStream(addr net.IPAddr, urlFormat string, oldTracks map[byte]uint16
return t
}
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) {
func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
trackFormat := "?track=%d"
var oldTracks map[byte]uint16
if endInfo := source.GetTransStreamPublisher().GetStreamEndInfo(); endInfo != nil {

View File

@@ -1,6 +1,7 @@
package stream
import (
"encoding/binary"
"github.com/lkmio/avformat/collections"
"github.com/lkmio/lkm/log"
"github.com/lkmio/transport"
@@ -29,11 +30,15 @@ func (t TransportType) String() string {
}
}
// ForwardSink 转发流Sink, 级联/对讲广播/JT1078转GB28181均使用
type ForwardSink struct {
BaseSink
socket transport.Transport
transportType TransportType
receiveTimer *time.Timer
ssrc uint32
requireSSRCMatch bool // 如果ssrc要求一致, 发包时要检查ssrc是否一致, 不一致则重新拷贝一份
rtpBuffer *RtpBuffer
}
func (f *ForwardSink) OnConnected(conn net.Conn) []byte {
@@ -67,18 +72,68 @@ func (f *ForwardSink) Write(index int, data []*collections.ReferenceCounter[[]by
return nil
}
var processedData []*collections.ReferenceCounter[[]byte]
// ssrc不一致, 重新拷贝一份, 修改为指定的ssrc
if f.requireSSRCMatch && f.ssrc != binary.BigEndian.Uint32(data[0].Get()[2+8:]) {
if TransportTypeUDP != f.transportType {
if f.rtpBuffer == nil {
f.rtpBuffer = NewRtpBuffer(1024)
}
processedData = make([]*collections.ReferenceCounter[[]byte], 0, len(data))
} else if f.rtpBuffer == nil {
f.rtpBuffer = NewRtpBuffer(1)
}
for i, datum := range data {
src := datum.Get()
counter := f.rtpBuffer.Get()
bytes := counter.Get()
length := len(src)
copy(bytes, src[:length])
// 修改ssrc
binary.BigEndian.PutUint32(bytes[2+8:], f.ssrc)
// UDP直接发送
if TransportTypeUDP == f.transportType {
for _, datum := range data {
_ = f.socket.(*transport.UDPClient).Write(bytes[2:length])
} else {
counter.ResetData(bytes[:length])
counter.Refer()
processedData[i] = counter
}
}
// UDP已经发送, 直接返回
if processedData == nil {
return nil
} else {
// 引用计数保持为1
for _, pkt := range processedData {
pkt.Release()
}
}
}
if processedData == nil {
processedData = data
}
if TransportTypeUDP == f.transportType {
for _, datum := range processedData {
f.socket.(*transport.UDPClient).Write(datum.Get()[2:])
}
} else {
return f.BaseSink.Write(index, data, ts, keyVideo)
return f.BaseSink.Write(index, processedData, ts, keyVideo)
}
return nil
}
// Close 关闭国标转发流
// Close 关闭转发流
func (f *ForwardSink) Close() {
f.BaseSink.Close()
@@ -89,6 +144,10 @@ func (f *ForwardSink) Close() {
if f.receiveTimer != nil {
f.receiveTimer.Stop()
}
if f.rtpBuffer != nil {
f.rtpBuffer.Clear()
}
}
// StartReceiveTimer 启动tcp sever计时器, 如果计时器触发, 没有连接, 则关闭流
@@ -101,10 +160,16 @@ func (f *ForwardSink) StartReceiveTimer() {
})
}
func NewForwardSink(transportType TransportType, protocol TransStreamProtocol, sinkId SinkID, sourceId string, addr string, manager transport.Manager) (*ForwardSink, int, error) {
func (f *ForwardSink) GetSSRC() uint32 {
return f.ssrc
}
func NewForwardSink(transportType TransportType, protocol TransStreamProtocol, sinkId SinkID, sourceId string, addr string, manager transport.Manager, ssrc uint32) (*ForwardSink, int, error) {
sink := &ForwardSink{
BaseSink: BaseSink{ID: sinkId, SourceID: sourceId, State: SessionStateCreated, Protocol: protocol},
transportType: transportType,
ssrc: ssrc,
requireSSRCMatch: true, // 默认要求ssrc一致
}
if transportType == TransportTypeUDP {

39
stream/rtp_buffer.go Normal file
View File

@@ -0,0 +1,39 @@
package stream
import (
"github.com/lkmio/avformat/collections"
)
type RtpBuffer struct {
queue *collections.Queue[*collections.ReferenceCounter[[]byte]]
}
func (r *RtpBuffer) Get() *collections.ReferenceCounter[[]byte] {
if r.queue.Size() > 0 {
rtp := r.queue.Peek(0)
if rtp.UseCount() < 2 {
bytes := rtp.Get()
rtp.ResetData(bytes[:cap(bytes)])
return rtp
}
}
bytes := collections.NewReferenceCounter(UDPReceiveBufferPool.Get().([]byte))
r.queue.Push(bytes)
return bytes
}
func (r *RtpBuffer) Clear() {
for r.queue.Size() > 0 {
if r.queue.Peek(0).UseCount() > 1 {
break
}
bytes := r.queue.Pop().Get()
UDPReceiveBufferPool.Put(bytes[:cap(bytes)])
}
}
func NewRtpBuffer(capacity int) *RtpBuffer {
return &RtpBuffer{queue: collections.NewQueue[*collections.ReferenceCounter[[]byte]](capacity)}
}

View File

@@ -9,7 +9,7 @@ import (
type RtpStream struct {
BaseTransStream
rtpBuffers *collections.Queue[*collections.ReferenceCounter[[]byte]]
rtpBuffer *RtpBuffer
}
func (f *RtpStream) WriteHeader() error {
@@ -23,38 +23,23 @@ func (f *RtpStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo
return nil, 0, false, nil
}
// 释放rtp包
for f.rtpBuffers.Size() > 0 {
rtp := f.rtpBuffers.Peek(0)
if rtp.UseCount() > 1 {
break
}
f.rtpBuffers.Pop()
// 放回池中
data := rtp.Get()
UDPReceiveBufferPool.Put(data[:cap(data)])
}
bytes := UDPReceiveBufferPool.Get().([]byte)
counter := f.rtpBuffer.Get()
bytes := counter.Get()
binary.BigEndian.PutUint16(bytes, size-2)
copy(bytes[2:], packet.Data)
rtp := collections.NewReferenceCounter(bytes[:size])
f.rtpBuffers.Push(rtp)
counter.ResetData(bytes)
// 每帧都当关键帧, 直接发给上级
return []*collections.ReferenceCounter[[]byte]{rtp}, -1, true, nil
return []*collections.ReferenceCounter[[]byte]{counter}, -1, true, nil
}
func NewRtpTransStream(protocol TransStreamProtocol, capacity int) *RtpStream {
return &RtpStream{
BaseTransStream: BaseTransStream{Protocol: protocol},
rtpBuffers: collections.NewQueue[*collections.ReferenceCounter[[]byte]](capacity),
rtpBuffer: NewRtpBuffer(capacity),
}
}
func GBCascadedTransStreamFactory(source Source, protocol TransStreamProtocol, tracks []*Track) (TransStream, error) {
func GBCascadedTransStreamFactory(_ Source, _ TransStreamProtocol, _ []*Track, _ Sink) (TransStream, error) {
return NewRtpTransStream(TransStreamGBCascaded, 1024), nil
}

View File

@@ -83,7 +83,7 @@ func SubscribeStreamWithOptions(sink Sink, values url.Values, ready bool, timeou
return state
}
func ForwardStream(protocol TransStreamProtocol, transport TransportType, sourceId string, values url.Values, remoteAddr string, manager transport.Manager) (Sink, int, error) {
func ForwardStream(protocol TransStreamProtocol, transport TransportType, sourceId string, values url.Values, remoteAddr string, manager transport.Manager, ssrc uint32) (Sink, int, error) {
//source := SourceManager.Find(sourceId)
//if source == nil {
// return nil, 0, fmt.Errorf("source %s 不存在", sourceId)
@@ -91,7 +91,7 @@ func ForwardStream(protocol TransStreamProtocol, transport TransportType, source
sinkId := GenerateUint64SinkID()
var port int
sink, port, err := NewForwardSink(transport, protocol, sinkId, sourceId, remoteAddr, manager)
sink, port, err := NewForwardSink(transport, protocol, sinkId, sourceId, remoteAddr, manager, ssrc)
if err != nil {
return nil, 0, err
}

View File

@@ -4,7 +4,7 @@ import (
"fmt"
)
type TransStreamFactory func(source Source, protocol TransStreamProtocol, tracks []*Track) (TransStream, error)
type TransStreamFactory func(source Source, protocol TransStreamProtocol, tracks []*Track, sink Sink) (TransStream, error)
type RecordStreamFactory func(source string) (Sink, string, error)
@@ -35,13 +35,13 @@ func FindTransStreamFactory(protocol TransStreamProtocol) (TransStreamFactory, e
return f, nil
}
func CreateTransStream(source Source, protocol TransStreamProtocol, tracks []*Track) (TransStream, error) {
func CreateTransStream(source Source, protocol TransStreamProtocol, tracks []*Track, sink Sink) (TransStream, error) {
factory, err := FindTransStreamFactory(protocol)
if err != nil {
return nil, err
}
return factory(source, protocol, tracks)
return factory(source, protocol, tracks, sink)
}
func SetRecordStreamFactory(factory RecordStreamFactory) {

View File

@@ -187,7 +187,7 @@ func (t *transStreamPublisher) CreateDefaultOutStreams() {
utils.Assert(len(streams) > 0)
id := GenerateTransStreamID(TransStreamHls, streams...)
hlsStream, err := t.CreateTransStream(id, TransStreamHls, streams)
hlsStream, err := t.CreateTransStream(id, TransStreamHls, streams, nil)
if err != nil {
panic(err)
}
@@ -206,12 +206,12 @@ func IsSupportMux(protocol TransStreamProtocol, _, _ utils.AVCodecID) bool {
return true
}
func (t *transStreamPublisher) CreateTransStream(id TransStreamID, protocol TransStreamProtocol, tracks []*Track) (TransStream, error) {
func (t *transStreamPublisher) CreateTransStream(id TransStreamID, protocol TransStreamProtocol, tracks []*Track, sink Sink) (TransStream, error) {
log.Sugar.Infof("创建%s-stream source: %s", protocol.String(), t.source)
source := SourceManager.Find(t.source)
utils.Assert(source != nil)
transStream, err := CreateTransStream(source, protocol, tracks)
transStream, err := CreateTransStream(source, protocol, tracks, sink)
if err != nil {
log.Sugar.Errorf("创建传输流失败 err: %s source: %s", err.Error(), t.source)
return nil, err
@@ -356,7 +356,7 @@ func (t *transStreamPublisher) doAddSink(sink Sink, resume bool) bool {
transStream, exist := t.transStreams[transStreamId]
if !exist {
var err error
transStream, err = t.CreateTransStream(transStreamId, sink.GetProtocol(), tracks)
transStream, err = t.CreateTransStream(transStreamId, sink.GetProtocol(), tracks, sink)
if err != nil {
log.Sugar.Errorf("添加sink失败,创建传输流发生err: %s source: %s", err.Error(), t.source)
return false