refactor: 解析音视频帧不再单独占用一个协程,直接在网络收流协程完成;

This commit is contained in:
ydajiang
2025-06-07 17:32:59 +08:00
parent fd718ffec2
commit 3553a1b582
17 changed files with 206 additions and 269 deletions

View File

@@ -215,7 +215,7 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) {
id := r.FormValue("source") id := r.FormValue("source")
talkSource := gb28181.NewTalkSource(id, conn) talkSource := gb28181.NewTalkSource(id, conn)
talkSource.Init(stream.TCPReceiveBufferQueueSize) talkSource.Init()
talkSource.SetUrlValues(r.Form) talkSource.SetUrlValues(r.Form)
_, state := stream.PreparePublishSource(talkSource, true) _, state := stream.PreparePublishSource(talkSource, true)
@@ -227,7 +227,9 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) {
log.Sugar.Infof("ws对讲连接成功, source: %s", talkSource) log.Sugar.Infof("ws对讲连接成功, source: %s", talkSource)
go stream.LoopEvent(talkSource) stream.LoopEvent(talkSource)
data := stream.UDPReceiveBufferPool.Get().([]byte)
for { for {
_, bytes, err := conn.ReadMessage() _, bytes, err := conn.ReadMessage()
@@ -240,10 +242,9 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) {
} }
for i := 0; i < length; { for i := 0; i < length; {
data := stream.UDPReceiveBufferPool.Get().([]byte)
n := bufio.MinInt(stream.UDPReceiveBufferSize, length-i) n := bufio.MinInt(stream.UDPReceiveBufferSize, length-i)
copy(data, bytes[:n]) copy(data, bytes[:n])
_ = talkSource.PublishSource.Input(data[:n]) _, _ = talkSource.PublishSource.Input(data[:n])
i += n i += n
} }
} }

View File

@@ -6,6 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/lkmio/avformat" "github.com/lkmio/avformat"
"github.com/lkmio/avformat/bufio"
"github.com/lkmio/avformat/utils" "github.com/lkmio/avformat/utils"
"github.com/lkmio/mpeg" "github.com/lkmio/mpeg"
"github.com/lkmio/transport" "github.com/lkmio/transport"
@@ -69,8 +70,7 @@ func createSource(source, setup string, ssrc uint32) (string, uint16, uint32) {
panic(err) panic(err)
} }
//request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/source/create", bytes.NewBuffer(marshal)) request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/source/create", bytes.NewBuffer(marshal))
request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/offer/create", bytes.NewBuffer(marshal))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -316,3 +316,33 @@ func TestPublish(t *testing.T) {
connectSource(id, fmt.Sprintf("%s:%d", ip, port)) connectSource(id, fmt.Sprintf("%s:%d", ip, port))
}) })
} }
func TestDecode(t *testing.T) {
t.Run("decode_raw", func(t *testing.T) {
file, err2 := os.ReadFile("../dump/gb28181-192.168.2.103.37841")
if err2 != nil {
panic(err2)
}
filter := NewSingleFilter(NewPassiveSource())
session := NewTCPSession(nil, filter)
reader := bufio.NewBytesReader(file)
for {
size, err2 := reader.ReadUint32()
if err2 != nil {
break
}
bytes, err2 := reader.ReadBytes(int(size))
if err2 != nil {
break
}
err2 = session.DecodeGBRTPOverTCPPacket(bytes, filter, nil)
if err2 != nil {
break
}
}
})
}

View File

