From 3553a1b582a59ece303108a15acbe6af3d53fb97 Mon Sep 17 00:00:00 2001 From: ydajiang Date: Sat, 7 Jun 2025 17:32:59 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E8=A7=A3=E6=9E=90=E9=9F=B3?= =?UTF-8?q?=E8=A7=86=E9=A2=91=E5=B8=A7=E4=B8=8D=E5=86=8D=E5=8D=95=E7=8B=AC?= =?UTF-8?q?=E5=8D=A0=E7=94=A8=E4=B8=80=E4=B8=AA=E5=8D=8F=E7=A8=8B=EF=BC=8C?= =?UTF-8?q?=E7=9B=B4=E6=8E=A5=E5=9C=A8=E7=BD=91=E7=BB=9C=E6=94=B6=E6=B5=81?= =?UTF-8?q?=E5=8D=8F=E7=A8=8B=E5=AE=8C=E6=88=90;?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_gb.go | 9 ++- gb28181/publish_test.go | 34 ++++++++- gb28181/source.go | 28 +++---- gb28181/source_active.go | 4 - gb28181/source_passive.go | 13 +--- gb28181/source_udp.go | 3 +- gb28181/talk_source.go | 5 -- gb28181/tcp_server.go | 28 ++----- gb28181/tcp_session.go | 46 ++++++------ gb28181/udp_server.go | 13 +--- jt1078/jt_server.go | 4 +- jt1078/jt_session.go | 21 +++--- rtmp/rtmp_publisher.go | 4 - rtmp/rtmp_session.go | 13 +++- stream/rtp_stream.go | 2 +- stream/source.go | 152 ++++++++++++++++++++++++-------------- stream/source_utils.go | 96 +----------------------- 17 files changed, 206 insertions(+), 269 deletions(-) diff --git a/api_gb.go b/api_gb.go index 8e7dbc1..6501304 100644 --- a/api_gb.go +++ b/api_gb.go @@ -215,7 +215,7 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) { id := r.FormValue("source") talkSource := gb28181.NewTalkSource(id, conn) - talkSource.Init(stream.TCPReceiveBufferQueueSize) + talkSource.Init() talkSource.SetUrlValues(r.Form) _, state := stream.PreparePublishSource(talkSource, true) @@ -227,7 +227,9 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) { log.Sugar.Infof("ws对讲连接成功, source: %s", talkSource) - go stream.LoopEvent(talkSource) + stream.LoopEvent(talkSource) + + data := stream.UDPReceiveBufferPool.Get().([]byte) for { _, bytes, err := conn.ReadMessage() @@ -240,10 +242,9 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) { } for i := 0; i < length; { - data := stream.UDPReceiveBufferPool.Get().([]byte) n := bufio.MinInt(stream.UDPReceiveBufferSize, length-i) copy(data, bytes[:n]) - _ = talkSource.PublishSource.Input(data[:n]) + _, _ = talkSource.PublishSource.Input(data[:n]) i += n } } diff --git a/gb28181/publish_test.go b/gb28181/publish_test.go index b4f200f..6196c80 100644 --- a/gb28181/publish_test.go +++ b/gb28181/publish_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "github.com/lkmio/avformat" + "github.com/lkmio/avformat/bufio" "github.com/lkmio/avformat/utils" "github.com/lkmio/mpeg" "github.com/lkmio/transport" @@ -69,8 +70,7 @@ func createSource(source, setup string, ssrc uint32) (string, uint16, uint32) { panic(err) } - //request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/source/create", bytes.NewBuffer(marshal)) - request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/offer/create", bytes.NewBuffer(marshal)) + request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/source/create", bytes.NewBuffer(marshal)) if err != nil { panic(err) } @@ -316,3 +316,33 @@ func TestPublish(t *testing.T) { connectSource(id, fmt.Sprintf("%s:%d", ip, port)) }) } + +func TestDecode(t *testing.T) { + t.Run("decode_raw", func(t *testing.T) { + file, err2 := os.ReadFile("../dump/gb28181-192.168.2.103.37841") + if err2 != nil { + panic(err2) + } + + filter := NewSingleFilter(NewPassiveSource()) + session := NewTCPSession(nil, filter) + reader := bufio.NewBytesReader(file) + + for { + size, err2 := reader.ReadUint32() + if err2 != nil { + break + } + + bytes, err2 := reader.ReadBytes(int(size)) + if err2 != nil { + break + } + + err2 = session.DecodeGBRTPOverTCPPacket(bytes, filter, nil) + if err2 != nil { + break + } + } + }) +} diff --git a/gb28181/source.go b/gb28181/source.go index 551e00c..4a9d191 100644 --- a/gb28181/source.go +++ b/gb28181/source.go @@ -83,15 +83,16 @@ type GBSource interface { SetSSRC(ssrc uint32) SSRC() uint32 + + ProcessPacket(data []byte) error } type BaseGBSource struct { stream.PublishSource + transport transport.Transport probeBuffer *mpeg.PSProbeBuffer - - ssrc uint32 - transport transport.Transport + ssrc uint32 audioTimestamp int64 videoTimestamp int64 @@ -102,7 +103,7 @@ type BaseGBSource struct { sameTimePackets [][]byte } -func (source *BaseGBSource) Init(receiveQueueSize int) { +func (source *BaseGBSource) Init() { source.TransDemuxer = mpeg.NewPSDemuxer(false) source.TransDemuxer.SetHandler(source) source.TransDemuxer.SetOnPreprocessPacketHandler(func(packet *avformat.AVPacket) { @@ -110,12 +111,12 @@ func (source *BaseGBSource) Init(receiveQueueSize int) { }) source.SetType(stream.SourceType28181) source.probeBuffer = mpeg.NewProbeBuffer(PsProbeBufferSize) - source.PublishSource.Init(receiveQueueSize) + source.PublishSource.Init() source.lastRtpTimestamp = -1 } -// Input 输入rtp包, 处理PS流, 负责解析->封装->推流 -func (source *BaseGBSource) Input(data []byte) error { +// ProcessPacket 输入rtp包, 处理PS流, 负责解析->封装->推流 +func (source *BaseGBSource) ProcessPacket(data []byte) error { packet := rtp.Packet{} _ = packet.Unmarshal(data) @@ -150,7 +151,7 @@ func (source *BaseGBSource) Input(data []byte) error { var err error bytes, err = source.probeBuffer.Input(packet.Payload) if err == nil { - n, err = source.TransDemuxer.Input(bytes) + n, err = source.PublishSource.Input(bytes) } // 非解析缓冲区满的错误, 继续解析 @@ -347,20 +348,13 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, int, } } - var queueSize int - if active || tcp { - queueSize = stream.TCPReceiveBufferQueueSize - } else { - queueSize = stream.UDPReceiveBufferQueueSize - } - source.SetID(id) source.SetSSRC(ssrc) - source.Init(queueSize) + source.Init() if _, state := stream.PreparePublishSource(source, false); utils.HookStateOK != state { return nil, 0, fmt.Errorf("error code %d", state) } - go stream.LoopEvent(source) + stream.LoopEvent(source) return source, port, err } diff --git a/gb28181/source_active.go b/gb28181/source_active.go index 7c161c9..e6ef54f 100644 --- a/gb28181/source_active.go +++ b/gb28181/source_active.go @@ -1,7 +1,6 @@ package gb28181 import ( - "github.com/lkmio/transport" "net" ) @@ -35,9 +34,6 @@ func NewActiveSource() (*ActiveSource, int, error) { }) return &ActiveSource{ - PassiveSource: PassiveSource{ - decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), - }, port: port, }, port, nil } diff --git a/gb28181/source_passive.go b/gb28181/source_passive.go index 69af0ac..7faafe1 100644 --- a/gb28181/source_passive.go +++ b/gb28181/source_passive.go @@ -1,16 +1,7 @@ package gb28181 -import "github.com/lkmio/transport" - type PassiveSource struct { BaseGBSource - decoder *transport.LengthFieldFrameDecoder -} - -// Input 重写stream.Source的Input函数, 主协程把推流数据交给PassiveSource处理 -func (p *PassiveSource) Input(data []byte) error { - _, err := DecodeGBRTPOverTCPPacket(data, p, p.decoder, nil, p.Conn) - return err } func (p *PassiveSource) SetupType() SetupType { @@ -18,7 +9,5 @@ func (p *PassiveSource) SetupType() SetupType { } func NewPassiveSource() *PassiveSource { - return &PassiveSource{ - decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), - } + return &PassiveSource{} } diff --git a/gb28181/source_udp.go b/gb28181/source_udp.go index 1c6c294..46ae50d 100644 --- a/gb28181/source_udp.go +++ b/gb28181/source_udp.go @@ -19,7 +19,8 @@ func (u *UDPSource) SetupType() SetupType { // OnOrderedRtp 有序RTP包回调 func (u *UDPSource) OnOrderedRtp(packet *rtp.Packet) { // 此时还在网络收流携程, 交给Source的主协程处理 - u.PublishSource.Input(packet.Raw) + u.ProcessPacket(packet.Raw) + stream.UDPReceiveBufferPool.Put(packet.Raw[:cap(packet.Raw)]) } // InputRtpPacket 将RTP包排序后,交给Source的主协程处理 diff --git a/gb28181/talk_source.go b/gb28181/talk_source.go index 3afa958..c06651b 100644 --- a/gb28181/talk_source.go +++ b/gb28181/talk_source.go @@ -46,11 +46,6 @@ type TalkSource struct { stream.PublishSource } -func (s *TalkSource) Input(data []byte) error { - _, err := s.PublishSource.TransDemuxer.Input(data) - return err -} - func (s *TalkSource) Close() { s.PublishSource.Close() // 关闭所有对讲设备的会话 diff --git a/gb28181/tcp_server.go b/gb28181/tcp_server.go index 93d843e..a8966d7 100644 --- a/gb28181/tcp_server.go +++ b/gb28181/tcp_server.go @@ -35,35 +35,21 @@ func (T *TCPServer) OnCloseSession(session *TCPSession) { func (T *TCPServer) OnConnected(conn net.Conn) []byte { T.StreamServer.OnConnected(conn) - return stream.TCPReceiveBufferPool.Get().([]byte) + return conn.(*transport.Conn).Data.(*TCPSession).receiveBuffer } func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte { T.StreamServer.OnPacket(conn, data) session := conn.(*transport.Conn).Data.(*TCPSession) - // 单端口推流时, 先解析出SSRC找到GBSource. 后序将推流数据交给stream.Source处理 - if session.source == nil { - source, err := DecodeGBRTPOverTCPPacket(data, nil, session.decoder, T.filter, conn) - if err != nil { - log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data)) - _ = conn.Close() - return nil - } - - if source != nil { - session.Init(source) - } - } else { - // 将流交给Source的主协程处理,主协程最终会调用PassiveSource的Input函数处理 - if session.source.SetupType() == SetupPassive { - session.source.(*PassiveSource).PublishSource.Input(data) - } else { - session.source.(*ActiveSource).PublishSource.Input(data) - } + err := session.DecodeGBRTPOverTCPPacket(data, T.filter, conn) + if err != nil { + log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data)) + _ = conn.Close() + return nil } - return stream.TCPReceiveBufferPool.Get().([]byte) + return session.receiveBuffer } func NewTCPServer(filter Filter) (*TCPServer, error) { diff --git a/gb28181/tcp_session.go b/gb28181/tcp_session.go index 3a1b80c..f8bb7e3 100644 --- a/gb28181/tcp_session.go +++ b/gb28181/tcp_session.go @@ -10,9 +10,10 @@ import ( // TCPSession 国标TCP主被动推流Session, 统一处理TCP粘包. type TCPSession struct { - conn net.Conn - source GBSource - decoder *transport.LengthFieldFrameDecoder + conn net.Conn + source GBSource + decoder *transport.LengthFieldFrameDecoder + receiveBuffer []byte } func (t *TCPSession) Init(source GBSource) { @@ -25,14 +26,17 @@ func (t *TCPSession) Close() { t.source.Close() t.source = nil } + + stream.TCPReceiveBufferPool.Put(t.receiveBuffer[:cap(t.receiveBuffer)]) } -func DecodeGBRTPOverTCPPacket(data []byte, source GBSource, decoder *transport.LengthFieldFrameDecoder, filter Filter, conn net.Conn) (GBSource, error) { +func (t *TCPSession) DecodeGBRTPOverTCPPacket(data []byte, filter Filter, conn net.Conn) error { length := len(data) for i := 0; i < length; { - n, bytes, err := decoder.Input(data[i:]) + // 解析粘包数据 + n, bytes, err := t.decoder.Input(data[i:]) if err != nil { - return source, err + return err } i += n @@ -41,40 +45,38 @@ func DecodeGBRTPOverTCPPacket(data []byte, source GBSource, decoder *transport.L } // 单端口模式,ssrc匹配source - if source == nil || stream.SessionStateHandshakeSuccess == source.State() { + if t.source == nil || stream.SessionStateHandshakeSuccess == t.source.State() { packet := rtp.Packet{} - if err := packet.Unmarshal(bytes); err != nil { - return nil, err - } else if source == nil { - source = filter.FindSource(packet.SSRC) + if err = packet.Unmarshal(bytes); err != nil { + return err + } else if t.source == nil { + t.source = filter.FindSource(packet.SSRC) } - if source == nil { + if t.source == nil { // ssrc 匹配不到Source - return nil, fmt.Errorf("gb28181推流失败 ssrc: %x 匹配不到source", packet.SSRC) + return fmt.Errorf("gb28181推流失败 ssrc: %x 匹配不到source", packet.SSRC) } - if stream.SessionStateHandshakeSuccess == source.State() { - source.PreparePublish(conn, packet.SSRC, source) + if stream.SessionStateHandshakeSuccess == t.source.State() { + t.source.PreparePublish(conn, packet.SSRC, t.source) } } - // 如果是单端口推流, 并且刚才与source绑定, 此时正位于网络收流协程, 否则都位于主协程 - if source.SetupType() == SetupPassive { - source.(*PassiveSource).BaseGBSource.Input(bytes) - } else { - source.(*ActiveSource).BaseGBSource.Input(bytes) + if err = t.source.ProcessPacket(bytes); err != nil { + return err } } - return source, nil + return nil } func NewTCPSession(conn net.Conn, filter Filter) *TCPSession { session := &TCPSession{ conn: conn, // filter: filter, - decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), + decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), + receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte), } // 多端口收流, Source已知, 直接初始化Session diff --git a/gb28181/udp_server.go b/gb28181/udp_server.go index 8490a36..f99f4f0 100644 --- a/gb28181/udp_server.go +++ b/gb28181/udp_server.go @@ -16,18 +16,11 @@ type UDPServer struct { filter Filter } -func (U *UDPServer) OnNewSession(conn net.Conn) *UDPSource { +func (U *UDPServer) OnNewSession(_ net.Conn) *UDPSource { return nil } -func (U *UDPServer) OnCloseSession(session *UDPSource) { - U.filter.RemoveSource(session.SSRC()) - session.Close() - - if stream.AppConfig.GB28181.IsMultiPort() { - U.udp.Close() - U.Handler = nil - } +func (U *UDPServer) OnCloseSession(_ *UDPSource) { } func (U *UDPServer) OnPacket(conn net.Conn, data []byte) []byte { @@ -52,7 +45,7 @@ func (U *UDPServer) OnPacket(conn net.Conn, data []byte) []byte { } packet.Raw = data - source.(*UDPSource).InputRtpPacket(&packet) + _ = source.(*UDPSource).InputRtpPacket(&packet) return nil } diff --git a/jt1078/jt_server.go b/jt1078/jt_server.go index 9095931..43ff7d2 100644 --- a/jt1078/jt_server.go +++ b/jt1078/jt_server.go @@ -30,8 +30,8 @@ func (s *jtServer) OnCloseSession(session *Session) { func (s *jtServer) OnPacket(conn net.Conn, data []byte) []byte { s.StreamServer.OnPacket(conn, data) session := conn.(*transport.Conn).Data.(*Session) - session.PublishSource.Input(data) - return stream.TCPReceiveBufferPool.Get().([]byte) + _, _ = session.Input(data) + return session.receiveBuffer } func (s *jtServer) Start(addr net.Addr) error { diff --git a/jt1078/jt_session.go b/jt1078/jt_session.go index 6fa14e5..fefcaaf 100644 --- a/jt1078/jt_session.go +++ b/jt1078/jt_session.go @@ -11,15 +11,16 @@ import ( type Session struct { stream.PublishSource - decoder *transport.DelimiterFrameDecoder + decoder *transport.DelimiterFrameDecoder + receiveBuffer []byte } -func (s *Session) Input(data []byte) error { +func (s *Session) Input(data []byte) (int, error) { var n int for length := len(data); n < length; { i, bytes, err := s.decoder.Input(data[n:]) if err != nil { - return err + return -1, err } else if len(bytes) < 1 { break } @@ -27,9 +28,9 @@ func (s *Session) Input(data []byte) error { n += i demuxer := s.TransDemuxer.(*Demuxer) firstOfPacket := demuxer.prevPacket == nil - _, err = demuxer.Input(bytes) + _, err = s.PublishSource.Input(bytes) if err != nil { - return err + return -1, err } // 首包处理, hook通知 @@ -49,7 +50,7 @@ func (s *Session) Input(data []byte) error { } } - return nil + return 0, nil } func (s *Session) Close() { @@ -61,6 +62,7 @@ func (s *Session) Close() { } s.PublishSource.Close() + stream.TCPReceiveBufferPool.Put(s.receiveBuffer[:cap(s.receiveBuffer)]) } func NewSession(conn net.Conn) *Session { @@ -72,11 +74,12 @@ func NewSession(conn net.Conn) *Session { TransDemuxer: NewDemuxer(), }, - decoder: transport.NewDelimiterFrameDecoder(1024*1024*2, delimiter[:]), + decoder: transport.NewDelimiterFrameDecoder(1024*1024*2, delimiter[:]), + receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte), } session.TransDemuxer.SetHandler(&session) - session.Init(stream.TCPReceiveBufferQueueSize) - go stream.LoopEvent(&session) + session.Init() + stream.LoopEvent(&session) return &session } diff --git a/rtmp/rtmp_publisher.go b/rtmp/rtmp_publisher.go index 8235631..bb67019 100644 --- a/rtmp/rtmp_publisher.go +++ b/rtmp/rtmp_publisher.go @@ -12,10 +12,6 @@ type Publisher struct { Stack *rtmp.ServerStack } -func (p *Publisher) Input(data []byte) error { - return p.Stack.Input(p.Conn, data) -} - func (p *Publisher) Close() { p.PublishSource.Close() p.Stack = nil diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index 60afc89..42ecfd1 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -35,7 +35,7 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState { source := NewPublisher(sourceId, s.stack, s.conn) // 初始化放在add source前面, 以防add后再init, 空窗期拉流队列空指针. - source.Init(stream.TCPReceiveBufferQueueSize) + source.Init() source.SetUrlValues(values) // 统一处理source推流事件, source是否已经存在, hook回调.... @@ -46,7 +46,7 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState { s.handle = source s.isPublisher = true - go stream.LoopEvent(source) + stream.LoopEvent(source) } return state @@ -73,7 +73,14 @@ func (s *Session) OnPlay(app, stream_ string) utils.HookState { func (s *Session) Input(data []byte) error { // 推流会话, 收到的包都将交由主协程处理 if s.isPublisher { - return s.handle.(*Publisher).PublishSource.Input(data) + s.handle.(*Publisher).UpdateReceiveStats(len(data)) + + var err error + s.handle.(*Publisher).ExecuteSyncEvent(func() { + err = s.stack.Input(s.conn, data) + }) + + return err } else { return s.stack.Input(s.conn, data) } diff --git a/stream/rtp_stream.go b/stream/rtp_stream.go index 3ec8eda..c35bca5 100644 --- a/stream/rtp_stream.go +++ b/stream/rtp_stream.go @@ -27,7 +27,7 @@ func (f *RtpStream) Input(packet *avformat.AVPacket) ([]*collections.ReferenceCo bytes := counter.Get() binary.BigEndian.PutUint16(bytes, size-2) copy(bytes[2:], packet.Data) - counter.ResetData(bytes) + counter.ResetData(bytes[:2+len(bytes)]) // 每帧都当关键帧, 直接发给上级 return []*collections.ReferenceCounter[[]byte]{counter}, -1, true, nil diff --git a/stream/source.go b/stream/source.go index 422c923..d4325dd 100644 --- a/stream/source.go +++ b/stream/source.go @@ -27,7 +27,7 @@ type Source interface { SetID(id string) // Input 输入推流数据 - Input(data []byte) error + Input(data []byte) (int, error) // GetType 返回推流类型 GetType() SourceType @@ -47,7 +47,7 @@ type Source interface { // IsCompleted 所有推流track是否解析完毕 IsCompleted() bool - Init(receiveQueueSize int) + Init() RemoteAddr() string @@ -61,11 +61,6 @@ type Source interface { // SetUrlValues 设置推流url参数 SetUrlValues(values url.Values) - // PostEvent 切换到主协程执行当前函数 - postEvent(cb func()) - - executeSyncEvent(cb func()) - // LastPacketTime 返回最近收流时间戳 LastPacketTime() time.Time @@ -73,10 +68,6 @@ type Source interface { IsClosed() bool - StreamPipe() chan []byte - - MainContextEvents() chan func() - CreateTime() time.Time SetCreateTime(time time.Time) @@ -86,6 +77,12 @@ type Source interface { ProbeTimeout() GetTransStreamPublisher() TransStreamPublisher + + StartTimers(source Source) + + ExecuteSyncEvent(cb func()) + + UpdateReceiveStats(dataLen int) } type PublishSource struct { @@ -94,9 +91,7 @@ type PublishSource struct { state SessionState Conn net.Conn - streamPipe *NonBlockingChannel[[]byte] // 推流数据管道 - mainContextEvents chan func() // 切换到主协程执行函数的事件管道 - streamPublisher TransStreamPublisher // 解析出来的AVStream和AVPacket, 交由streamPublisher处理 + streamPublisher TransStreamPublisher // 解析出来的AVStream和AVPacket, 交由streamPublisher处理 TransDemuxer avformat.Demuxer // 负责从推流协议中解析出AVStream和AVPacket originTracks TrackManager // 推流的音视频Streams @@ -110,6 +105,14 @@ type PublishSource struct { createTime time.Time // source创建时间 statistics *BitrateStatistics // 码流统计 streamLogger avformat.OnUnpackStream2FileHandler + // streamLock sync.RWMutex + streamLock sync.Mutex + + timers struct { + receiveTimer *time.Timer // 收流超时计时器 + idleTimer *time.Timer // 空闲超时计时器 + probeTimer *time.Timer // tack探测超时计时器 + } } func (s *PublishSource) SetLastPacketTime(time2 time.Time) { @@ -120,14 +123,6 @@ func (s *PublishSource) IsClosed() bool { return s.closed.Load() } -func (s *PublishSource) StreamPipe() chan []byte { - return s.streamPipe.Channel -} - -func (s *PublishSource) MainContextEvents() chan func() { - return s.mainContextEvents -} - func (s *PublishSource) LastPacketTime() time.Time { return s.lastPacketTime } @@ -143,23 +138,35 @@ func (s *PublishSource) SetID(id string) { } } -func (s *PublishSource) Init(receiveQueueSize int) { +func (s *PublishSource) Init() { s.SetState(SessionStateHandshakeSuccess) - // 初始化事件接收管道 - // -2是为了保证从管道取到流, 到处理完流整个过程安全的, 不会被覆盖 - s.streamPipe = NewNonBlockingChannel[[]byte](receiveQueueSize - 1) - s.mainContextEvents = make(chan func(), 128) s.statistics = NewBitrateStatistics() s.streamPublisher = NewTransStreamPublisher(s.ID) // 设置探测时长 s.TransDemuxer.SetProbeDuration(AppConfig.ProbeTimeout) } -func (s *PublishSource) Input(data []byte) error { - s.streamPipe.Post(data) - s.statistics.Input(len(data)) - return nil +func (s *PublishSource) UpdateReceiveStats(dataLen int) { + s.statistics.Input(dataLen) + if AppConfig.ReceiveTimeout > 0 { + s.SetLastPacketTime(time.Now()) + } +} + +func (s *PublishSource) Input(data []byte) (int, error) { + s.UpdateReceiveStats(len(data)) + var n int + var err error + s.ExecuteSyncEvent(func() { + if s.closed.Load() { + err = fmt.Errorf("source closed") + } else { + n, err = s.TransDemuxer.Input(data) + } + }) + + return n, err } func (s *PublishSource) OriginTracks() []*Track { @@ -177,12 +184,31 @@ func (s *PublishSource) DoClose() { return } - s.closed.Store(true) + var closed bool + s.ExecuteSyncEvent(func() { + closed = s.closed.Swap(true) + }) + + if closed { + return + } + + // 关闭各种超时计时器 + if s.timers.receiveTimer != nil { + s.timers.receiveTimer.Stop() + } + + if s.timers.idleTimer != nil { + s.timers.idleTimer.Stop() + } + + if s.timers.probeTimer != nil { + s.timers.probeTimer.Stop() + } // 关闭推流源的解复用器, 不再接收数据 if s.TransDemuxer != nil { s.TransDemuxer.Close() - s.TransDemuxer = nil } // 等传输流发布器关闭结束 @@ -210,14 +236,7 @@ func (s *PublishSource) DoClose() { } func (s *PublishSource) Close() { - if s.closed.Load() { - return - } - - // 同步执行, 确保close后, 主协程已经退出, 不会再处理任何推拉流、查询等任何事情. - s.executeSyncEvent(func() { - s.DoClose() - }) + s.DoClose() } // 解析完所有track后, 创建各种输出流 @@ -233,7 +252,8 @@ func (s *PublishSource) writeHeader() { if len(s.originTracks.All()) == 0 { log.Sugar.Errorf("没有一路track, 删除source: %s", s.ID) - s.DoClose() + // 异步执行ProbeTimeout函数中还没释放锁 + go s.DoClose() return } } @@ -356,20 +376,11 @@ func (s *PublishSource) SetUrlValues(values url.Values) { s.urlValues = values } -func (s *PublishSource) postEvent(cb func()) { - s.mainContextEvents <- cb -} - -func (s *PublishSource) executeSyncEvent(cb func()) { - group := sync.WaitGroup{} - group.Add(1) - - s.postEvent(func() { - cb() - group.Done() - }) - - group.Wait() +func (s *PublishSource) ExecuteSyncEvent(cb func()) { + // 无竞争情况下, 接近原子操作 + s.streamLock.Lock() + defer s.streamLock.Unlock() + cb() } func (s *PublishSource) CreateTime() time.Time { @@ -386,10 +397,37 @@ func (s *PublishSource) GetBitrateStatistics() *BitrateStatistics { func (s *PublishSource) ProbeTimeout() { if s.TransDemuxer != nil { - s.TransDemuxer.ProbeComplete() + s.ExecuteSyncEvent(func() { + if !s.closed.Load() { + s.TransDemuxer.ProbeComplete() + } + }) } } func (s *PublishSource) GetTransStreamPublisher() TransStreamPublisher { return s.streamPublisher } + +func (s *PublishSource) StartTimers(source Source) { + + // 开启收流超时计时器 + if AppConfig.ReceiveTimeout > 0 { + s.timers.receiveTimer = StartReceiveDataTimer(source) + } + + // 开启拉流空闲超时计时器 + if AppConfig.Hooks.IsEnableOnIdleTimeout() && AppConfig.IdleTimeout > 0 { + s.timers.idleTimer = StartIdleTimer(source) + } + + // 开启探测超时计时器 + s.timers.probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, func() { + if source.IsCompleted() { + return + } + + source.ProbeTimeout() + }) + +} diff --git a/stream/source_utils.go b/stream/source_utils.go index c59eb89..08fa62d 100644 --- a/stream/source_utils.go +++ b/stream/source_utils.go @@ -197,100 +197,6 @@ func StartIdleTimer(source Source) *time.Timer { // LoopEvent 循环读取事件 func LoopEvent(source Source) { - // 将超时计时器放在此处开启, 方便在退出的时候关闭 - var receiveTimer *time.Timer - var idleTimer *time.Timer - var probeTimer *time.Timer - - defer func() { - log.Sugar.Debugf("主协程执行结束 source: %s", source.GetID()) - - // 关闭计时器 - if receiveTimer != nil { - receiveTimer.Stop() - } - - if idleTimer != nil { - idleTimer.Stop() - } - - if probeTimer != nil { - probeTimer.Stop() - } - - // 未使用的数据, 放回池中 - for len(source.StreamPipe()) > 0 { - data := <-source.StreamPipe() - if size := cap(data); size > UDPReceiveBufferSize { - TCPReceiveBufferPool.Put(data[:size]) - } else { - UDPReceiveBufferPool.Put(data[:size]) - } - } - }() - - // 开启收流超时计时器 - if AppConfig.ReceiveTimeout > 0 { - receiveTimer = StartReceiveDataTimer(source) - } - - // 开启拉流空闲超时计时器 - if AppConfig.Hooks.IsEnableOnIdleTimeout() && AppConfig.IdleTimeout > 0 { - idleTimer = StartIdleTimer(source) - } - - // 开启探测超时计时器 - probeTimer = time.AfterFunc(time.Duration(AppConfig.ProbeTimeout)*time.Millisecond, func() { - if source.IsCompleted() { - return - } - - var ok bool - source.executeSyncEvent(func() { - source.ProbeTimeout() - ok = len(source.OriginTracks()) > 0 - }) - - if !ok { - source.Close() - return - } - }) - - // 启动协程, 生成发布传输流 + source.StartTimers(source) go source.GetTransStreamPublisher().run() - - for { - select { - // 读取推流数据 - case data := <-source.StreamPipe(): - if AppConfig.ReceiveTimeout > 0 { - source.SetLastPacketTime(time.Now()) - } - - if err := source.Input(data); err != nil { - log.Sugar.Errorf("解析推流数据发生err: %s 释放source: %s", err.Error(), source.GetID()) - go source.Close() - return - } - - // 使用后, 放回池中 - if size := cap(data); size > UDPReceiveBufferSize { - TCPReceiveBufferPool.Put(data[:size]) - } else { - UDPReceiveBufferPool.Put(data[:size]) - } - break - // 切换到主协程,执行该函数. 目的是用于无锁化处理推拉流的连接与断开, 推流源断开, 查询推流源信息等事件. 不要做耗时操作, 否则会影响推拉流. - case event := <-source.MainContextEvents(): - event() - - if source.IsClosed() { - // 处理推流管道剩余的数据? - return - } - - break - } - } }