feat: 转发流的ssrc以offer sdp的为准

This commit is contained in:
ydajiang
2025-06-04 20:55:18 +08:00
parent c09a132433
commit 98afe59c67
15 changed files with 242 additions and 95 deletions

View File

@@ -153,21 +153,49 @@ func (api *ApiServer) OnGBOfferCreate(v *SourceSDP, w http.ResponseWriter, r *ht
} }
} }
func (api *ApiServer) AddForwardSink(protocol stream.TransStreamProtocol, transport stream.TransportType, sourceId string, remoteAddr string, w http.ResponseWriter, r *http.Request) { func (api *ApiServer) AddForwardSink(protocol stream.TransStreamProtocol, transport stream.TransportType, sourceId string, remoteAddr string, ssrc, sessionName string, w http.ResponseWriter, r *http.Request) {
// 解析或生成应答的ssrc
var ssrcOffer int
var ssrcAnswer string
if ssrc != "" {
var err error
ssrcOffer, err = strconv.Atoi(ssrc)
if err != nil {
log.Sugar.Errorf("解析ssrc失败 err: %s ssrc: %s", err.Error(), ssrc)
} else {
ssrcAnswer = ssrc
}
}
if ssrcAnswer == "" {
if "download" != sessionName && "playback" != sessionName {
ssrcAnswer = gb28181.GetLiveSSRC()
} else {
ssrcAnswer = gb28181.GetVodSSRC()
}
var err error
ssrcOffer, err = strconv.Atoi(ssrcAnswer)
// 严重错误, 直接panic
if err != nil {
panic(err)
}
}
var port int var port int
sink, port, err := stream.ForwardStream(protocol, transport, sourceId, r.URL.Query(), remoteAddr, gb28181.TransportManger) sink, port, err := stream.ForwardStream(protocol, transport, sourceId, r.URL.Query(), remoteAddr, gb28181.TransportManger, uint32(ssrcOffer))
if err != nil { if err != nil {
log.Sugar.Errorf("创建转发sink失败 err: %s", err.Error()) log.Sugar.Errorf("创建转发sink失败 err: %s", err.Error())
httpResponseError(w, err.Error()) httpResponseError(w, err.Error())
return return
} }
log.Sugar.Infof("创建转发sink成功, sink: %s port: %d transport: %s", sink.GetID(), port, transport) log.Sugar.Infof("创建转发sink成功, sink: %s port: %d transport: %s ssrc: %s", sink.GetID(), port, transport, ssrcAnswer)
response := struct { response := struct {
Sink string `json:"sink"` // sink id Sink string `json:"sink"` // sink id
SDP SDP
}{Sink: stream.SinkID2String(sink.GetID()), SDP: SDP{Addr: net.JoinHostPort(stream.AppConfig.PublicIP, strconv.Itoa(port))}} }{Sink: stream.SinkID2String(sink.GetID()), SDP: SDP{Addr: net.JoinHostPort(stream.AppConfig.PublicIP, strconv.Itoa(port)), SSRC: ssrcAnswer}}
httpResponseOK(w, &response) httpResponseOK(w, &response)
} }
@@ -235,5 +263,5 @@ func (api *ApiServer) OnSinkAdd(v *GBOffer, w http.ResponseWriter, r *http.Reque
setup = gb28181.SetupTypeFromString(v.AnswerSetup) setup = gb28181.SetupTypeFromString(v.AnswerSetup)
} }
api.AddForwardSink(v.TransStreamProtocol, setup.TransportType(), v.Source, v.Addr, w, r) api.AddForwardSink(v.TransStreamProtocol, setup.TransportType(), v.Source, v.Addr, v.SSRC, v.SessionName, w, r)
} }

View File