@@ -83,15 +83,16 @@ type GBSource interface {
SetSSRC(ssrc uint32) SetSSRC(ssrc uint32)
SSRC() uint32 SSRC() uint32
ProcessPacket(data []byte) error
} }
type BaseGBSource struct { type BaseGBSource struct {
stream.PublishSource stream.PublishSource
transport transport.Transport
probeBuffer *mpeg.PSProbeBuffer probeBuffer *mpeg.PSProbeBuffer
ssrc uint32
ssrc uint32
transport transport.Transport
audioTimestamp int64 audioTimestamp int64
videoTimestamp int64 videoTimestamp int64
@@ -102,7 +103,7 @@ type BaseGBSource struct {
sameTimePackets [][]byte sameTimePackets [][]byte
} }
func (source *BaseGBSource) Init(receiveQueueSize int) { func (source *BaseGBSource) Init() {
source.TransDemuxer = mpeg.NewPSDemuxer(false) source.TransDemuxer = mpeg.NewPSDemuxer(false)
source.TransDemuxer.SetHandler(source) source.TransDemuxer.SetHandler(source)
source.TransDemuxer.SetOnPreprocessPacketHandler(func(packet *avformat.AVPacket) { source.TransDemuxer.SetOnPreprocessPacketHandler(func(packet *avformat.AVPacket) {
@@ -110,12 +111,12 @@ func (source *BaseGBSource) Init(receiveQueueSize int) {
}) })
source.SetType(stream.SourceType28181) source.SetType(stream.SourceType28181)
source.probeBuffer = mpeg.NewProbeBuffer(PsProbeBufferSize) source.probeBuffer = mpeg.NewProbeBuffer(PsProbeBufferSize)
source.PublishSource.Init(receiveQueueSize) source.PublishSource.Init()
source.lastRtpTimestamp = -1 source.lastRtpTimestamp = -1
} }
// Input 输入rtp包, 处理PS流, 负责解析->封装->推流 // ProcessPacket 输入rtp包, 处理PS流, 负责解析->封装->推流
func (source *BaseGBSource) Input(data []byte) error { func (source *BaseGBSource) ProcessPacket(data []byte) error {
packet := rtp.Packet{} packet := rtp.Packet{}
_ = packet.Unmarshal(data) _ = packet.Unmarshal(data)
@@ -150,7 +151,7 @@ func (source *BaseGBSource) Input(data []byte) error {
var err error var err error
bytes, err = source.probeBuffer.Input(packet.Payload) bytes, err = source.probeBuffer.Input(packet.Payload)
if err == nil { if err == nil {
n, err = source.TransDemuxer.Input(bytes) n, err = source.PublishSource.Input(bytes)
} }
// 非解析缓冲区满的错误, 继续解析 // 非解析缓冲区满的错误, 继续解析
@@ -347,20 +348,13 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, int,
} }
} }
var queueSize int
if active || tcp {
queueSize = stream.TCPReceiveBufferQueueSize
} else {
queueSize = stream.UDPReceiveBufferQueueSize
}
source.SetID(id) source.SetID(id)
source.SetSSRC(ssrc) source.SetSSRC(ssrc)
source.Init(queueSize) source.Init()
if _, state := stream.PreparePublishSource(source, false); utils.HookStateOK != state { if _, state := stream.PreparePublishSource(source, false); utils.HookStateOK != state {
return nil, 0, fmt.Errorf("error code %d", state) return nil, 0, fmt.Errorf("error code %d", state)
} }
go stream.LoopEvent(source) stream.LoopEvent(source)
return source, port, err return source, port, err
} }

View File

@@ -1,7 +1,6 @@
package gb28181 package gb28181
import ( import (
"github.com/lkmio/transport"
"net" "net"
) )
@@ -35,9 +34,6 @@ func NewActiveSource() (*ActiveSource, int, error) {
}) })
return &ActiveSource{ return &ActiveSource{
PassiveSource: PassiveSource{
decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2),
},
port: port, port: port,
}, port, nil }, port, nil
} }

View File

@@ -1,16 +1,7 @@
package gb28181 package gb28181
import "github.com/lkmio/transport"
type PassiveSource struct { type PassiveSource struct {
BaseGBSource BaseGBSource
decoder *transport.LengthFieldFrameDecoder
}
// Input 重写stream.Source的Input函数, 主协程把推流数据交给PassiveSource处理
func (p *PassiveSource) Input(data []byte) error {
_, err := DecodeGBRTPOverTCPPacket(data, p, p.decoder, nil, p.Conn)
return err
} }
func (p *PassiveSource) SetupType() SetupType { func (p *PassiveSource) SetupType() SetupType {
@@ -18,7 +9,5 @@ func (p *PassiveSource) SetupType() SetupType {
} }
func NewPassiveSource() *PassiveSource { func NewPassiveSource() *PassiveSource {
return &PassiveSource{ return &PassiveSource{}
decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2),
}
} }

View File

