From ca52588bae0738483fbf24c7cd6d33c76444d204 Mon Sep 17 00:00:00 2001 From: ydajiang Date: Fri, 8 Aug 2025 17:14:33 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20gb28181=E4=BB=85=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=A4=9A=E7=AB=AF=E5=8F=A3=E6=8E=A8=E6=B5=81,=20=E6=8F=90?= =?UTF-8?q?=E5=8D=87=E4=BB=A3=E7=A0=81=E5=81=A5=E5=A3=AE=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api_gb.go | 11 +-- gb28181/filter.go | 10 -- gb28181/filter_single.go | 21 ----- gb28181/filter_ssrc.go | 38 -------- gb28181/publish_test.go | 26 +++--- gb28181/source.go | 183 +++++++++++++++++-------------------- gb28181/source_active.go | 29 ++++-- gb28181/source_passive.go | 96 ++++++++++++++++++- gb28181/source_udp.go | 35 ++++++- gb28181/ssrc_manager.go | 20 +--- gb28181/tcp_client.go | 28 ------ gb28181/tcp_server.go | 96 ------------------- gb28181/tcp_session.go | 88 ------------------ gb28181/udp_server.go | 91 ------------------ jt1078/jt_server.go | 1 + jt1078/jt_session.go | 21 +---- jt1078/jt_test.go | 4 +- main.go | 32 +------ rtmp/rtmp_session.go | 17 +++- stream/config.go | 8 -- stream/hook.go | 7 +- stream/hook_source.go | 70 ++++++++++---- stream/mw_buffer.go | 13 ++- stream/mwb_pool.go | 36 +++++--- stream/sink.go | 11 ++- stream/source.go | 57 ++++++------ stream/source_manager.go | 2 +- stream/source_utils.go | 7 ++ stream/stream_publisher.go | 4 +- stream/stream_server.go | 8 +- stream/trans_utils.go | 29 +----- 31 files changed, 415 insertions(+), 684 deletions(-) delete mode 100644 gb28181/filter.go delete mode 100644 gb28181/filter_single.go delete mode 100644 gb28181/filter_ssrc.go delete mode 100644 gb28181/tcp_client.go delete mode 100644 gb28181/tcp_server.go delete mode 100644 gb28181/tcp_session.go delete mode 100644 gb28181/udp_server.go diff --git a/api_gb.go b/api_gb.go index 6501304..12f8bf5 100644 --- a/api_gb.go +++ b/api_gb.go @@ -3,7 +3,6 @@ package main import ( "fmt" "github.com/lkmio/avformat/bufio" - "github.com/lkmio/avformat/utils" "github.com/lkmio/lkm/gb28181" "github.com/lkmio/lkm/log" "github.com/lkmio/lkm/stream" @@ -74,9 +73,7 @@ func (api *ApiServer) OnGBSourceCreate(v *SourceSDP, w http.ResponseWriter, r *h } if tcp && active { - if !stream.AppConfig.GB28181.IsMultiPort() { - err = fmt.Errorf("单端口模式下不能主动拉流") - } else if !tcp { + if !tcp { err = fmt.Errorf("UDP不能主动拉流") } else if !stream.AppConfig.GB28181.IsEnableTCP() { err = fmt.Errorf("未开启TCP收流服务,UDP不能主动拉流") @@ -218,9 +215,9 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) { talkSource.Init() talkSource.SetUrlValues(r.Form) - _, state := stream.PreparePublishSource(talkSource, true) - if utils.HookStateOK != state { - log.Sugar.Errorf("对讲失败, source: %s", talkSource) + _, err = stream.PreparePublishSource(talkSource, true) + if err != nil { + log.Sugar.Errorf("对讲失败, err: %s source: %s", err, talkSource) conn.Close() return } diff --git a/gb28181/filter.go b/gb28181/filter.go deleted file mode 100644 index 0670014..0000000 --- a/gb28181/filter.go +++ /dev/null @@ -1,10 +0,0 @@ -package gb28181 - -// Filter 关联Source -type Filter interface { - AddSource(ssrc uint32, source GBSource) bool - - RemoveSource(ssrc uint32) - - FindSource(ssrc uint32) GBSource -} diff --git a/gb28181/filter_single.go b/gb28181/filter_single.go deleted file mode 100644 index dd23424..0000000 --- a/gb28181/filter_single.go +++ /dev/null @@ -1,21 +0,0 @@ -package gb28181 - -type singleFilter struct { - source GBSource -} - -func (s *singleFilter) AddSource(ssrc uint32, source GBSource) bool { - panic("implement me") -} - -func (s *singleFilter) RemoveSource(ssrc uint32) { - s.source = nil -} - -func (s *singleFilter) FindSource(ssrc uint32) GBSource { - return s.source -} - -func NewSingleFilter(source GBSource) Filter { - return &singleFilter{source: source} -} diff --git a/gb28181/filter_ssrc.go b/gb28181/filter_ssrc.go deleted file mode 100644 index 62fb485..0000000 --- a/gb28181/filter_ssrc.go +++ /dev/null @@ -1,38 +0,0 @@ -package gb28181 - -import ( - "sync" -) - -type ssrcFilter struct { - sources map[uint32]GBSource - mute sync.RWMutex -} - -func (r *ssrcFilter) AddSource(ssrc uint32, source GBSource) bool { - r.mute.Lock() - defer r.mute.Unlock() - - if _, ok := r.sources[ssrc]; !ok { - r.sources[ssrc] = source - return true - } - - return false -} - -func (r *ssrcFilter) RemoveSource(ssrc uint32) { - r.mute.Lock() - defer r.mute.Unlock() - delete(r.sources, ssrc) -} - -func (r *ssrcFilter) FindSource(ssrc uint32) GBSource { - r.mute.RLock() - defer r.mute.RUnlock() - return r.sources[ssrc] -} - -func NewSSRCFilter(guestCount int) Filter { - return &ssrcFilter{sources: make(map[uint32]GBSource, guestCount)} -} diff --git a/gb28181/publish_test.go b/gb28181/publish_test.go index d660c5b..de82cdf 100644 --- a/gb28181/publish_test.go +++ b/gb28181/publish_test.go @@ -24,7 +24,7 @@ import ( func connectSource(source string, addr string) { v := &struct { Source string `json:"source"` //GetSourceID - RemoteAddr string `json:"remote_addr"` + RemoteAddr string `json:"addr"` }{ Source: source, RemoteAddr: addr, @@ -35,7 +35,7 @@ func connectSource(source string, addr string) { panic(err) } - request, err := http.NewRequest("POST", "http://localhost:8080/v1/gb28181/source/connect", bytes.NewBuffer(marshal)) + request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/answer/set", bytes.NewBuffer(marshal)) if err != nil { panic(err) } @@ -209,11 +209,12 @@ func TestPublish(t *testing.T) { //path := "../../source_files/rtp_ps_h264_G7221_0xBEBC204.raw" //var rawSsrc uint32 = 0xBEBC204 - path := "../../source_files/rtp_ps_h264_G726_0xBEBC205.raw" + //path := "../../source_files/rtp_ps_h264_G726_0xBEBC205.raw" + path := "../../source_files/rtp_ps_err_parse.raw" var rawSsrc uint32 = 0xBEBC205 localAddr := "0.0.0.0:20001" - id := "hls_mystream" + id := "hls/mystream" data, err := os.ReadFile(path) if err != nil { @@ -305,7 +306,7 @@ func TestPublish(t *testing.T) { }) t.Run("active", func(t *testing.T) { - ip, port, ssrc := createSource(id, "active", rawSsrc) + _, _, ssrc := createSource(id, "active", rawSsrc) addr, _ := net.ResolveTCPAddr("tcp", localAddr) server := transport.TCPServer{} @@ -317,6 +318,7 @@ func TestPublish(t *testing.T) { ctrDelay(packet[2:]) } + server.Close() return nil }, nil, nil) @@ -325,7 +327,9 @@ func TestPublish(t *testing.T) { panic(err) } - connectSource(id, fmt.Sprintf("%s:%d", ip, port)) + server.Accept() + connectSource(id, localAddr) + select {} }) } @@ -336,10 +340,10 @@ func TestDecode(t *testing.T) { panic(err2) } - source := NewPassiveSource() - source.Init() - filter := NewSingleFilter(source) - session := NewTCPSession(nil, filter) + source := &PassiveSource{ + decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), + } + reader := bufio.NewBytesReader(file) for { @@ -353,7 +357,7 @@ func TestDecode(t *testing.T) { break } - err2 = session.DecodeGBRTPOverTCPPacket(bytes, filter, nil) + err2 = source.DecodeGBRTPOverTCPPacket(bytes) if err2 != nil { break } diff --git a/gb28181/source.go b/gb28181/source.go index 7ed889d..cce1842 100644 --- a/gb28181/source.go +++ b/gb28181/source.go @@ -11,7 +11,6 @@ import ( "github.com/lkmio/transport" "github.com/pion/rtp" "math" - "net" "strings" ) @@ -23,7 +22,6 @@ const ( SetupActive = SetupType(2) PsProbeBufferSize = 1024 * 1024 * 2 - JitterBufferSize = 1024 * 1024 ) func (s SetupType) TransportType() stream.TransportType { @@ -65,8 +63,6 @@ func SetupTypeFromString(setupType string) SetupType { var ( TransportManger transport.Manager - SharedUDPServer *UDPServer - SharedTCPServer *TCPServer ) // GBSource GB28181推流Source, 统一解析PS流、级联转发. @@ -75,23 +71,20 @@ type GBSource interface { SetupType() SetupType - // PreparePublish 收到流时, 做一些初始化工作. - PreparePublish(conn net.Conn, ssrc uint32, source GBSource) - - SetConn(conn net.Conn) - SetSSRC(ssrc uint32) SSRC() uint32 ProcessPacket(data []byte) error + + SetTransport(transport transport.Transport) } type BaseGBSource struct { stream.PublishSource - transport transport.Transport probeBuffer *mpeg.PSProbeBuffer + transport transport.Transport ssrc uint32 audioTimestamp int64 @@ -103,23 +96,16 @@ type BaseGBSource struct { sameTimePackets [][]byte } -func (source *BaseGBSource) Init() { - source.TransDemuxer = mpeg.NewPSDemuxer(false) - source.TransDemuxer.SetHandler(source) - source.TransDemuxer.SetOnPreprocessPacketHandler(func(packet *avformat.AVPacket) { - source.correctTimestamp(packet, packet.Dts, packet.Pts) - }) - source.SetType(stream.SourceType28181) - source.probeBuffer = mpeg.NewProbeBuffer(PsProbeBufferSize) - source.PublishSource.Init() - source.lastRtpTimestamp = -1 -} - // ProcessPacket 输入rtp包, 处理PS流, 负责解析->封装->推流 func (source *BaseGBSource) ProcessPacket(data []byte) error { packet := rtp.Packet{} _ = packet.Unmarshal(data) + // 收到第一包, 初始化 + if source.probeBuffer == nil { + source.InitializePublish(packet.SSRC) + } + // 国标级联转发 if source.GetTransStreamPublisher().GetForwardTransStream() != nil { if source.lastRtpTimestamp == -1 { @@ -228,30 +214,17 @@ func (source *BaseGBSource) correctTimestamp(packet *avformat.AVPacket, dts, pts } func (source *BaseGBSource) Close() { - log.Sugar.Infof("GB28181推流结束 ssrc:%d %s", source.ssrc, source.PublishSource.String()) - - // 释放收流端口 - if source.transport != nil { - source.transport.Close() - source.transport = nil - } - - // 删除ssrc关联 - if !stream.AppConfig.GB28181.IsMultiPort() { - if SharedTCPServer != nil { - SharedTCPServer.filter.RemoveSource(source.ssrc) - } - - if SharedUDPServer != nil { - SharedUDPServer.filter.RemoveSource(source.ssrc) - } - } + log.Sugar.Infof("GB28181推流结束 ssrc: %d %s", source.ssrc, source.PublishSource.String()) source.PublishSource.Close() -} -func (source *BaseGBSource) SetConn(conn net.Conn) { - source.Conn = conn + // 加锁执行, 保证并发安全 + source.ExecuteWithDeleteLock(func() { + if source.transport != nil { + source.transport.Close() + source.transport = nil + } + }) } func (source *BaseGBSource) SetSSRC(ssrc uint32) { @@ -262,27 +235,43 @@ func (source *BaseGBSource) SSRC() uint32 { return source.ssrc } -func (source *BaseGBSource) PreparePublish(conn net.Conn, ssrc uint32, source_ GBSource) { - source.SetConn(conn) - source.SetSSRC(ssrc) - source.SetState(stream.SessionStateTransferring) +func (source *BaseGBSource) InitializePublish(ssrc uint32) { + if source.ssrc != ssrc { + log.Sugar.Warnf("创建source的ssrc与实际推流的ssrc不一致, 创建的ssrc: %x 实际推流的ssrc: %x source: %s", source.ssrc, ssrc, source.GetID()) + } + + // 初始化ps解复用器 + source.TransDemuxer.SetOnPreprocessPacketHandler(func(packet *avformat.AVPacket) { + source.correctTimestamp(packet, packet.Dts, packet.Pts) + }) + source.probeBuffer = mpeg.NewProbeBuffer(PsProbeBufferSize) + source.lastRtpTimestamp = -1 + + source.ssrc = ssrc source.audioTimestamp = -1 source.videoTimestamp = -1 source.audioPacketCreatedTime = -1 source.videoPacketCreatedTime = -1 - if stream.AppConfig.Hooks.IsEnablePublishEvent() { - go func() { - if _, state := stream.HookPublishEvent(source_); utils.HookStateOK == state { - return - } - - log.Sugar.Errorf("GB28181 推流失败 source:%s", source.GetID()) - if conn != nil { - conn.Close() - } - }() + p := stream.SourceManager.Find(source.GetID()) + if p == nil { + log.Sugar.Errorf("GB28181推流失败, 未找到source: %s", source.GetID()) + source.Close() + return } + + stream.PreparePublishSourceWithAsync(p, false) +} + +func (source *BaseGBSource) Init() { + // 创建ps解复用器 + source.TransDemuxer = mpeg.NewPSDemuxer(false) + source.TransDemuxer.SetHandler(source) + source.PublishSource.Init() +} + +func (source *BaseGBSource) SetTransport(transport transport.Transport) { + source.transport = transport } // NewGBSource 创建国标推流源, 返回监听的收流端口 @@ -294,9 +283,10 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, int, } if active { - utils.Assert(tcp && stream.AppConfig.GB28181.IsEnableTCP() && stream.AppConfig.GB28181.IsMultiPort()) + utils.Assert(tcp && stream.AppConfig.GB28181.IsEnableTCP()) } + var transportServer transport.Transport var source GBSource var port int var err error @@ -304,55 +294,46 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, int, if active { source, port, err = NewActiveSource() } else if tcp { + transportServer, err = TransportManger.NewTCPServer() + if err != nil { + return nil, 0, err + } + source = NewPassiveSource() + transportServer.(*transport.TCPServer).SetHandler(source.(*PassiveSource)) + transportServer.(*transport.TCPServer).Accept() + port = transportServer.ListenPort() } else { + transportServer, err = TransportManger.NewUDPServer() + if err != nil { + return nil, 0, err + } + source = NewUDPSource() + transportServer.(*transport.UDPServer).SetHandler(source.(*UDPSource)) + transportServer.(*transport.UDPServer).Receive() + port = transportServer.ListenPort() } - if err != nil { - return nil, 0, err - } - - // 单端口模式,绑定ssrc - if !stream.AppConfig.GB28181.IsMultiPort() { - var success bool - if tcp { - success = SharedTCPServer.filter.AddSource(ssrc, source) - } else { - success = SharedUDPServer.filter.AddSource(ssrc, source) - } - - if !success { - return nil, 0, fmt.Errorf("ssrc conflict") - } - - port = stream.AppConfig.GB28181.Port[0] - } else if !active { - // 多端口模式, 创建收流Server - if tcp { - tcpServer, err := NewTCPServer(NewSingleFilter(source)) - if err != nil { - return nil, 0, err - } - - port = tcpServer.tcp.ListenPort() - source.(*PassiveSource).transport = tcpServer.tcp - } else { - server, err := NewUDPServer(NewSingleFilter(source)) - if err != nil { - return nil, 0, err - } - - port = server.udp.ListenPort() - source.(*UDPSource).transport = server.udp - } - } - + source.SetType(stream.SourceType28181) source.SetID(id) source.SetSSRC(ssrc) - source.Init() - if _, state := stream.PreparePublishSource(source, false); utils.HookStateOK != state { - return nil, 0, fmt.Errorf("error code %d", state) + // 加锁保护一下, 防止初始化阶段, 调用关闭source接口, 发生并发安全问题 + source.ExecuteWithDeleteLock(func() { + if err = stream.AddSource(source); err != nil { + return + } + + source.SetTransport(transportServer) + source.Init() + }) + + // id冲突 + if err != nil { + if transportServer != nil { + transportServer.Close() + } + return nil, 0, err } stream.LoopEvent(source) diff --git a/gb28181/source_active.go b/gb28181/source_active.go index e6ef54f..811a1b5 100644 --- a/gb28181/source_active.go +++ b/gb28181/source_active.go @@ -1,24 +1,30 @@ package gb28181 import ( + "github.com/lkmio/lkm/stream" + "github.com/lkmio/transport" "net" ) type ActiveSource struct { - PassiveSource - + *PassiveSource port int remoteAddr net.TCPAddr - tcp *TCPClient } func (a *ActiveSource) Connect(remoteAddr *net.TCPAddr) error { - client, err := NewTCPClient(a.port, remoteAddr, a) + client := &transport.TCPClient{} + client.SetHandler(a.PassiveSource) + + addr, err := net.ResolveTCPAddr("tcp", stream.ListenAddr(a.port)) if err != nil { return err + } else if _, err = client.Connect(addr, remoteAddr); err != nil { + return err } - a.tcp = client + go client.Receive() + a.transport = client return nil } @@ -28,12 +34,23 @@ func (a *ActiveSource) SetupType() SetupType { func NewActiveSource() (*ActiveSource, int, error) { var port int - TransportManger.AllocPort(true, func(port_ uint16) error { + err := TransportManger.AllocPort(true, func(port_ uint16) error { port = int(port_) return nil }) + if err != nil { + return nil, 0, err + } + return &ActiveSource{ + PassiveSource: &PassiveSource{ + StreamServer: stream.StreamServer[GBSource]{ + SourceType: stream.SourceType28181, + }, + decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), + receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte), + }, port: port, }, port, nil } diff --git a/gb28181/source_passive.go b/gb28181/source_passive.go index 7faafe1..21dbf3a 100644 --- a/gb28181/source_passive.go +++ b/gb28181/source_passive.go @@ -1,13 +1,105 @@ package gb28181 +import ( + "encoding/hex" + "github.com/lkmio/lkm/log" + "github.com/lkmio/lkm/stream" + "github.com/lkmio/transport" + "net" +) + type PassiveSource struct { + stream.StreamServer[GBSource] BaseGBSource + decoder *transport.LengthFieldFrameDecoder + receiveBuffer []byte + remoteAddr string } func (p *PassiveSource) SetupType() SetupType { return SetupPassive } -func NewPassiveSource() *PassiveSource { - return &PassiveSource{} +func (p *PassiveSource) Close() { + p.BaseGBSource.Close() + stream.TCPReceiveBufferPool.Put(p.receiveBuffer[:cap(p.receiveBuffer)]) +} + +func (p *PassiveSource) DecodeGBRTPOverTCPPacket(data []byte) error { + length := len(data) + for i := 0; i < length; { + // 解析粘包数据 + n, bytes, err := p.decoder.Input(data[i:]) + if err != nil { + return err + } + + i += n + if bytes == nil { + break + } + + if err = p.ProcessPacket(bytes); err != nil { + return err + } + } + + return nil +} + +func (p *PassiveSource) OnConnected(conn net.Conn) []byte { + p.StreamServer.OnConnected(conn) + + var ok bool + p.ExecuteWithDeleteLock(func() { + if p.IsClosed() { + log.Sugar.Infof("source %s 已关闭, 拒绝新连接", p.GetID()) + } else if ok = p.PublishSource.Conn == nil; ok { + // 一个推流一个端口, 默认第一个连接为有效连接, 关闭其他连接 + p.PublishSource.Conn = conn + p.remoteAddr = conn.RemoteAddr().String() + } else { + log.Sugar.Infof("port %d 已连接, 关闭连接. source: %s", p.transport.ListenPort(), p.GetID()) + } + }) + + if !ok { + _ = conn.Close() + return nil + } + + return p.receiveBuffer +} + +func (p *PassiveSource) OnPacket(conn net.Conn, data []byte) []byte { + p.StreamServer.OnPacket(conn, data) + + err := p.DecodeGBRTPOverTCPPacket(data) + if err != nil { + log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data)) + _ = conn.Close() + return nil + } + + return p.receiveBuffer +} + +func (p *PassiveSource) OnDisConnected(conn net.Conn, err error) { + p.StreamServer.OnDisConnected(conn, err) + + if conn.RemoteAddr().String() == p.remoteAddr { + p.Close() + } +} + +func NewPassiveSource() *PassiveSource { + source := &PassiveSource{ + StreamServer: stream.StreamServer[GBSource]{ + SourceType: stream.SourceType28181, + }, + decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), + receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte), + } + + return source } diff --git a/gb28181/source_udp.go b/gb28181/source_udp.go index 46ae50d..9de0f1d 100644 --- a/gb28181/source_udp.go +++ b/gb28181/source_udp.go @@ -1,14 +1,16 @@ package gb28181 import ( + "github.com/lkmio/lkm/log" "github.com/lkmio/lkm/stream" "github.com/pion/rtp" + "net" ) // UDPSource 国标UDP推流源 type UDPSource struct { + stream.StreamServer[interface{}] BaseGBSource - jitterBuffer *stream.JitterBuffer[*rtp.Packet] } @@ -18,12 +20,12 @@ func (u *UDPSource) SetupType() SetupType { // OnOrderedRtp 有序RTP包回调 func (u *UDPSource) OnOrderedRtp(packet *rtp.Packet) { - // 此时还在网络收流携程, 交给Source的主协程处理 - u.ProcessPacket(packet.Raw) + _ = u.ProcessPacket(packet.Raw) + // 处理完后, 归还buffer stream.UDPReceiveBufferPool.Put(packet.Raw[:cap(packet.Raw)]) } -// InputRtpPacket 将RTP包排序后,交给Source的主协程处理 +// InputRtpPacket 将RTP包排序后,交给Source处理 func (u *UDPSource) InputRtpPacket(pkt *rtp.Packet) error { block := stream.UDPReceiveBufferPool.Get().([]byte) copy(block, pkt.Raw) @@ -45,8 +47,31 @@ func (u *UDPSource) Close() { u.BaseGBSource.Close() } +func (u *UDPSource) OnPacket(conn net.Conn, data []byte) []byte { + u.StreamServer.OnPacket(conn, data) + + packet := rtp.Packet{} + err := packet.Unmarshal(data) + if err != nil { + log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String()) + return nil + } else if u.Conn == nil { + u.Conn = conn + } + + packet.Raw = data + _ = u.InputRtpPacket(&packet) + return nil +} + func NewUDPSource() *UDPSource { - return &UDPSource{ + source := &UDPSource{ jitterBuffer: stream.NewJitterBuffer[*rtp.Packet](), } + + source.StreamServer = stream.StreamServer[interface{}]{ + SourceType: stream.SourceType28181, + } + + return source } diff --git a/gb28181/ssrc_manager.go b/gb28181/ssrc_manager.go index 280d238..71dfb9f 100644 --- a/gb28181/ssrc_manager.go +++ b/gb28181/ssrc_manager.go @@ -2,7 +2,6 @@ package gb28181 import ( "fmt" - "strconv" "sync" ) @@ -11,9 +10,8 @@ const ( ) var ( - ssrcCount uint32 - lock sync.Mutex - SSRCFilters []Filter + ssrcCount uint32 + lock sync.Mutex ) func NextSSRC() uint32 { @@ -23,19 +21,7 @@ func NextSSRC() uint32 { return ssrcCount } -func getUniqueSSRC(ssrc string, get func() string) string { - atoi, err := strconv.Atoi(ssrc) - if err != nil { - panic(err) - } - - v := uint32(atoi) - for _, filter := range SSRCFilters { - if filter.FindSource(v) != nil { - ssrc = get() - } - } - +func getUniqueSSRC(ssrc string, _ func() string) string { return ssrc } diff --git a/gb28181/tcp_client.go b/gb28181/tcp_client.go deleted file mode 100644 index b5bd15e..0000000 --- a/gb28181/tcp_client.go +++ /dev/null @@ -1,28 +0,0 @@ -package gb28181 - -import ( - "github.com/lkmio/lkm/stream" - "github.com/lkmio/transport" - "net" -) - -// TCPClient GB28181TCP主动收流 -type TCPClient struct { - TCPServer -} - -func NewTCPClient(listenPort int, remoteAddr *net.TCPAddr, source GBSource) (*TCPClient, error) { - client := &TCPClient{ - TCPServer{filter: NewSingleFilter(source)}, - } - tcp := transport.TCPClient{} - tcp.SetHandler(client) - - addr, err := net.ResolveTCPAddr("tcp", stream.ListenAddr(listenPort)) - if err != nil { - return client, err - } - - _, err = tcp.Connect(addr, remoteAddr) - return client, err -} diff --git a/gb28181/tcp_server.go b/gb28181/tcp_server.go deleted file mode 100644 index a8966d7..0000000 --- a/gb28181/tcp_server.go +++ /dev/null @@ -1,96 +0,0 @@ -package gb28181 - -import ( - "encoding/hex" - "github.com/lkmio/lkm/log" - "github.com/lkmio/lkm/stream" - "github.com/lkmio/transport" - "net" - "runtime" -) - -// TCPServer GB28181TCP被动收流 -type TCPServer struct { - stream.StreamServer[*TCPSession] - tcp *transport.TCPServer - filter Filter -} - -func (T *TCPServer) OnNewSession(conn net.Conn) *TCPSession { - return NewTCPSession(conn, T.filter) -} - -func (T *TCPServer) OnCloseSession(session *TCPSession) { - session.Close() - - if session.source != nil { - T.filter.RemoveSource(session.source.SSRC()) - } - - if stream.AppConfig.GB28181.IsMultiPort() { - T.tcp.Close() - T.Handler = nil - } -} - -func (T *TCPServer) OnConnected(conn net.Conn) []byte { - T.StreamServer.OnConnected(conn) - return conn.(*transport.Conn).Data.(*TCPSession).receiveBuffer -} - -func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte { - T.StreamServer.OnPacket(conn, data) - session := conn.(*transport.Conn).Data.(*TCPSession) - - err := session.DecodeGBRTPOverTCPPacket(data, T.filter, conn) - if err != nil { - log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data)) - _ = conn.Close() - return nil - } - - return session.receiveBuffer -} - -func NewTCPServer(filter Filter) (*TCPServer, error) { - server := &TCPServer{ - filter: filter, - } - - var tcp *transport.TCPServer - var err error - if stream.AppConfig.GB28181.IsMultiPort() { - tcp = &transport.TCPServer{} - tcp, err = TransportManger.NewTCPServer() - if err != nil { - return nil, err - } - - } else { - tcp = &transport.TCPServer{ - ReuseServer: transport.ReuseServer{ - EnableReuse: true, - ConcurrentNumber: runtime.NumCPU(), - }, - } - - var gbAddr *net.TCPAddr - gbAddr, err = net.ResolveTCPAddr("tcp", stream.ListenAddr(stream.AppConfig.GB28181.Port[0])) - if err != nil { - return nil, err - } - - if err = tcp.Bind(gbAddr); err != nil { - return server, err - } - } - - tcp.SetHandler(server) - tcp.Accept() - server.tcp = tcp - server.StreamServer = stream.StreamServer[*TCPSession]{ - SourceType: stream.SourceType28181, - Handler: server, - } - return server, nil -} diff --git a/gb28181/tcp_session.go b/gb28181/tcp_session.go deleted file mode 100644 index f8bb7e3..0000000 --- a/gb28181/tcp_session.go +++ /dev/null @@ -1,88 +0,0 @@ -package gb28181 - -import ( - "fmt" - "github.com/lkmio/lkm/stream" - "github.com/lkmio/transport" - "github.com/pion/rtp" - "net" -) - -// TCPSession 国标TCP主被动推流Session, 统一处理TCP粘包. -type TCPSession struct { - conn net.Conn - source GBSource - decoder *transport.LengthFieldFrameDecoder - receiveBuffer []byte -} - -func (t *TCPSession) Init(source GBSource) { - t.source = source -} - -func (t *TCPSession) Close() { - t.conn = nil - if t.source != nil { - t.source.Close() - t.source = nil - } - - stream.TCPReceiveBufferPool.Put(t.receiveBuffer[:cap(t.receiveBuffer)]) -} - -func (t *TCPSession) DecodeGBRTPOverTCPPacket(data []byte, filter Filter, conn net.Conn) error { - length := len(data) - for i := 0; i < length; { - // 解析粘包数据 - n, bytes, err := t.decoder.Input(data[i:]) - if err != nil { - return err - } - - i += n - if bytes == nil { - break - } - - // 单端口模式,ssrc匹配source - if t.source == nil || stream.SessionStateHandshakeSuccess == t.source.State() { - packet := rtp.Packet{} - if err = packet.Unmarshal(bytes); err != nil { - return err - } else if t.source == nil { - t.source = filter.FindSource(packet.SSRC) - } - - if t.source == nil { - // ssrc 匹配不到Source - return fmt.Errorf("gb28181推流失败 ssrc: %x 匹配不到source", packet.SSRC) - } - - if stream.SessionStateHandshakeSuccess == t.source.State() { - t.source.PreparePublish(conn, packet.SSRC, t.source) - } - } - - if err = t.source.ProcessPacket(bytes); err != nil { - return err - } - } - - return nil -} - -func NewTCPSession(conn net.Conn, filter Filter) *TCPSession { - session := &TCPSession{ - conn: conn, - // filter: filter, - decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2), - receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte), - } - - // 多端口收流, Source已知, 直接初始化Session - if stream.AppConfig.GB28181.IsMultiPort() { - session.Init(filter.(*singleFilter).source) - } - - return session -} diff --git a/gb28181/udp_server.go b/gb28181/udp_server.go deleted file mode 100644 index f99f4f0..0000000 --- a/gb28181/udp_server.go +++ /dev/null @@ -1,91 +0,0 @@ -package gb28181 - -import ( - "github.com/lkmio/lkm/log" - "github.com/lkmio/lkm/stream" - "github.com/lkmio/transport" - "github.com/pion/rtp" - "net" - "runtime" -) - -// UDPServer GB28181UDP收流 -type UDPServer struct { - stream.StreamServer[*UDPSource] - udp *transport.UDPServer - filter Filter -} - -func (U *UDPServer) OnNewSession(_ net.Conn) *UDPSource { - return nil -} - -func (U *UDPServer) OnCloseSession(_ *UDPSource) { -} - -func (U *UDPServer) OnPacket(conn net.Conn, data []byte) []byte { - U.StreamServer.OnPacket(conn, data) - - packet := rtp.Packet{} - err := packet.Unmarshal(data) - if err != nil { - log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String()) - return nil - } - - source := U.filter.FindSource(packet.SSRC) - if source == nil { - log.Sugar.Errorf("ssrc匹配source失败 ssrc:%x conn:%s", packet.SSRC, conn.RemoteAddr().String()) - return nil - } - - if stream.SessionStateHandshakeSuccess == source.State() { - conn.(*transport.Conn).Data = source - source.PreparePublish(conn, packet.SSRC, source) - } - - packet.Raw = data - _ = source.(*UDPSource).InputRtpPacket(&packet) - return nil -} - -func NewUDPServer(filter Filter) (*UDPServer, error) { - server := &UDPServer{ - filter: filter, - } - - var udp *transport.UDPServer - var err error - if stream.AppConfig.GB28181.IsMultiPort() { - udp, err = TransportManger.NewUDPServer() - if err != nil { - return nil, err - } - } else { - udp = &transport.UDPServer{ - ReuseServer: transport.ReuseServer{ - EnableReuse: true, - ConcurrentNumber: runtime.NumCPU(), - }, - } - - var gbAddr *net.UDPAddr - gbAddr, err = net.ResolveUDPAddr("udp", stream.ListenAddr(stream.AppConfig.GB28181.Port[0])) - if err != nil { - return nil, err - } - - if err = udp.Bind(gbAddr); err != nil { - return server, err - } - } - - udp.SetHandler(server) - udp.Receive() - server.udp = udp - server.StreamServer = stream.StreamServer[*UDPSource]{ - SourceType: stream.SourceType28181, - Handler: server, - } - return server, nil -} diff --git a/jt1078/jt_server.go b/jt1078/jt_server.go index 671ced9..3334a25 100644 --- a/jt1078/jt_server.go +++ b/jt1078/jt_server.go @@ -26,6 +26,7 @@ func (s *jtServer) OnNewSession(conn net.Conn) *Session { func (s *jtServer) OnCloseSession(session *Session) { session.Close() + stream.TCPReceiveBufferPool.Put(session.receiveBuffer[:cap(session.receiveBuffer)]) } func (s *jtServer) OnPacket(conn net.Conn, data []byte) []byte { diff --git a/jt1078/jt_session.go b/jt1078/jt_session.go index 3e36709..303fa74 100644 --- a/jt1078/jt_session.go +++ b/jt1078/jt_session.go @@ -1,7 +1,6 @@ package jt1078 import ( - "github.com/lkmio/avformat/utils" "github.com/lkmio/lkm/log" "github.com/lkmio/lkm/stream" "github.com/lkmio/transport" @@ -33,20 +32,10 @@ func (s *Session) Input(data []byte) (int, error) { return -1, err } - // 首包处理, hook通知 + // 首包处理 if firstOfPacket && demuxer.prevPacket != nil { s.SetID(demuxer.sim + "/" + strconv.Itoa(demuxer.channel)) - - go func() { - _, state := stream.PreparePublishSource(s, true) - if utils.HookStateOK != state { - log.Sugar.Errorf("1078推流失败 source: %s", demuxer.sim) - - if s.Conn != nil { - s.Conn.Close() - } - } - }() + stream.PreparePublishSourceWithAsync(s, true) } } @@ -56,13 +45,7 @@ func (s *Session) Input(data []byte) (int, error) { func (s *Session) Close() { log.Sugar.Infof("1078推流结束 %s", s.String()) - if s.Conn != nil { - s.Conn.Close() - s.Conn = nil - } - s.PublishSource.Close() - stream.TCPReceiveBufferPool.Put(s.receiveBuffer[:cap(s.receiveBuffer)]) } func NewSession(conn net.Conn, version int) *Session { diff --git a/jt1078/jt_test.go b/jt1078/jt_test.go index 78794f8..be6d4a8 100644 --- a/jt1078/jt_test.go +++ b/jt1078/jt_test.go @@ -218,9 +218,9 @@ func TestPublish(t *testing.T) { }) t.Run("publish", func(t *testing.T) { - //path := "../../source_files/10352264314-2.bin" + path := "../../source_files/10352264314-2.bin" //path := "../../source_files/013800138000-1.bin" - path := "../../source_files/0714-1.bin" + //path := "../../source_files/0714-1.bin" publish(path, "1078") }) diff --git a/main.go b/main.go index 1417dda..f58221e 100644 --- a/main.go +++ b/main.go @@ -75,11 +75,11 @@ func init() { // 初始化日志 log.InitLogger(config.Log.FileLogging, zapcore.Level(stream.AppConfig.Log.Level), stream.AppConfig.Log.Name, stream.AppConfig.Log.MaxSize, stream.AppConfig.Log.MaxBackup, stream.AppConfig.Log.MaxAge, stream.AppConfig.Log.Compress) - if stream.AppConfig.GB28181.Enable && stream.AppConfig.GB28181.IsMultiPort() { + if stream.AppConfig.GB28181.Enable { gb28181.TransportManger = transport.NewTransportManager(config.ListenIP, uint16(stream.AppConfig.GB28181.Port[0]), uint16(stream.AppConfig.GB28181.Port[1])) } - if stream.AppConfig.Rtsp.Enable && stream.AppConfig.Rtsp.IsMultiPort() { + if stream.AppConfig.Rtsp.Enable { rtsp.TransportManger = transport.NewTransportManager(config.ListenIP, uint16(stream.AppConfig.Rtsp.Port[1]), uint16(stream.AppConfig.Rtsp.Port[2])) } @@ -134,33 +134,7 @@ func main() { log.Sugar.Info("启动http服务 addr:", stream.ListenAddr(stream.AppConfig.Http.Port)) go startApiServer(net.JoinHostPort(stream.AppConfig.ListenIP, strconv.Itoa(stream.AppConfig.Http.Port))) - // 单端口模式下, 启动时就创建收流端口 - // 多端口模式下, 创建GBSource时才创建收流端口 - if stream.AppConfig.GB28181.Enable && !stream.AppConfig.GB28181.IsMultiPort() { - if stream.AppConfig.GB28181.IsEnableUDP() { - filter := gb28181.NewSSRCFilter(128) - server, err := gb28181.NewUDPServer(filter) - if err != nil { - panic(err) - } - - gb28181.SharedUDPServer = server - log.Sugar.Info("启动GB28181 udp收流端口成功:" + stream.ListenAddr(stream.AppConfig.GB28181.Port[0])) - gb28181.SSRCFilters = append(gb28181.SSRCFilters, filter) - } - - if stream.AppConfig.GB28181.IsEnableTCP() { - filter := gb28181.NewSSRCFilter(128) - server, err := gb28181.NewTCPServer(filter) - if err != nil { - panic(err) - } - - gb28181.SharedTCPServer = server - log.Sugar.Info("启动GB28181 tcp收流端口成功:" + stream.ListenAddr(stream.AppConfig.GB28181.Port[0])) - gb28181.SSRCFilters = append(gb28181.SSRCFilters, filter) - } - } + // GB28181收流时调用api创建收流端口 if stream.AppConfig.JT1078.Enable { // 无法通过包头区分2016和2019, 每个版本创建一个Server diff --git a/rtmp/rtmp_session.go b/rtmp/rtmp_session.go index cd2885c..1893118 100644 --- a/rtmp/rtmp_session.go +++ b/rtmp/rtmp_session.go @@ -6,6 +6,7 @@ import ( "github.com/lkmio/lkm/stream" "github.com/lkmio/rtmp" "net" + "strings" ) // Session RTMP会话, 解析处理Message @@ -40,9 +41,17 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState { source.SetUrlValues(values) // 统一处理source推流事件, source是否已经存在, hook回调.... - _, state := stream.PreparePublishSource(source, true) - if utils.HookStateOK != state { - log.Sugar.Errorf("rtmp推流失败 source: %s", sourceId) + state := utils.HookStateOK + _, err := stream.PreparePublishSource(source, true) + if err != nil { + str := err.Error() + log.Sugar.Errorf("rtmp推流失败 source: %s err: %s", sourceId, str) + + if strings.HasSuffix(str, "exist") { + state = utils.HookStateOccupy + } else { + state = utils.HookStateFailure + } } else { s.handle = source s.isPublisher = true @@ -77,7 +86,7 @@ func (s *Session) Input(data []byte) error { s.handle.(*Publisher).UpdateReceiveStats(len(data)) var err error - s.handle.(*Publisher).ExecuteSyncEvent(func() { + s.handle.(*Publisher).ExecuteWithStreamLock(func() { err = s.stack.Input(s.conn, data) }) diff --git a/stream/config.go b/stream/config.go index 2f6eb89..f13a15d 100644 --- a/stream/config.go +++ b/stream/config.go @@ -124,14 +124,6 @@ func (g TransportConfig) IsEnableUDP() bool { return strings.Contains(g.Transport, "UDP") } -func (g GB28181Config) IsMultiPort() bool { - return len(g.Port) > 1 -} - -func (g RtspConfig) IsMultiPort() bool { - return len(g.Port) == 3 -} - // M3U8Path 根据sourceId返回m3u8的磁盘路径 // 切片及目录生成规则, 以SourceId为34020000001320000001/34020000001320000001为例: // 创建文件夹34020000001320000001, 34020000001320000001.m3u8文件, 文件列表中切片url为34020000001320000001_seq.ts diff --git a/stream/hook.go b/stream/hook.go index 79a07e4..e2b2c57 100644 --- a/stream/hook.go +++ b/stream/hook.go @@ -59,15 +59,16 @@ func Hook(event HookEvent, params string, body interface{}) (*http.Response, err response, err := SendHookEvent(url, bytes) if err != nil { log.Sugar.Errorf("failed to %s the hook event. err: %s", event.ToString(), err.Error()) + return response, err } else { log.Sugar.Infof("received response for hook %s event: status='%s', response body='%s'", event.ToString(), response.Status, responseBodyToString(response)) } - if err == nil && http.StatusOK != response.StatusCode { - return response, fmt.Errorf("unexpected response status: %s for request %s", response.Status, url) + if http.StatusOK != response.StatusCode { + return response, fmt.Errorf("unexpected response status: %s", response.Status) } - return response, err + return response, nil } func NewHookPlayEventInfo(sink Sink) eventInfo { diff --git a/stream/hook_source.go b/stream/hook_source.go index 59b5c85..191bb69 100644 --- a/stream/hook_source.go +++ b/stream/hook_source.go @@ -2,53 +2,83 @@ package stream import ( "encoding/json" + "fmt" "github.com/lkmio/avformat/utils" "github.com/lkmio/lkm/log" "net/http" "time" ) -func PreparePublishSource(source Source, hook bool) (*http.Response, utils.HookState) { - var response *http.Response - - if err := SourceManager.Add(source); err != nil { - return nil, utils.HookStateOccupy +func AddSource(source Source) error { + err := SourceManager.add(source) + if err == nil { + source.SetState(SessionStateHandshakeSuccess) } - if hook && AppConfig.Hooks.IsEnablePublishEvent() { - rep, state := HookPublishEvent(source) - if utils.HookStateOK != state { + return err +} + +func PreparePublishSource(source Source, add bool) (*http.Response, error) { + var response *http.Response + + if add { + if err := AddSource(source); err != nil { + return nil, err + } + } else if SourceManager.Find(source.GetID()) == nil { + return nil, fmt.Errorf("not found") + } + + if AppConfig.Hooks.IsEnablePublishEvent() { + rep, err := HookPublishEvent(source) + if err != nil { _, _ = SourceManager.Remove(source.GetID()) - return rep, state + return rep, err } response = rep } + // 此时才认为source推流成功 + source.SetState(SessionStateTransferring) source.SetCreateTime(time.Now()) urls := GetStreamPlayUrls(source.GetID()) indent, _ := json.MarshalIndent(urls, "", "\t") - log.Sugar.Infof("%s准备推流 source:%s 拉流地址:\r\n%s", source.GetType().String(), source.GetID(), indent) + log.Sugar.Infof("%s推流 source: %s 拉流地址:\r\n%s", source.GetType().String(), source.GetID(), indent) - source.SetState(SessionStateTransferring) - return response, utils.HookStateOK + return response, nil } -func HookPublishEvent(source Source) (*http.Response, utils.HookState) { - var response *http.Response +func PreparePublishSourceWithAsync(source Source, add bool) { + go func() { + var err error + // 加锁执行, 保证并发安全 + source.ExecuteWithDeleteLock(func() { + if source.IsClosed() { + err = fmt.Errorf("source is closed") + } else if _, err = PreparePublishSource(source, add); err == nil { + } + }) - if AppConfig.Hooks.IsEnablePublishEvent() { - hook, err := Hook(HookEventPublish, source.UrlValues().Encode(), NewHookPublishEventInfo(source)) if err != nil { - return hook, utils.HookStateFailure - } + log.Sugar.Errorf("GB28181推流失败 err: %s source: %s", err.Error(), source.GetID()) - response = hook + if !source.IsClosed() { + source.Close() + } + } + }() + +} + +func HookPublishEvent(source Source) (*http.Response, error) { + if AppConfig.Hooks.IsEnablePublishEvent() { + return Hook(HookEventPublish, source.UrlValues().Encode(), NewHookPublishEventInfo(source)) } - return response, utils.HookStateOK + return nil, nil } func HookPublishDoneEvent(source Source) { diff --git a/stream/mw_buffer.go b/stream/mw_buffer.go index a10169b..1e8923e 100644 --- a/stream/mw_buffer.go +++ b/stream/mw_buffer.go @@ -33,8 +33,8 @@ type MergeWritingBuffer interface { } type mbBuffer struct { - buffer collections.BlockBuffer - segments *collections.Queue[*collections.ReferenceCounter[[]byte]] + buffer collections.BlockBuffer // 合并写内存缓冲区 + segments *collections.Queue[*collections.ReferenceCounter[[]byte]] // 包含多个合并写切片 } type mergeWritingBuffer struct { @@ -56,13 +56,15 @@ func (m *mergeWritingBuffer) TryAlloc(size int, ts int64, videoPkt, videoKey boo buffer := m.buffers.Peek(m.buffers.Size() - 1).buffer bytes := buffer.AvailableBytes() + // 内存不足, 分配新的内存缓冲区 if bytes < size { - // 非完整切片,先保存切片再分配新的内存 + // 让外部先flush, 再分配新的内存 if buffer.PendingBlockSize() > 0 { return nil, false } - // -1, 当前内存池不释放 + // 释放未使用的内存缓冲区 + // -1, 最新的内存缓冲区不释放 release(m.buffers, m.buffers.Size()-1) m.buffers.Push(MWBufferPool.Get().(*mbBuffer)) } @@ -116,6 +118,7 @@ func (m *mergeWritingBuffer) FlushSegment() (*collections.ReferenceCounter[[]byt } if AppConfig.GOPCache { + // +1=2 counter.Refer() m.lastKeyVideoDataSegments.Push(counter) } @@ -172,11 +175,13 @@ func (m *mergeWritingBuffer) HasVideoDataInCurrentSegment() bool { } func (m *mergeWritingBuffer) Close() *collections.Queue[*mbBuffer] { + // 减少关键帧切片的引用计数 for m.lastKeyVideoDataSegments.Size() > 0 { m.lastKeyVideoDataSegments.Pop().Release() } if m.buffers.Size() > 0 && !release(m.buffers, m.buffers.Size()) { + // 还有sink在使用, 返回未释放的内存缓冲区 return m.buffers } diff --git a/stream/mwb_pool.go b/stream/mwb_pool.go index 18897bf..e391100 100644 --- a/stream/mwb_pool.go +++ b/stream/mwb_pool.go @@ -8,6 +8,8 @@ import ( ) const ( + // BlockBufferSize 合并写缓冲区的内存块大小 + // 一块缓冲区可以包含多个合并写切片 BlockBufferSize = 1024 * 1024 * 2 ) @@ -23,37 +25,40 @@ var ( }, } - pendingReleaseBuffers = make(map[string]*collections.Queue[*mbBuffer]) + pendingReleaseBuffers = make(map[string]*collections.Queue[*mbBuffer]) // 等待释放的合并写缓冲区 lock sync.Mutex ) +// AddMWBuffersToPending 添加合并写缓冲区到等待释放队列 func AddMWBuffersToPending(sourceId string, transStreamId TransStreamID, buffers *collections.Queue[*mbBuffer]) { key := fmt.Sprintf("%s-%d", sourceId, transStreamId) lock.Lock() defer lock.Unlock() - for buffers.Size() > 0 { - v, ok := pendingReleaseBuffers[key] - if ok { - // 第二次都推流结束了,第一次的内存还被占用 - // 强制释放上次推流的内存池 - log.Sugar.Warnf("force release last pending buffers of %s", key) + v, ok := pendingReleaseBuffers[key] + if ok { + // 第二次都推流结束了,第一次的内存还被占用 + // 强制释放上次推流的内存池 + log.Sugar.Warnf("force release last pending buffers of %s", key) - for v.Size() > 0 { - pop := v.Pop() - pop.buffer.Clear() - pop.segments.Clear() - MWBufferPool.Put(pop) - } - - delete(pendingReleaseBuffers, key) + for v.Size() > 0 { + pop := v.Pop() + pop.buffer.Clear() + pop.segments.Clear() + MWBufferPool.Put(pop) } + delete(pendingReleaseBuffers, key) + } + + if buffers.Size() > 0 { pendingReleaseBuffers[key] = buffers } } +// ReleasePendingBuffers 释放等待释放的合并写缓冲区 +// 拉流结束后主动调用一次, 创建传输流的时候也调用一次 func ReleasePendingBuffers(sourceId string, transStreamId TransStreamID) { key := fmt.Sprintf("%s-%d", sourceId, transStreamId) @@ -68,6 +73,7 @@ func ReleasePendingBuffers(sourceId string, transStreamId TransStreamID) { delete(pendingReleaseBuffers, key) } +// release 释放合并写缓冲区 func release(buffers *collections.Queue[*mbBuffer], length int) bool { for i := 0; i < length; i++ { buffer := buffers.Peek(0) diff --git a/stream/sink.go b/stream/sink.go index f5bd952..fb73566 100644 --- a/stream/sink.go +++ b/stream/sink.go @@ -175,7 +175,7 @@ func (s *BaseSink) fastForward(firstSegment *collections.ReferenceCounter[[]byte func (s *BaseSink) doAsyncWrite() { defer func() { - // 释放未发送的数据 + // 释放未发送的合并写切片 for buffer := s.pendingSendQueue.Pop(); buffer != nil; buffer = s.pendingSendQueue.Pop() { buffer.Release() } @@ -241,9 +241,12 @@ func (s *BaseSink) doAsyncWrite() { func (s *BaseSink) EnableAsyncWriteMode(queueSize int) { utils.Assert(s.Conn != nil) - s.pendingSendQueue = NewNonBlockingChannel[*collections.ReferenceCounter[[]byte]](queueSize) - s.cancelCtx, s.cancelFunc = context.WithCancel(context.Background()) - go s.doAsyncWrite() + // 只初始化一次 + if s.pendingSendQueue == nil { + s.pendingSendQueue = NewNonBlockingChannel[*collections.ReferenceCounter[[]byte]](queueSize) + s.cancelCtx, s.cancelFunc = context.WithCancel(context.Background()) + go s.doAsyncWrite() + } } func (s *BaseSink) Write(index int, data []*collections.ReferenceCounter[[]byte], ts int64, keyVideo bool) error { diff --git a/stream/source.go b/stream/source.go index 825d750..522fbc2 100644 --- a/stream/source.go +++ b/stream/source.go @@ -80,7 +80,9 @@ type Source interface { StartTimers(source Source) - ExecuteSyncEvent(cb func()) + ExecuteWithStreamLock(cb func()) + + ExecuteWithDeleteLock(cb func()) UpdateReceiveStats(dataLen int) } @@ -105,7 +107,8 @@ type PublishSource struct { createTime time.Time // source创建时间 statistics *BitrateStatistics // 码流统计 streamLogger avformat.OnUnpackStream2FileHandler - streamLock sync.Mutex // 收流、探测超时、关闭等操作互斥锁 + streamLock sync.Mutex // 收流、探测超时等操作互斥锁 + deleteLock sync.Mutex // 双重锁, 防止在关闭source时, 其他操作同时进行 timers struct { receiveTimer *time.Timer // 收流超时计时器 @@ -157,10 +160,8 @@ func (s *PublishSource) Input(data []byte) (int, error) { s.UpdateReceiveStats(len(data)) var n int var err error - s.ExecuteSyncEvent(func() { - if s.closed.Load() { - err = fmt.Errorf("source closed") - } else { + s.ExecuteWithStreamLock(func() { + if !s.closed.Load() { n, err = s.TransDemuxer.Input(data) } }) @@ -176,7 +177,7 @@ func (s *PublishSource) SetState(state SessionState) { s.state = state } -func (s *PublishSource) DoClose() { +func (s *PublishSource) doClose() { log.Sugar.Debugf("closing the %s source. id: %s. closed flag: %t", s.Type, s.ID, s.closed.Load()) // 已关闭, 直接返回 @@ -185,7 +186,7 @@ func (s *PublishSource) DoClose() { } var closed bool - s.ExecuteSyncEvent(func() { + s.ExecuteWithStreamLock(func() { closed = s.closed.Swap(true) }) @@ -221,20 +222,16 @@ func (s *PublishSource) DoClose() { // 同步执行 s.streamPublisher.close() - // 只释放prepare成功的source, 否则在关闭失败的source时, 造成id相同的source被错误释放 - if s.state < SessionStateTransferring { - return - } - s.state = SessionStateClosed - // 释放解复用器 - // 释放转码器 - // 释放每路转协议流, 将所有sink添加到等待队列 - _, err := SourceManager.Remove(s.ID) - if err != nil { - // source不存在, 在创建source时, 未添加到manager中, 目前只有1078流会出现这种情况(tcp连接到端口, 没有推流或推流数据无效, 无法定位到手机号, 以至于无法执行PreparePublishSource函数), 将不再处理后续事情. - log.Sugar.Errorf("删除源失败 source: %s err: %s", s.ID, err.Error()) - return + + // 只删除被添加的source, 否则会造成id相同的source被误删 + if s.state >= SessionStateHandshakeSuccess { + _, err := SourceManager.Remove(s.ID) + if err != nil { + // source不存在, 在创建source时, 未添加到manager中, 目前只有1078流会出现这种情况(tcp连接到端口, 没有推流或推流数据无效, 无法定位到手机号, 以至于无法执行PreparePublishSource函数), 将不再处理后续事情. + log.Sugar.Errorf("删除源失败 source: %s err: %s", s.ID, err.Error()) + return + } } // 异步hook @@ -249,7 +246,9 @@ func (s *PublishSource) DoClose() { } func (s *PublishSource) Close() { - s.DoClose() + s.ExecuteWithDeleteLock(func() { + s.doClose() + }) } // 解析完所有track后, 创建各种输出流 @@ -265,8 +264,8 @@ func (s *PublishSource) writeHeader() { if len(s.originTracks.All()) == 0 { log.Sugar.Errorf("没有一路track, 删除source: %s", s.ID) - // 异步执行ProbeTimeout函数中还没释放锁 - go s.DoClose() + // 此时还持有stream lock, 异步关闭source + go CloseSource(s.ID) return } } @@ -399,7 +398,7 @@ func (s *PublishSource) SetUrlValues(values url.Values) { s.urlValues = values } -func (s *PublishSource) ExecuteSyncEvent(cb func()) { +func (s *PublishSource) ExecuteWithStreamLock(cb func()) { // 无竞争情况下, 接近原子操作 s.streamLock.Lock() defer s.streamLock.Unlock() @@ -420,7 +419,7 @@ func (s *PublishSource) GetBitrateStatistics() *BitrateStatistics { func (s *PublishSource) ProbeTimeout() { if s.TransDemuxer != nil { - s.ExecuteSyncEvent(func() { + s.ExecuteWithStreamLock(func() { if !s.closed.Load() { s.TransDemuxer.ProbeComplete() } @@ -454,3 +453,9 @@ func (s *PublishSource) StartTimers(source Source) { }) } + +func (s *PublishSource) ExecuteWithDeleteLock(cb func()) { + s.deleteLock.Lock() + defer s.deleteLock.Unlock() + cb() +} diff --git a/stream/source_manager.go b/stream/source_manager.go index 8f240e4..5532f09 100644 --- a/stream/source_manager.go +++ b/stream/source_manager.go @@ -16,7 +16,7 @@ type sourceManger struct { m sync.Map } -func (s *sourceManger) Add(source Source) error { +func (s *sourceManger) add(source Source) error { _, ok := s.m.LoadOrStore(source.GetID(), source) if ok { return fmt.Errorf("the source %s has been exist", source.GetID()) diff --git a/stream/source_utils.go b/stream/source_utils.go index 08fa62d..ec5f44b 100644 --- a/stream/source_utils.go +++ b/stream/source_utils.go @@ -195,6 +195,13 @@ func StartIdleTimer(source Source) *time.Timer { return idleTimer } +func CloseSource(id string) { + source := SourceManager.Find(id) + if source != nil { + source.Close() + } +} + // LoopEvent 循环读取事件 func LoopEvent(source Source) { source.StartTimers(source) diff --git a/stream/stream_publisher.go b/stream/stream_publisher.go index 154ba93..93a8cf5 100644 --- a/stream/stream_publisher.go +++ b/stream/stream_publisher.go @@ -341,6 +341,9 @@ func (t *transStreamPublisher) CreateTransStream(protocol TransStreamProtocol, t } } + // 尝试清空等待释放的合并写缓冲区 + ReleasePendingBuffers(t.source, id) + t.transStreams[id] = transStream // 创建输出流对应的拉流队列 t.transStreamSinks[id] = make(map[SinkID]Sink, 128) @@ -700,7 +703,6 @@ func (t *transStreamPublisher) doClose() { // 将所有sink添加到等待队列 for _, sink := range t.sinks { transStreamID := sink.GetTransStreamID() - sink.SetTransStreamID(0) if t.recordSink == sink { continue } diff --git a/stream/stream_server.go b/stream/stream_server.go index 6d74f5b..9fe6901 100644 --- a/stream/stream_server.go +++ b/stream/stream_server.go @@ -18,8 +18,10 @@ type StreamServer[T any] struct { } func (s *StreamServer[T]) OnConnected(conn net.Conn) []byte { - log.Sugar.Debugf("%s连接 conn:%s", s.SourceType.String(), conn.RemoteAddr().String()) - conn.(*transport.Conn).Data = s.Handler.OnNewSession(conn) + log.Sugar.Debugf("%s连接 conn: %s", s.SourceType.String(), conn.RemoteAddr().String()) + if s.Handler != nil { + conn.(*transport.Conn).Data = s.Handler.OnNewSession(conn) + } return nil } @@ -35,7 +37,7 @@ func (s *StreamServer[T]) OnDisConnected(conn net.Conn, err error) { log.Sugar.Debugf("%s断开连接 conn:%s", s.SourceType.String(), conn.RemoteAddr().String()) t := conn.(*transport.Conn) - if t.Data != nil { + if s.Handler != nil && t.Data != nil { s.Handler.OnCloseSession(t.Data.(T)) t.Data = nil } diff --git a/stream/trans_utils.go b/stream/trans_utils.go index c11e313..9f75fb0 100644 --- a/stream/trans_utils.go +++ b/stream/trans_utils.go @@ -2,14 +2,13 @@ package stream import "github.com/lkmio/avformat/utils" -// TransStreamID 每个传输流的唯一Id,根据输出流协议ID+track index生成 -// 输出流协议占低8位 -// 每个音视频编译器ID占用8位. 意味着每个输出流至多7路流. +// TransStreamID 每个传输流的唯一Id, 根据输出流协议ID+track index生成 +// 输出流协议占低8位, track index占用8位, 最多支持7路流. type TransStreamID uint64 func (id TransStreamID) HasTrack(index int) bool { for i := 1; i < 8; i++ { - if int(id>>(i*8))&0xFF == index { + if (int(id>>(i*8))&0xFF)-1 == index { return true } } @@ -21,25 +20,6 @@ func (id TransStreamID) Protocol() TransStreamProtocol { return TransStreamProtocol(id & 0xFF) } -// GenerateTransStreamID 根据传入的推拉流协议和编码器ID生成StreamId -// 请确保ids根据值升序排序传参 -/*func GenerateTransStreamID(protocol GetProtocol, ids ...utils.AVCodecID) GetTransStreamID { - len_ := len(ids) - utils.Assert(len_ > 0 && len_ < 8) - - var streamId uint64 - streamId = uint64(protocol) << 56 - - for i, GetID := range ids { - bId, ok := narrowCodecIds[int(GetID)] - utils.Assert(ok) - - streamId |= uint64(bId) << (48 - i*8) - } - - return GetTransStreamID(streamId) -}*/ - // GenerateTransStreamID 根据输出流协议和输出流包含的音视频编码器ID生成流ID func GenerateTransStreamID(protocol TransStreamProtocol, tracks ...*Track) TransStreamID { len_ := len(tracks) @@ -47,7 +27,8 @@ func GenerateTransStreamID(protocol TransStreamProtocol, tracks ...*Track) Trans var streamId = uint64(protocol) & 0xFF for i, track := range tracks { - streamId |= uint64(track.Stream.Index) << ((i + 1) * 8) + // +1是为了避免0值 + streamId |= uint64(track.Stream.Index+1) << ((i + 1) * 8) } return TransStreamID(streamId)