@@ -171,7 +171,7 @@ func NewHttpTransStream(metadata *amf0.Object, prevTagSize uint32) stream.TransS
} }
} }
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) { func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
var prevTagSize uint32 var prevTagSize uint32
var metaData *amf0.Object var metaData *amf0.Object

View File

@@ -18,6 +18,7 @@ type GBGateway struct {
rtp rtp.Muxer rtp rtp.Muxer
psBuffer []byte psBuffer []byte
tracks map[utils.AVCodecID]int // codec->track index tracks map[utils.AVCodecID]int // codec->track index
rtpBuffer *stream.RtpBuffer
} }
func (s *GBGateway) WriteHeader() error { func (s *GBGateway) WriteHeader() error {
@@ -61,6 +62,7 @@ func (s *GBGateway) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo
data = avformat.AVCCPacket2AnnexB(s.BaseTransStream.Tracks[packet.Index].Stream, packet) data = avformat.AVCCPacket2AnnexB(s.BaseTransStream.Tracks[packet.Index].Stream, packet)
} }
// 扩容ps buffer
if cap(s.psBuffer) < len(data)+1024*64 { if cap(s.psBuffer) < len(data)+1024*64 {
s.psBuffer = make([]byte, len(data)*2) s.psBuffer = make([]byte, len(data)*2)
} }
@@ -69,27 +71,52 @@ func (s *GBGateway) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo
var result []*collections.ReferenceCounter[[]byte] var result []*collections.ReferenceCounter[[]byte]
var rtpBuffer []byte var rtpBuffer []byte
var counter *collections.ReferenceCounter[[]byte]
s.rtp.Input(s.psBuffer[:n], uint32(dts), func() []byte { s.rtp.Input(s.psBuffer[:n], uint32(dts), func() []byte {
rtpBuffer = stream.UDPReceiveBufferPool.Get().([]byte) counter = s.rtpBuffer.Get()
counter.Refer()
rtpBuffer = counter.Get()
return rtpBuffer[2:] return rtpBuffer[2:]
}, func(bytes []byte) { }, func(bytes []byte) {
binary.BigEndian.PutUint16(rtpBuffer, uint16(len(bytes))) binary.BigEndian.PutUint16(rtpBuffer, uint16(len(bytes)))
refPacket := collections.NewReferenceCounter(rtpBuffer[:2+len(bytes)]) counter.ResetData(rtpBuffer[:2+len(bytes)])
result = append(result, refPacket) result = append(result, counter)
}) })
// 引用计数保持为1
for _, pkt := range result {
pkt.Release()
}
return result, 0, true, nil return result, 0, true, nil
} }
func NewGBGateway() *GBGateway { func (s *GBGateway) Close() ([]*collections.ReferenceCounter[[]byte], int64, error) {
s.rtpBuffer.Clear()
return nil, 0, nil
}
func NewGBGateway(ssrc uint32) *GBGateway {
return &GBGateway{ return &GBGateway{
ps: mpeg.NewPsMuxer(), ps: mpeg.NewPsMuxer(),
rtp: rtp.NewMuxer(96, 0, 0xFFFFFFFF), rtp: rtp.NewMuxer(96, 0, ssrc),
psBuffer: make([]byte, 1024*1024*2), psBuffer: make([]byte, 1024*1024*2),
tracks: make(map[utils.AVCodecID]int), tracks: make(map[utils.AVCodecID]int),
rtpBuffer: stream.NewRtpBuffer(1024),
} }
} }
func GatewayTransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) { func GatewayTransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, sink stream.Sink) (stream.TransStream, error) {
return NewGBGateway(), nil // 默认ssrc
var ssrc uint32 = 0xFFFFFFFF
// 优先使用sink的ssrc, 减少内存拷贝
if sink != nil {
if forwardSink, ok := sink.(*stream.ForwardSink); ok {
ssrc = forwardSink.GetSSRC()
}
}
gateway := NewGBGateway(ssrc)
return gateway, nil
} }

View File

@@ -29,14 +29,22 @@ func (s *TalkStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceC
return s.RtpStream.Input(packet) return s.RtpStream.Input(packet)
} }
func NewTalkTransStream() (stream.TransStream, error) { func NewTalkTransStream(ssrc uint32) (stream.TransStream, error) {
return &TalkStream{ return &TalkStream{
RtpStream: stream.NewRtpTransStream(stream.TransStreamGBTalk, 1024), RtpStream: stream.NewRtpTransStream(stream.TransStreamGBTalk, 1024),
muxer: rtp.NewMuxer(8, 0, 0xFFFFFFFF), muxer: rtp.NewMuxer(8, 0, ssrc),
packet: make([]byte, 1500), packet: make([]byte, 1500),
}, nil }, nil
} }
func TalkTransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) { func TalkTransStreamFactory(_ stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, sink stream.Sink) (stream.TransStream, error) {
return NewTalkTransStream() var ssrc uint32 = 0xFFFFFFFF
if sink != nil {
forwardSink, ok := sink.(*stream.ForwardSink)
if ok {
ssrc = forwardSink.GetSSRC()
}
}
return NewTalkTransStream(ssrc)
} }