@@ -19,7 +19,8 @@ func (u *UDPSource) SetupType() SetupType {
// OnOrderedRtp 有序RTP包回调 // OnOrderedRtp 有序RTP包回调
func (u *UDPSource) OnOrderedRtp(packet *rtp.Packet) { func (u *UDPSource) OnOrderedRtp(packet *rtp.Packet) {
// 此时还在网络收流携程, 交给Source的主协程处理 // 此时还在网络收流携程, 交给Source的主协程处理
u.PublishSource.Input(packet.Raw) u.ProcessPacket(packet.Raw)
stream.UDPReceiveBufferPool.Put(packet.Raw[:cap(packet.Raw)])
} }
// InputRtpPacket 将RTP包排序后交给Source的主协程处理 // InputRtpPacket 将RTP包排序后交给Source的主协程处理

View File

@@ -46,11 +46,6 @@ type TalkSource struct {
stream.PublishSource stream.PublishSource
} }
func (s *TalkSource) Input(data []byte) error {
_, err := s.PublishSource.TransDemuxer.Input(data)
return err
}
func (s *TalkSource) Close() { func (s *TalkSource) Close() {
s.PublishSource.Close() s.PublishSource.Close()
// 关闭所有对讲设备的会话 // 关闭所有对讲设备的会话

View File

@@ -35,35 +35,21 @@ func (T *TCPServer) OnCloseSession(session *TCPSession) {
func (T *TCPServer) OnConnected(conn net.Conn) []byte { func (T *TCPServer) OnConnected(conn net.Conn) []byte {
T.StreamServer.OnConnected(conn) T.StreamServer.OnConnected(conn)
return stream.TCPReceiveBufferPool.Get().([]byte) return conn.(*transport.Conn).Data.(*TCPSession).receiveBuffer
} }
func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte { func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte {
T.StreamServer.OnPacket(conn, data) T.StreamServer.OnPacket(conn, data)
session := conn.(*transport.Conn).Data.(*TCPSession) session := conn.(*transport.Conn).Data.(*TCPSession)
// 单端口推流时, 先解析出SSRC找到GBSource. 后序将推流数据交给stream.Source处理 err := session.DecodeGBRTPOverTCPPacket(data, T.filter, conn)
if session.source == nil { if err != nil {
source, err := DecodeGBRTPOverTCPPacket(data, nil, session.decoder, T.filter, conn) log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data))
if err != nil { _ = conn.Close()
log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data)) return nil
_ = conn.Close()
return nil
}
if source != nil {
session.Init(source)
}
} else {
// 将流交给Source的主协程处理主协程最终会调用PassiveSource的Input函数处理
if session.source.SetupType() == SetupPassive {
session.source.(*PassiveSource).PublishSource.Input(data)
} else {
session.source.(*ActiveSource).PublishSource.Input(data)
}
} }
return stream.TCPReceiveBufferPool.Get().([]byte) return session.receiveBuffer
} }
func NewTCPServer(filter Filter) (*TCPServer, error) { func NewTCPServer(filter Filter) (*TCPServer, error) {

View File

@@ -10,9 +10,10 @@ import (
// TCPSession 国标TCP主被动推流Session, 统一处理TCP粘包. // TCPSession 国标TCP主被动推流Session, 统一处理TCP粘包.
type TCPSession struct { type TCPSession struct {
conn net.Conn conn net.Conn
source GBSource source GBSource
decoder *transport.LengthFieldFrameDecoder decoder *transport.LengthFieldFrameDecoder
receiveBuffer []byte
} }
func (t *TCPSession) Init(source GBSource) { func (t *TCPSession) Init(source GBSource) {
@@ -25,14 +26,17 @@ func (t *TCPSession) Close() {
t.source.Close() t.source.Close()
t.source = nil t.source = nil
} }
stream.TCPReceiveBufferPool.Put(t.receiveBuffer[:cap(t.receiveBuffer)])
} }
func DecodeGBRTPOverTCPPacket(data []byte, source GBSource, decoder *transport.LengthFieldFrameDecoder, filter Filter, conn net.Conn) (GBSource, error) { func (t *TCPSession) DecodeGBRTPOverTCPPacket(data []byte, filter Filter, conn net.Conn) error {
length := len(data) length := len(data)
for i := 0; i < length; { for i := 0; i < length; {
n, bytes, err := decoder.Input(data[i:]) // 解析粘包数据
n, bytes, err := t.decoder.Input(data[i:])
if err != nil { if err != nil {
return source, err return err
} }
i += n i += n
@@ -41,40 +45,38 @@ func DecodeGBRTPOverTCPPacket(data []byte, source GBSource, decoder *transport.L
} }
// 单端口模式,ssrc匹配source // 单端口模式,ssrc匹配source
if source == nil || stream.SessionStateHandshakeSuccess == source.State() { if t.source == nil || stream.SessionStateHandshakeSuccess == t.source.State() {
packet := rtp.Packet{} packet := rtp.Packet{}
if err := packet.Unmarshal(bytes); err != nil { if err = packet.Unmarshal(bytes); err != nil {
return nil, err return err
} else if source == nil { } else if t.source == nil {
source = filter.FindSource(packet.SSRC) t.source = filter.FindSource(packet.SSRC)
} }
if source == nil { if t.source == nil {
// ssrc 匹配不到Source // ssrc 匹配不到Source
return nil, fmt.Errorf("gb28181推流失败 ssrc: %x 匹配不到source", packet.SSRC) return fmt.Errorf("gb28181推流失败 ssrc: %x 匹配不到source", packet.SSRC)
} }
if stream.SessionStateHandshakeSuccess == source.State() { if stream.SessionStateHandshakeSuccess == t.source.State() {
source.PreparePublish(conn, packet.SSRC, source) t.source.PreparePublish(conn, packet.SSRC, t.source)
} }
} }
// 如果是单端口推流, 并且刚才与source绑定, 此时正位于网络收流协程, 否则都位于主协程 if err = t.source.ProcessPacket(bytes); err != nil {
if source.SetupType() == SetupPassive { return err
source.(*PassiveSource).BaseGBSource.Input(bytes)
} else {
source.(*ActiveSource).BaseGBSource.Input(bytes)
} }
} }
return source, nil return nil
} }
func NewTCPSession(conn net.Conn, filter Filter) *TCPSession { func NewTCPSession(conn net.Conn, filter Filter) *TCPSession {
session := &TCPSession{ session := &TCPSession{
conn: conn, conn: conn,
// filter: filter, // filter: filter,
decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2),
receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte),
} }
// 多端口收流, Source已知, 直接初始化Session // 多端口收流, Source已知, 直接初始化Session

View File

@@ -16,18 +16,11 @@ type UDPServer struct {
filter Filter filter Filter
} }
func (U *UDPServer) OnNewSession(conn net.Conn) *UDPSource { func (U *UDPServer) OnNewSession(_ net.Conn) *UDPSource {
return nil return nil
} }
func (U *UDPServer) OnCloseSession(session *UDPSource) { func (U *UDPServer) OnCloseSession(_ *UDPSource) {
U.filter.RemoveSource(session.SSRC())
session.Close()
if stream.AppConfig.GB28181.IsMultiPort() {
U.udp.Close()
U.Handler = nil
}
} }
func (U *UDPServer) OnPacket(conn net.Conn, data []byte) []byte { func (U *UDPServer) OnPacket(conn net.Conn, data []byte) []byte {
@@ -52,7 +45,7 @@ func (U *UDPServer) OnPacket(conn net.Conn, data []byte) []byte {
} }
packet.Raw = data packet.Raw = data
source.(*UDPSource).InputRtpPacket(&packet) _ = source.(*UDPSource).InputRtpPacket(&packet)
return nil return nil
} }

View File

@@ -30,8 +30,8 @@ func (s *jtServer) OnCloseSession(session *Session) {
func (s *jtServer) OnPacket(conn net.Conn, data []byte) []byte { func (s *jtServer) OnPacket(conn net.Conn, data []byte) []byte {
s.StreamServer.OnPacket(conn, data) s.StreamServer.OnPacket(conn, data)
session := conn.(*transport.Conn).Data.(*Session) session := conn.(*transport.Conn).Data.(*Session)
session.PublishSource.Input(data) _, _ = session.Input(data)
return stream.TCPReceiveBufferPool.Get().([]byte) return session.receiveBuffer
} }
func (s *jtServer) Start(addr net.Addr) error { func (s *jtServer) Start(addr net.Addr) error {

View File

@@ -11,15 +11,16 @@ import (
type Session struct { type Session struct {
stream.PublishSource stream.PublishSource
decoder *transport.DelimiterFrameDecoder decoder *transport.DelimiterFrameDecoder
receiveBuffer []byte
} }
func (s *Session) Input(data []byte) error { func (s *Session) Input(data []byte) (int, error) {
var n int var n int
for length := len(data); n < length; { for length := len(data); n < length; {
i, bytes, err := s.decoder.Input(data[n:]) i, bytes, err := s.decoder.Input(data[n:])
if err != nil { if err != nil {
return err return -1, err
} else if len(bytes) < 1 { } else if len(bytes) < 1 {
break break
} }
@@ -27,9 +28,9 @@ func (s *Session) Input(data []byte) error {
n += i n += i
demuxer := s.TransDemuxer.(*Demuxer) demuxer := s.TransDemuxer.(*Demuxer)
firstOfPacket := demuxer.prevPacket == nil firstOfPacket := demuxer.prevPacket == nil
_, err = demuxer.Input(bytes) _, err = s.PublishSource.Input(bytes)
if err != nil { if err != nil {
return err return -1, err
} }
// 首包处理, hook通知 // 首包处理, hook通知
@@ -49,7 +50,7 @@ func (s *Session) Input(data []byte) error {
} }
} }
return nil return 0, nil
} }
func (s *Session) Close() { func (s *Session) Close() {
@@ -61,6 +62,7 @@ func (s *Session) Close() {
} }
s.PublishSource.Close() s.PublishSource.Close()
stream.TCPReceiveBufferPool.Put(s.receiveBuffer[:cap(s.receiveBuffer)])
} }
func NewSession(conn net.Conn) *Session { func NewSession(conn net.Conn) *Session {
@@ -72,11 +74,12 @@ func NewSession(conn net.Conn) *Session {
TransDemuxer: NewDemuxer(), TransDemuxer: NewDemuxer(),
}, },
decoder: transport.NewDelimiterFrameDecoder(1024*1024*2, delimiter[:]), decoder: transport.NewDelimiterFrameDecoder(1024*1024*2, delimiter[:]),
receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte),
} }
session.TransDemuxer.SetHandler(&session) session.TransDemuxer.SetHandler(&session)
session.Init(stream.TCPReceiveBufferQueueSize) session.Init()
go stream.LoopEvent(&session) stream.LoopEvent(&session)
return &session return &session
} }