View File

@@ -297,7 +297,7 @@ func NewTransStream(dir, m3u8Name, tsFormat, tsUrl string, segmentDuration, play
return transStream, nil return transStream, nil
} }
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) { func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
id := source.GetID() id := source.GetID()
var writer stream.M3U8Writer var writer stream.M3U8Writer

View File

@@ -182,7 +182,7 @@ func TestPublish(t *testing.T) {
buffer: make([]byte, 1024*1024*2), buffer: make([]byte, 1024*1024*2),
fos: openFile, fos: openFile,
tracks: make(map[int]int), tracks: make(map[int]int),
gateway: gb28181.NewGBGateway(), gateway: gb28181.NewGBGateway(0xFFFFFFFF),
udp: client, udp: client,
}) })

View File

@@ -78,6 +78,6 @@ func NewTransStream() stream.TransStream {
return t return t
} }
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) { func TransStreamFactory(_ stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
return NewTransStream(), nil return NewTransStream(), nil
} }

View File

@@ -238,7 +238,7 @@ func NewTransStream(chunkSize int, metaData *amf0.Object) stream.TransStream {
return &transStream{chunkSize: chunkSize, metaData: metaData} return &transStream{chunkSize: chunkSize, metaData: metaData}
} }
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) { func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
// 获取推流的元数据 // 获取推流的元数据
var metaData *amf0.Object var metaData *amf0.Object
if stream.SourceTypeRtmp == source.GetType() { if stream.SourceTypeRtmp == source.GetType() {

View File

@@ -32,7 +32,7 @@ type TransStream struct {
oldTracks map[byte]uint16 oldTracks map[byte]uint16
sdp string sdp string
rtpBuffers *collections.Queue[*collections.ReferenceCounter[[]byte]] rtpBuffer *stream.RtpBuffer
} }
func (t *TransStream) OverTCP(data []byte, channel int) { func (t *TransStream) OverTCP(data []byte, channel int) {
@@ -42,20 +42,6 @@ func (t *TransStream) OverTCP(data []byte, channel int) {
} }
func (t *TransStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCounter[[]byte], int64, bool, error) { func (t *TransStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCounter[[]byte], int64, bool, error) {
// 释放rtp包
for t.rtpBuffers.Size() > 0 {
rtp := t.rtpBuffers.Peek(0)
if rtp.UseCount() > 1 {
break
}
t.rtpBuffers.Pop()
// 放回池中
data := rtp.Get()
stream.UDPReceiveBufferPool.Put(data[:cap(data)])
}
var ts uint32 var ts uint32
var result []*collections.ReferenceCounter[[]byte] var result []*collections.ReferenceCounter[[]byte]
track := t.RtspTracks[packet.Index] track := t.RtspTracks[packet.Index]
@@ -97,22 +83,30 @@ func (t *TransStream) ReadExtraData(ts int64) ([]*collections.ReferenceCounter[[
func (t *TransStream) PackRtpPayload(track *Track, channel int, data []byte, timestamp uint32) []*collections.ReferenceCounter[[]byte] { func (t *TransStream) PackRtpPayload(track *Track, channel int, data []byte, timestamp uint32) []*collections.ReferenceCounter[[]byte] {
var result []*collections.ReferenceCounter[[]byte] var result []*collections.ReferenceCounter[[]byte]
var packet []byte var packet []byte
var counter *collections.ReferenceCounter[[]byte]
// 保存开始序号 // 保存开始序号
track.StartSeq = track.Muxer.GetHeader().Seq track.StartSeq = track.Muxer.GetHeader().Seq
track.Muxer.Input(data, timestamp, func() []byte { track.Muxer.Input(data, timestamp, func() []byte {
packet = stream.UDPReceiveBufferPool.Get().([]byte) counter = t.rtpBuffer.Get()
counter.Refer()
packet = counter.Get()
return packet[OverTcpHeaderSize:] return packet[OverTcpHeaderSize:]
}, func(bytes []byte) { }, func(bytes []byte) {
track.EndSeq = track.Muxer.GetHeader().Seq track.EndSeq = track.Muxer.GetHeader().Seq
overTCPPacket := packet[:OverTcpHeaderSize+len(bytes)] overTCPPacket := packet[:OverTcpHeaderSize+len(bytes)]
t.OverTCP(overTCPPacket, channel) t.OverTCP(overTCPPacket, channel)
refPacket := collections.NewReferenceCounter(overTCPPacket) counter.ResetData(overTCPPacket)
result = append(result, refPacket) result = append(result, counter)
t.rtpBuffers.Push(refPacket)
}) })
// 引用计数保持为1
for _, pkt := range result {
pkt.Release()
}
return result return result
} }
@@ -154,9 +148,10 @@ func (t *TransStream) AddTrack(track *stream.Track) error {
packAndAdd := func(data []byte) { packAndAdd := func(data []byte) {
packets := t.PackRtpPayload(rtspTrack, trackIndex, data, 0) packets := t.PackRtpPayload(rtspTrack, trackIndex, data, 0)
for _, packet := range packets { for _, packet := range packets {
extraDataPackets = append(extraDataPackets, packet) extra := packet.Get()
// 出队列, 单独保存 bytes := make([]byte, len(extra))
t.rtpBuffers.Pop() copy(bytes, extra)
extraDataPackets = append(extraDataPackets, collections.NewReferenceCounter(bytes))
} }
} }
@@ -277,7 +272,7 @@ func NewTransStream(addr net.IPAddr, urlFormat string, oldTracks map[byte]uint16
addr: addr, addr: addr,
urlFormat: urlFormat, urlFormat: urlFormat,
oldTracks: oldTracks, oldTracks: oldTracks,
rtpBuffers: collections.NewQueue[*collections.ReferenceCounter[[]byte]](512), rtpBuffer: stream.NewRtpBuffer(512),
} }
if addr.IP.To4() != nil { if addr.IP.To4() != nil {
@@ -289,7 +284,7 @@ func NewTransStream(addr net.IPAddr, urlFormat string, oldTracks map[byte]uint16
return t return t
} }
func TransStreamFactory(source stream.Source, protocol stream.TransStreamProtocol, tracks []*stream.Track) (stream.TransStream, error) { func TransStreamFactory(source stream.Source, _ stream.TransStreamProtocol, _ []*stream.Track, _ stream.Sink) (stream.TransStream, error) {
trackFormat := "?track=%d" trackFormat := "?track=%d"
var oldTracks map[byte]uint16 var oldTracks map[byte]uint16
if endInfo := source.GetTransStreamPublisher().GetStreamEndInfo(); endInfo != nil { if endInfo := source.GetTransStreamPublisher().GetStreamEndInfo(); endInfo != nil {

View File

@@ -1,6 +1,7 @@
package stream package stream
import ( import (
"encoding/binary"
"github.com/lkmio/avformat/collections" "github.com/lkmio/avformat/collections"
"github.com/lkmio/lkm/log" "github.com/lkmio/lkm/log"
"github.com/lkmio/transport" "github.com/lkmio/transport"
@@ -29,11 +30,15 @@ func (t TransportType) String() string {
} }
} }
// ForwardSink 转发流Sink, 级联/对讲广播/JT1078转GB28181均使用
type ForwardSink struct { type ForwardSink struct {
BaseSink BaseSink
socket transport.Transport socket transport.Transport
transportType TransportType transportType TransportType
receiveTimer *time.Timer receiveTimer *time.Timer
ssrc uint32
requireSSRCMatch bool // 如果ssrc要求一致, 发包时要检查ssrc是否一致, 不一致则重新拷贝一份
rtpBuffer *RtpBuffer
} }
func (f *ForwardSink) OnConnected(conn net.Conn) []byte { func (f *ForwardSink) OnConnected(conn net.Conn) []byte {
@@ -67,18 +72,68 @@ func (f *ForwardSink) Write(index int, data []*collections.ReferenceCounter[[]by
return nil return nil
} }
var processedData []*collections.ReferenceCounter[[]byte]
// ssrc不一致, 重新拷贝一份, 修改为指定的ssrc
if f.requireSSRCMatch && f.ssrc != binary.BigEndian.Uint32(data[0].Get()[2+8:]) {
if TransportTypeUDP != f.transportType {
if f.rtpBuffer == nil {
f.rtpBuffer = NewRtpBuffer(1024)
}
processedData = make([]*collections.ReferenceCounter[[]byte], 0, len(data))
} else if f.rtpBuffer == nil {
f.rtpBuffer = NewRtpBuffer(1)
}
for i, datum := range data {
src := datum.Get()
counter := f.rtpBuffer.Get()
bytes := counter.Get()
length := len(src)
copy(bytes, src[:length])
// 修改ssrc
binary.BigEndian.PutUint32(bytes[2+8:], f.ssrc)
// UDP直接发送
if TransportTypeUDP == f.transportType { if TransportTypeUDP == f.transportType {
for _, datum := range data { _ = f.socket.(*transport.UDPClient).Write(bytes[2:length])
} else {
counter.ResetData(bytes[:length])
counter.Refer()
processedData[i] = counter
}
}
// UDP已经发送, 直接返回
if processedData == nil {
return nil
} else {
// 引用计数保持为1
for _, pkt := range processedData {
pkt.Release()
}
}
}
if processedData == nil {
processedData = data
}
if TransportTypeUDP == f.transportType {
for _, datum := range processedData {
f.socket.(*transport.UDPClient).Write(datum.Get()[2:]) f.socket.(*transport.UDPClient).Write(datum.Get()[2:])
} }
} else { } else {
return f.BaseSink.Write(index, data, ts, keyVideo) return f.BaseSink.Write(index, processedData, ts, keyVideo)
} }
return nil return nil
} }
// Close 关闭国标转发流 // Close 关闭转发流
func (f *ForwardSink) Close() { func (f *ForwardSink) Close() {
f.BaseSink.Close() f.BaseSink.Close()
@@ -89,6 +144,10 @@ func (f *ForwardSink) Close() {
if f.receiveTimer != nil { if f.receiveTimer != nil {
f.receiveTimer.Stop() f.receiveTimer.Stop()
} }
if f.rtpBuffer != nil {
f.rtpBuffer.Clear()
}
} }
// StartReceiveTimer 启动tcp sever计时器, 如果计时器触发, 没有连接, 则关闭流 // StartReceiveTimer 启动tcp sever计时器, 如果计时器触发, 没有连接, 则关闭流
@@ -101,10 +160,16 @@ func (f *ForwardSink) StartReceiveTimer() {
}) })
} }
func NewForwardSink(transportType TransportType, protocol TransStreamProtocol, sinkId SinkID, sourceId string, addr string, manager transport.Manager) (*ForwardSink, int, error) { func (f *ForwardSink) GetSSRC() uint32 {
return f.ssrc
}
func NewForwardSink(transportType TransportType, protocol TransStreamProtocol, sinkId SinkID, sourceId string, addr string, manager transport.Manager, ssrc uint32) (*ForwardSink, int, error) {
sink := &ForwardSink{ sink := &ForwardSink{
BaseSink: BaseSink{ID: sinkId, SourceID: sourceId, State: SessionStateCreated, Protocol: protocol}, BaseSink: BaseSink{ID: sinkId, SourceID: sourceId, State: SessionStateCreated, Protocol: protocol},
transportType: transportType, transportType: transportType,
ssrc: ssrc,
requireSSRCMatch: true, // 默认要求ssrc一致
} }
if transportType == TransportTypeUDP { if transportType == TransportTypeUDP {

39
stream/rtp_buffer.go Normal file
View File

@@ -0,0 +1,39 @@
package stream
import (
"github.com/lkmio/avformat/collections"
)
type RtpBuffer struct {
queue *collections.Queue[*collections.ReferenceCounter[[]byte]]
}
func (r *RtpBuffer) Get() *collections.ReferenceCounter[[]byte] {
if r.queue.Size() > 0 {
rtp := r.queue.Peek(0)
if rtp.UseCount() < 2 {
bytes := rtp.Get()
rtp.ResetData(bytes[:cap(bytes)])
return rtp
}
}
bytes := collections.NewReferenceCounter(UDPReceiveBufferPool.Get().([]byte))
r.queue.Push(bytes)
return bytes
}
func (r *RtpBuffer) Clear() {
for r.queue.Size() > 0 {
if r.queue.Peek(0).UseCount() > 1 {
break
}
bytes := r.queue.Pop().Get()
UDPReceiveBufferPool.Put(bytes[:cap(bytes)])
}
}
func NewRtpBuffer(capacity int) *RtpBuffer {
return &RtpBuffer{queue: collections.NewQueue[*collections.ReferenceCounter[[]byte]](capacity)}
}

View File

@@ -9,7 +9,7 @@ import (
type RtpStream struct { type RtpStream struct {
BaseTransStream BaseTransStream
rtpBuffers *collections.Queue[*collections.ReferenceCounter[[]byte]] rtpBuffer *RtpBuffer
} }
func (f *RtpStream) WriteHeader() error { func (f *RtpStream) WriteHeader() error {
@@ -23,38 +23,23 @@ func (f *RtpStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo
return nil, 0, false, nil return nil, 0, false, nil
} }
// 释放rtp包 counter := f.rtpBuffer.Get()
for f.rtpBuffers.Size() > 0 { bytes := counter.Get()
rtp := f.rtpBuffers.Peek(0)
if rtp.UseCount() > 1 {
break
}
f.rtpBuffers.Pop()
// 放回池中
data := rtp.Get()
UDPReceiveBufferPool.Put(data[:cap(data)])
}
bytes := UDPReceiveBufferPool.Get().([]byte)
binary.BigEndian.PutUint16(bytes, size-2) binary.BigEndian.PutUint16(bytes, size-2)
copy(bytes[2:], packet.Data) copy(bytes[2:], packet.Data)
counter.ResetData(bytes)
rtp := collections.NewReferenceCounter(bytes[:size])
f.rtpBuffers.Push(rtp)
// 每帧都当关键帧, 直接发给上级 // 每帧都当关键帧, 直接发给上级
return []*collections.ReferenceCounter[[]byte]{rtp}, -1, true, nil return []*collections.ReferenceCounter[[]byte]{counter}, -1, true, nil
} }
func NewRtpTransStream(protocol TransStreamProtocol, capacity int) *RtpStream { func NewRtpTransStream(protocol TransStreamProtocol, capacity int) *RtpStream {
return &RtpStream{ return &RtpStream{
BaseTransStream: BaseTransStream{Protocol: protocol}, BaseTransStream: BaseTransStream{Protocol: protocol},
rtpBuffers: collections.NewQueue[*collections.ReferenceCounter[[]byte]](capacity), rtpBuffer: NewRtpBuffer(capacity),
} }
} }
func GBCascadedTransStreamFactory(source Source, protocol TransStreamProtocol, tracks []*Track) (TransStream, error) { func GBCascadedTransStreamFactory(_ Source, _ TransStreamProtocol, _ []*Track, _ Sink) (TransStream, error) {
return NewRtpTransStream(TransStreamGBCascaded, 1024), nil return NewRtpTransStream(TransStreamGBCascaded, 1024), nil
} }

View File

@@ -83,7 +83,7 @@ func SubscribeStreamWithOptions(sink Sink, values url.Values, ready bool, timeou
return state return state
} }
func ForwardStream(protocol TransStreamProtocol, transport TransportType, sourceId string, values url.Values, remoteAddr string, manager transport.Manager) (Sink, int, error) { func ForwardStream(protocol TransStreamProtocol, transport TransportType, sourceId string, values url.Values, remoteAddr string, manager transport.Manager, ssrc uint32) (Sink, int, error) {
//source := SourceManager.Find(sourceId) //source := SourceManager.Find(sourceId)
//if source == nil { //if source == nil {
// return nil, 0, fmt.Errorf("source %s 不存在", sourceId) // return nil, 0, fmt.Errorf("source %s 不存在", sourceId)
@@ -91,7 +91,7 @@ func ForwardStream(protocol TransStreamProtocol, transport TransportType, source
sinkId := GenerateUint64SinkID() sinkId := GenerateUint64SinkID()
var port int var port int
sink, port, err := NewForwardSink(transport, protocol, sinkId, sourceId, remoteAddr, manager) sink, port, err := NewForwardSink(transport, protocol, sinkId, sourceId, remoteAddr, manager, ssrc)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }

View File

@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
) )
type TransStreamFactory func(source Source, protocol TransStreamProtocol, tracks []*Track) (TransStream, error) type TransStreamFactory func(source Source, protocol TransStreamProtocol, tracks []*Track, sink Sink) (TransStream, error)
type RecordStreamFactory func(source string) (Sink, string, error) type RecordStreamFactory func(source string) (Sink, string, error)
@@ -35,13 +35,13 @@ func FindTransStreamFactory(protocol TransStreamProtocol) (TransStreamFactory, e
return f, nil return f, nil
} }
func CreateTransStream(source Source, protocol TransStreamProtocol, tracks []*Track) (TransStream, error) { func CreateTransStream(source Source, protocol TransStreamProtocol, tracks []*Track, sink Sink) (TransStream, error) {
factory, err := FindTransStreamFactory(protocol) factory, err := FindTransStreamFactory(protocol)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return factory(source, protocol, tracks) return factory(source, protocol, tracks, sink)
} }
func SetRecordStreamFactory(factory RecordStreamFactory) { func SetRecordStreamFactory(factory RecordStreamFactory) {

View File

@@ -187,7 +187,7 @@ func (t *transStreamPublisher) CreateDefaultOutStreams() {
utils.Assert(len(streams) > 0) utils.Assert(len(streams) > 0)
id := GenerateTransStreamID(TransStreamHls, streams...) id := GenerateTransStreamID(TransStreamHls, streams...)
hlsStream, err := t.CreateTransStream(id, TransStreamHls, streams) hlsStream, err := t.CreateTransStream(id, TransStreamHls, streams, nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -206,12 +206,12 @@ func IsSupportMux(protocol TransStreamProtocol, _, _ utils.AVCodecID) bool {
return true return true
} }
func (t *transStreamPublisher) CreateTransStream(id TransStreamID, protocol TransStreamProtocol, tracks []*Track) (TransStream, error) { func (t *transStreamPublisher) CreateTransStream(id TransStreamID, protocol TransStreamProtocol, tracks []*Track, sink Sink) (TransStream, error) {
log.Sugar.Infof("创建%s-stream source: %s", protocol.String(), t.source) log.Sugar.Infof("创建%s-stream source: %s", protocol.String(), t.source)
source := SourceManager.Find(t.source) source := SourceManager.Find(t.source)
utils.Assert(source != nil) utils.Assert(source != nil)
transStream, err := CreateTransStream(source, protocol, tracks) transStream, err := CreateTransStream(source, protocol, tracks, sink)
if err != nil { if err != nil {
log.Sugar.Errorf("创建传输流失败 err: %s source: %s", err.Error(), t.source) log.Sugar.Errorf("创建传输流失败 err: %s source: %s", err.Error(), t.source)
return nil, err return nil, err
@@ -356,7 +356,7 @@ func (t *transStreamPublisher) doAddSink(sink Sink, resume bool) bool {
transStream, exist := t.transStreams[transStreamId] transStream, exist := t.transStreams[transStreamId]
if !exist { if !exist {
var err error var err error
transStream, err = t.CreateTransStream(transStreamId, sink.GetProtocol(), tracks) transStream, err = t.CreateTransStream(transStreamId, sink.GetProtocol(), tracks, sink)
if err != nil { if err != nil {
log.Sugar.Errorf("添加sink失败,创建传输流发生err: %s source: %s", err.Error(), t.source) log.Sugar.Errorf("添加sink失败,创建传输流发生err: %s source: %s", err.Error(), t.source)
return false return false