View File

@@ -12,10 +12,6 @@ type Publisher struct {
Stack *rtmp.ServerStack Stack *rtmp.ServerStack
} }
func (p *Publisher) Input(data []byte) error {
return p.Stack.Input(p.Conn, data)
}
func (p *Publisher) Close() { func (p *Publisher) Close() {
p.PublishSource.Close() p.PublishSource.Close()
p.Stack = nil p.Stack = nil

View File

@@ -35,7 +35,7 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState {
source := NewPublisher(sourceId, s.stack, s.conn) source := NewPublisher(sourceId, s.stack, s.conn)
// 初始化放在add source前面, 以防add后再init, 空窗期拉流队列空指针. // 初始化放在add source前面, 以防add后再init, 空窗期拉流队列空指针.
source.Init(stream.TCPReceiveBufferQueueSize) source.Init()
source.SetUrlValues(values) source.SetUrlValues(values)
// 统一处理source推流事件, source是否已经存在, hook回调.... // 统一处理source推流事件, source是否已经存在, hook回调....
@@ -46,7 +46,7 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState {
s.handle = source s.handle = source
s.isPublisher = true s.isPublisher = true
go stream.LoopEvent(source) stream.LoopEvent(source)
} }
return state return state
@@ -73,7 +73,14 @@ func (s *Session) OnPlay(app, stream_ string) utils.HookState {
func (s *Session) Input(data []byte) error { func (s *Session) Input(data []byte) error {
// 推流会话, 收到的包都将交由主协程处理 // 推流会话, 收到的包都将交由主协程处理
if s.isPublisher { if s.isPublisher {
return s.handle.(*Publisher).PublishSource.Input(data) s.handle.(*Publisher).UpdateReceiveStats(len(data))
var err error
s.handle.(*Publisher).ExecuteSyncEvent(func() {
err = s.stack.Input(s.conn, data)
})
return err
} else { } else {
return s.stack.Input(s.conn, data) return s.stack.Input(s.conn, data)
} }

View File

@@ -27,7 +27,7 @@ func (f *RtpStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo
bytes := counter.Get() bytes := counter.Get()
binary.BigEndian.PutUint16(bytes, size-2) binary.BigEndian.PutUint16(bytes, size-2)
copy(bytes[2:], packet.Data) copy(bytes[2:], packet.Data)
counter.ResetData(bytes) counter.ResetData(bytes[:2+len(bytes)])
// 每帧都当关键帧, 直接发给上级 // 每帧都当关键帧, 直接发给上级
return []*collections.ReferenceCounter[[]byte]{counter}, -1, true, nil return []*collections.ReferenceCounter[[]byte]{counter}, -1, true, nil

View File

@@ -27,7 +27,7 @@ type Source interface {
SetID(id string) SetID(id string)
// Input 输入推流数据 // Input 输入推流数据
Input(data []byte) error Input(data []byte) (int, error)
// GetType 返回推流类型 // GetType 返回推流类型
GetType() SourceType GetType() SourceType
@@ -47,7 +47,7 @@ type Source interface {
// IsCompleted 所有推流track是否解析完毕 // IsCompleted 所有推流track是否解析完毕
IsCompleted() bool IsCompleted() bool
Init(receiveQueueSize int) Init()
RemoteAddr() string RemoteAddr() string
@@ -61,11 +61,6 @@ type Source interface {
// SetUrlValues 设置推流url参数 // SetUrlValues 设置推流url参数
SetUrlValues(values url.Values) SetUrlValues(values url.Values)
// PostEvent 切换到主协程执行当前函数
postEvent(cb func())
executeSyncEvent(cb func())
// LastPacketTime 返回最近收流时间戳 // LastPacketTime 返回最近收流时间戳
LastPacketTime() time.Time LastPacketTime() time.Time
@@ -73,10 +68,6 @@ type Source interface {
IsClosed() bool IsClosed() bool
StreamPipe() chan []byte
MainContextEvents() chan func()
CreateTime() time.Time CreateTime() time.Time
SetCreateTime(time time.Time) SetCreateTime(time time.Time)
@@ -86,6 +77,12 @@ type Source interface {
ProbeTimeout() ProbeTimeout()
GetTransStreamPublisher() TransStreamPublisher GetTransStreamPublisher() TransStreamPublisher
StartTimers(source Source)
ExecuteSyncEvent(cb func())
UpdateReceiveStats(dataLen int)
} }
type PublishSource struct { type PublishSource struct {
@@ -94,9 +91,7 @@ type PublishSource struct {
state SessionState state SessionState
Conn net.Conn Conn net.Conn
streamPipe *NonBlockingChannel[[]byte] // 推流数据管道 streamPublisher TransStreamPublisher // 解析出来的AVStream和AVPacket, 交由streamPublisher处理
mainContextEvents chan func() // 切换到主协程执行函数的事件管道
streamPublisher TransStreamPublisher // 解析出来的AVStream和AVPacket, 交由streamPublisher处理
TransDemuxer avformat.Demuxer // 负责从推流协议中解析出AVStream和AVPacket TransDemuxer avformat.Demuxer // 负责从推流协议中解析出AVStream和AVPacket
originTracks TrackManager // 推流的音视频Streams originTracks TrackManager // 推流的音视频Streams
@@ -110,6 +105,14 @@ type PublishSource struct {
createTime time.Time // source创建时间 createTime time.Time // source创建时间
statistics *BitrateStatistics // 码流统计 statistics *BitrateStatistics // 码流统计
streamLogger avformat.OnUnpackStream2FileHandler streamLogger avformat.OnUnpackStream2FileHandler
// streamLock sync.RWMutex
streamLock sync.Mutex
timers struct {
receiveTimer *time.Timer // 收流超时计时器
idleTimer *time.Timer // 空闲超时计时器
probeTimer *time.Timer // tack探测超时计时器
}
} }
func (s *PublishSource) SetLastPacketTime(time2 time.Time) { func (s *PublishSource) SetLastPacketTime(time2 time.Time) {
@@ -120,14 +123,6 @@ func (s *PublishSource) IsClosed() bool {
return s.closed.Load() return s.closed.Load()
} }
func (s *PublishSource) StreamPipe() chan []byte {
return s.streamPipe.Channel
}
func (s *PublishSource) MainContextEvents() chan func() {
return s.mainContextEvents
}
func (s *PublishSource) LastPacketTime() time.Time { func (s *PublishSource) LastPacketTime() time.Time {
return s.lastPacketTime return s.lastPacketTime
} }
@@ -143,23 +138,35 @@ func (s *PublishSource) SetID(id string) {
} }
} }
func (s *PublishSource) Init(receiveQueueSize int) { func (s *PublishSource) Init() {
s.SetState(SessionStateHandshakeSuccess) s.SetState(SessionStateHandshakeSuccess)
// 初始化事件接收管道
// -2是为了保证从管道取到流, 到处理完流整个过程安全的, 不会被覆盖
s.streamPipe = NewNonBlockingChannel[[]byte](receiveQueueSize - 1)
s.mainContextEvents = make(chan func(), 128)
s.statistics = NewBitrateStatistics() s.statistics = NewBitrateStatistics()
s.streamPublisher = NewTransStreamPublisher(s.ID) s.streamPublisher = NewTransStreamPublisher(s.ID)
// 设置探测时长 // 设置探测时长
s.TransDemuxer.SetProbeDuration(AppConfig.ProbeTimeout) s.TransDemuxer.SetProbeDuration(AppConfig.ProbeTimeout)
} }
func (s *PublishSource) Input(data []byte) error { func (s *PublishSource) UpdateReceiveStats(dataLen int) {
s.streamPipe.Post(data) s.statistics.Input(dataLen)
s.statistics.Input(len(data)) if AppConfig.ReceiveTimeout > 0 {
return nil s.SetLastPacketTime(time.Now())
}
}
func (s *PublishSource) Input(data []byte) (int, error) {
s.UpdateReceiveStats(len(data))
var n int
var err error
s.ExecuteSyncEvent(func() {
if s.closed.Load() {
err = fmt.Errorf("source closed")
} else {
n, err = s.TransDemuxer.Input(data)
}
})
return n, err
} }
func (s *PublishSource) OriginTracks() []*Track { func (s *PublishSource) OriginTracks() []*Track {
@@ -177,12 +184,31 @@ func (s *PublishSource) DoClose() {
return return
} }
s.closed.Store(true) var closed bool
s.ExecuteSyncEvent(func() {
closed = s.closed.Swap(true)
})
if closed {
return
}
// 关闭各种超时计时器
if s.timers.receiveTimer != nil {
s.timers.receiveTimer.Stop()
}
if s.timers.idleTimer != nil {
s.timers.idleTimer.Stop()
}
if s.timers.probeTimer != nil {
s.timers.probeTimer.Stop()
}
// 关闭推流源的解复用器, 不再接收数据 // 关闭推流源的解复用器, 不再接收数据
if s.TransDemuxer != nil { if s.TransDemuxer != nil {
s.TransDemuxer.Close() s.TransDemuxer.Close()
s.TransDemuxer = nil
} }
// 等传输流发布器关闭结束 // 等传输流发布器关闭结束
@@ -210,14 +236,7 @@ func (s *PublishSource) DoClose() {
} }
func (s *PublishSource) Close() { func (s *PublishSource) Close() {
if s.closed.Load() { s.DoClose()
return
}
// 同步执行, 确保close后, 主协程已经退出, 不会再处理任何推拉流、查询等任何事情.
s.executeSyncEvent(func() {
s.DoClose()
})
} }
// 解析完所有track后, 创建各种输出流 // 解析完所有track后, 创建各种输出流
@@ -233,7 +252,8 @@ func (s *PublishSource) writeHeader() {
if len(s.originTracks.All()) == 0 { if len(s.originTracks.All()) == 0 {
log.Sugar.Errorf("没有一路track, 删除source: %s", s.ID) log.Sugar.Errorf("没有一路track, 删除source: %s", s.ID)
s.DoClose() // 异步执行ProbeTimeout函数中还没释放锁
go s.DoClose()
return return
} }
} }
@@ -356,20 +376,11 @@ func (s *PublishSource) SetUrlValues(values url.Values) {
s.urlValues = values s.urlValues = values
} }
func (s *PublishSource) postEvent(cb func()) { func (s *PublishSource) ExecuteSyncEvent(cb func()) {
s.mainContextEvents <- cb // 无竞争情况下, 接近原子操作
} s.streamLock.Lock()
defer s.streamLock.Unlock()
func (s *PublishSource) executeSyncEvent(cb func()) { cb()
group := sync.WaitGroup{}
group.Add(1)
s.postEvent(func() {
cb()
group.Done()
})
group.Wait()
} }
func (s *PublishSource) CreateTime() time.Time { func (s *PublishSource) CreateTime() time.Time {
@@ -386,10 +397,37 @@ func (s *PublishSource) GetBitrateStatistics() *BitrateStatistics {
func (s *PublishSource) ProbeTimeout() { func (s *PublishSource) ProbeTimeout() {
if s.TransDemuxer != nil { if s.TransDemuxer != nil {
s.TransDemuxer.ProbeComplete() s.ExecuteSyncEvent(func() {
if !s.closed.Load() {
s.TransDemuxer.ProbeComplete()
}
})
} }
} }
func (s *PublishSource) GetTransStreamPublisher() TransStreamPublisher { func (s *PublishSource) GetTransStreamPublisher() TransStreamPublisher {
return s.streamPublisher return s.streamPublisher
} }
func (s *PublishSource) StartTimers(source Source) {
// 开启收流超时计时器
if AppConfig.ReceiveTimeout > 0 {
s.timers.receiveTimer = StartReceiveDataTimer(source)
}
// 开启拉流空闲超时计时器
if AppConfig.Hooks.IsEnableOnIdleTimeout() && AppConfig.IdleTimeout > 0 {
s.timers.idleTimer = StartIdleTimer(source)
}
// 开启探测超时计时器
s.timers.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, func() {
if source.IsCompleted() {
return
}
source.ProbeTimeout()
})
}

View File

@@ -197,100 +197,6 @@ func StartIdleTimer(source Source) *time.Timer {
// LoopEvent 循环读取事件 // LoopEvent 循环读取事件
func LoopEvent(source Source) { func LoopEvent(source Source) {
// 将超时计时器放在此处开启, 方便在退出的时候关闭 source.StartTimers(source)
var receiveTimer *time.Timer
var idleTimer *time.Timer
var probeTimer *time.Timer
defer func() {
log.Sugar.Debugf("主协程执行结束 source: %s", source.GetID())
// 关闭计时器
if receiveTimer != nil {
receiveTimer.Stop()
}
if idleTimer != nil {
idleTimer.Stop()
}
if probeTimer != nil {
probeTimer.Stop()
}
// 未使用的数据, 放回池中
for len(source.StreamPipe()) > 0 {
data := <-source.StreamPipe()
if size := cap(data); size > UDPReceiveBufferSize {
TCPReceiveBufferPool.Put(data[:size])
} else {
UDPReceiveBufferPool.Put(data[:size])
}
}
}()
// 开启收流超时计时器
if AppConfig.ReceiveTimeout > 0 {
receiveTimer = StartReceiveDataTimer(source)
}
// 开启拉流空闲超时计时器
if AppConfig.Hooks.IsEnableOnIdleTimeout() && AppConfig.IdleTimeout > 0 {
idleTimer = StartIdleTimer(source)
}
// 开启探测超时计时器
probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, func() {
if source.IsCompleted() {
return
}
var ok bool
source.executeSyncEvent(func() {
source.ProbeTimeout()
ok = len(source.OriginTracks()) > 0
})
if !ok {
source.Close()
return
}
})
// 启动协程, 生成发布传输流
go source.GetTransStreamPublisher().run() go source.GetTransStreamPublisher().run()
for {
select {
// 读取推流数据
case data := <-source.StreamPipe():
if AppConfig.ReceiveTimeout > 0 {
source.SetLastPacketTime(time.Now())
}
if err := source.Input(data); err != nil {
log.Sugar.Errorf("解析推流数据发生err: %s 释放source: %s", err.Error(), source.GetID())
go source.Close()
return
}
// 使用后, 放回池中
if size := cap(data); size > UDPReceiveBufferSize {
TCPReceiveBufferPool.Put(data[:size])
} else {
UDPReceiveBufferPool.Put(data[:size])
}
break
// 切换到主协程,执行该函数. 目的是用于无锁化处理推拉流的连接与断开, 推流源断开, 查询推流源信息等事件. 不要做耗时操作, 否则会影响推拉流.
case event := <-source.MainContextEvents():
event()
if source.IsClosed() {
// 处理推流管道剩余的数据?
return
}
break
}
}
} }