refactor: gb28181仅支持多端口推流, 提升代码健壮性

This commit is contained in:
ydajiang
2025-08-08 17:14:33 +08:00
parent cac5e91471
commit ca52588bae
31 changed files with 415 additions and 684 deletions

View File

@@ -3,7 +3,6 @@ package main
import ( import (
"fmt" "fmt"
"github.com/lkmio/avformat/bufio" "github.com/lkmio/avformat/bufio"
"github.com/lkmio/avformat/utils"
"github.com/lkmio/lkm/gb28181" "github.com/lkmio/lkm/gb28181"
"github.com/lkmio/lkm/log" "github.com/lkmio/lkm/log"
"github.com/lkmio/lkm/stream" "github.com/lkmio/lkm/stream"
@@ -74,9 +73,7 @@ func (api *ApiServer) OnGBSourceCreate(v *SourceSDP, w http.ResponseWriter, r *h
} }
if tcp && active { if tcp && active {
if !stream.AppConfig.GB28181.IsMultiPort() { if !tcp {
err = fmt.Errorf("单端口模式下不能主动拉流")
} else if !tcp {
err = fmt.Errorf("UDP不能主动拉流") err = fmt.Errorf("UDP不能主动拉流")
} else if !stream.AppConfig.GB28181.IsEnableTCP() { } else if !stream.AppConfig.GB28181.IsEnableTCP() {
err = fmt.Errorf("未开启TCP收流服务,UDP不能主动拉流") err = fmt.Errorf("未开启TCP收流服务,UDP不能主动拉流")
@@ -218,9 +215,9 @@ func (api *ApiServer) OnGBTalk(w http.ResponseWriter, r *http.Request) {
talkSource.Init() talkSource.Init()
talkSource.SetUrlValues(r.Form) talkSource.SetUrlValues(r.Form)
_, state := stream.PreparePublishSource(talkSource, true) _, err = stream.PreparePublishSource(talkSource, true)
if utils.HookStateOK != state { if err != nil {
log.Sugar.Errorf("对讲失败, source: %s", talkSource) log.Sugar.Errorf("对讲失败, err: %s source: %s", err, talkSource)
conn.Close() conn.Close()
return return
} }

View File

@@ -1,10 +0,0 @@
package gb28181
// Filter 关联Source
type Filter interface {
AddSource(ssrc uint32, source GBSource) bool
RemoveSource(ssrc uint32)
FindSource(ssrc uint32) GBSource
}

View File

@@ -1,21 +0,0 @@
package gb28181
type singleFilter struct {
source GBSource
}
func (s *singleFilter) AddSource(ssrc uint32, source GBSource) bool {
panic("implement me")
}
func (s *singleFilter) RemoveSource(ssrc uint32) {
s.source = nil
}
func (s *singleFilter) FindSource(ssrc uint32) GBSource {
return s.source
}
func NewSingleFilter(source GBSource) Filter {
return &singleFilter{source: source}
}

View File

@@ -1,38 +0,0 @@
package gb28181
import (
"sync"
)
type ssrcFilter struct {
sources map[uint32]GBSource
mute sync.RWMutex
}
func (r *ssrcFilter) AddSource(ssrc uint32, source GBSource) bool {
r.mute.Lock()
defer r.mute.Unlock()
if _, ok := r.sources[ssrc]; !ok {
r.sources[ssrc] = source
return true
}
return false
}
func (r *ssrcFilter) RemoveSource(ssrc uint32) {
r.mute.Lock()
defer r.mute.Unlock()
delete(r.sources, ssrc)
}
func (r *ssrcFilter) FindSource(ssrc uint32) GBSource {
r.mute.RLock()
defer r.mute.RUnlock()
return r.sources[ssrc]
}
func NewSSRCFilter(guestCount int) Filter {
return &ssrcFilter{sources: make(map[uint32]GBSource, guestCount)}
}

View File

@@ -24,7 +24,7 @@ import (
func connectSource(source string, addr string) { func connectSource(source string, addr string) {
v := &struct { v := &struct {
Source string `json:"source"` //GetSourceID Source string `json:"source"` //GetSourceID
RemoteAddr string `json:"remote_addr"` RemoteAddr string `json:"addr"`
}{ }{
Source: source, Source: source,
RemoteAddr: addr, RemoteAddr: addr,
@@ -35,7 +35,7 @@ func connectSource(source string, addr string) {
panic(err) panic(err)
} }
request, err := http.NewRequest("POST", "http://localhost:8080/v1/gb28181/source/connect", bytes.NewBuffer(marshal)) request, err := http.NewRequest("POST", "http://localhost:8080/api/v1/gb28181/answer/set", bytes.NewBuffer(marshal))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -209,11 +209,12 @@ func TestPublish(t *testing.T) {
//path := "../../source_files/rtp_ps_h264_G7221_0xBEBC204.raw" //path := "../../source_files/rtp_ps_h264_G7221_0xBEBC204.raw"
//var rawSsrc uint32 = 0xBEBC204 //var rawSsrc uint32 = 0xBEBC204
path := "../../source_files/rtp_ps_h264_G726_0xBEBC205.raw" //path := "../../source_files/rtp_ps_h264_G726_0xBEBC205.raw"
path := "../../source_files/rtp_ps_err_parse.raw"
var rawSsrc uint32 = 0xBEBC205 var rawSsrc uint32 = 0xBEBC205
localAddr := "0.0.0.0:20001" localAddr := "0.0.0.0:20001"
id := "hls_mystream" id := "hls/mystream"
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
@@ -305,7 +306,7 @@ func TestPublish(t *testing.T) {
}) })
t.Run("active", func(t *testing.T) { t.Run("active", func(t *testing.T) {
ip, port, ssrc := createSource(id, "active", rawSsrc) _, _, ssrc := createSource(id, "active", rawSsrc)
addr, _ := net.ResolveTCPAddr("tcp", localAddr) addr, _ := net.ResolveTCPAddr("tcp", localAddr)
server := transport.TCPServer{} server := transport.TCPServer{}
@@ -317,6 +318,7 @@ func TestPublish(t *testing.T) {
ctrDelay(packet[2:]) ctrDelay(packet[2:])
} }
server.Close()
return nil return nil
}, nil, nil) }, nil, nil)
@@ -325,7 +327,9 @@ func TestPublish(t *testing.T) {
panic(err) panic(err)
} }
connectSource(id, fmt.Sprintf("%s:%d", ip, port)) server.Accept()
connectSource(id, localAddr)
select {}
}) })
} }
@@ -336,10 +340,10 @@ func TestDecode(t *testing.T) {
panic(err2) panic(err2)
} }
source := NewPassiveSource() source := &PassiveSource{
source.Init() decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2),
filter := NewSingleFilter(source) }
session := NewTCPSession(nil, filter)
reader := bufio.NewBytesReader(file) reader := bufio.NewBytesReader(file)
for { for {
@@ -353,7 +357,7 @@ func TestDecode(t *testing.T) {
break break
} }
err2 = session.DecodeGBRTPOverTCPPacket(bytes, filter, nil) err2 = source.DecodeGBRTPOverTCPPacket(bytes)
if err2 != nil { if err2 != nil {
break break
} }

View File

@@ -11,7 +11,6 @@ import (
"github.com/lkmio/transport" "github.com/lkmio/transport"
"github.com/pion/rtp" "github.com/pion/rtp"
"math" "math"
"net"
"strings" "strings"
) )
@@ -23,7 +22,6 @@ const (
SetupActive = SetupType(2) SetupActive = SetupType(2)
PsProbeBufferSize = 1024 * 1024 * 2 PsProbeBufferSize = 1024 * 1024 * 2
JitterBufferSize = 1024 * 1024
) )
func (s SetupType) TransportType() stream.TransportType { func (s SetupType) TransportType() stream.TransportType {
@@ -65,8 +63,6 @@ func SetupTypeFromString(setupType string) SetupType {
var ( var (
TransportManger transport.Manager TransportManger transport.Manager
SharedUDPServer *UDPServer
SharedTCPServer *TCPServer
) )
// GBSource GB28181推流Source, 统一解析PS流、级联转发. // GBSource GB28181推流Source, 统一解析PS流、级联转发.
@@ -75,23 +71,20 @@ type GBSource interface {
SetupType() SetupType SetupType() SetupType
// PreparePublish 收到流时, 做一些初始化工作.
PreparePublish(conn net.Conn, ssrc uint32, source GBSource)
SetConn(conn net.Conn)
SetSSRC(ssrc uint32) SetSSRC(ssrc uint32)
SSRC() uint32 SSRC() uint32
ProcessPacket(data []byte) error ProcessPacket(data []byte) error
SetTransport(transport transport.Transport)
} }
type BaseGBSource struct { type BaseGBSource struct {
stream.PublishSource stream.PublishSource
transport transport.Transport
probeBuffer *mpeg.PSProbeBuffer probeBuffer *mpeg.PSProbeBuffer
transport transport.Transport
ssrc uint32 ssrc uint32
audioTimestamp int64 audioTimestamp int64
@@ -103,23 +96,16 @@ type BaseGBSource struct {
sameTimePackets [][]byte sameTimePackets [][]byte
} }
func (source *BaseGBSource) Init() {
source.TransDemuxer = mpeg.NewPSDemuxer(false)
source.TransDemuxer.SetHandler(source)
source.TransDemuxer.SetOnPreprocessPacketHandler(func(packet *avformat.AVPacket) {
source.correctTimestamp(packet, packet.Dts, packet.Pts)
})
source.SetType(stream.SourceType28181)
source.probeBuffer = mpeg.NewProbeBuffer(PsProbeBufferSize)
source.PublishSource.Init()
source.lastRtpTimestamp = -1
}
// ProcessPacket 输入rtp包, 处理PS流, 负责解析->封装->推流 // ProcessPacket 输入rtp包, 处理PS流, 负责解析->封装->推流
func (source *BaseGBSource) ProcessPacket(data []byte) error { func (source *BaseGBSource) ProcessPacket(data []byte) error {
packet := rtp.Packet{} packet := rtp.Packet{}
_ = packet.Unmarshal(data) _ = packet.Unmarshal(data)
// 收到第一包, 初始化
if source.probeBuffer == nil {
source.InitializePublish(packet.SSRC)
}
// 国标级联转发 // 国标级联转发
if source.GetTransStreamPublisher().GetForwardTransStream() != nil { if source.GetTransStreamPublisher().GetForwardTransStream() != nil {
if source.lastRtpTimestamp == -1 { if source.lastRtpTimestamp == -1 {
@@ -228,30 +214,17 @@ func (source *BaseGBSource) correctTimestamp(packet *avformat.AVPacket, dts, pts
} }
func (source *BaseGBSource) Close() { func (source *BaseGBSource) Close() {
log.Sugar.Infof("GB28181推流结束 ssrc:%d %s", source.ssrc, source.PublishSource.String()) log.Sugar.Infof("GB28181推流结束 ssrc: %d %s", source.ssrc, source.PublishSource.String())
// 释放收流端口 source.PublishSource.Close()
// 加锁执行, 保证并发安全
source.ExecuteWithDeleteLock(func() {
if source.transport != nil { if source.transport != nil {
source.transport.Close() source.transport.Close()
source.transport = nil source.transport = nil
} }
})
// 删除ssrc关联
if !stream.AppConfig.GB28181.IsMultiPort() {
if SharedTCPServer != nil {
SharedTCPServer.filter.RemoveSource(source.ssrc)
}
if SharedUDPServer != nil {
SharedUDPServer.filter.RemoveSource(source.ssrc)
}
}
source.PublishSource.Close()
}
func (source *BaseGBSource) SetConn(conn net.Conn) {
source.Conn = conn
} }
func (source *BaseGBSource) SetSSRC(ssrc uint32) { func (source *BaseGBSource) SetSSRC(ssrc uint32) {
@@ -262,27 +235,43 @@ func (source *BaseGBSource) SSRC() uint32 {
return source.ssrc return source.ssrc
} }
func (source *BaseGBSource) PreparePublish(conn net.Conn, ssrc uint32, source_ GBSource) { func (source *BaseGBSource) InitializePublish(ssrc uint32) {
source.SetConn(conn) if source.ssrc != ssrc {
source.SetSSRC(ssrc) log.Sugar.Warnf("创建source的ssrc与实际推流的ssrc不一致, 创建的ssrc: %x 实际推流的ssrc: %x source: %s", source.ssrc, ssrc, source.GetID())
source.SetState(stream.SessionStateTransferring) }
// 初始化ps解复用器
source.TransDemuxer.SetOnPreprocessPacketHandler(func(packet *avformat.AVPacket) {
source.correctTimestamp(packet, packet.Dts, packet.Pts)
})
source.probeBuffer = mpeg.NewProbeBuffer(PsProbeBufferSize)
source.lastRtpTimestamp = -1
source.ssrc = ssrc
source.audioTimestamp = -1 source.audioTimestamp = -1
source.videoTimestamp = -1 source.videoTimestamp = -1
source.audioPacketCreatedTime = -1 source.audioPacketCreatedTime = -1
source.videoPacketCreatedTime = -1 source.videoPacketCreatedTime = -1
if stream.AppConfig.Hooks.IsEnablePublishEvent() { p := stream.SourceManager.Find(source.GetID())
go func() { if p == nil {
if _, state := stream.HookPublishEvent(source_); utils.HookStateOK == state { log.Sugar.Errorf("GB28181推流失败, 未找到source: %s", source.GetID())
source.Close()
return return
} }
log.Sugar.Errorf("GB28181 推流失败 source:%s", source.GetID()) stream.PreparePublishSourceWithAsync(p, false)
if conn != nil { }
conn.Close()
} func (source *BaseGBSource) Init() {
}() // 创建ps解复用器
} source.TransDemuxer = mpeg.NewPSDemuxer(false)
source.TransDemuxer.SetHandler(source)
source.PublishSource.Init()
}
func (source *BaseGBSource) SetTransport(transport transport.Transport) {
source.transport = transport
} }
// NewGBSource 创建国标推流源, 返回监听的收流端口 // NewGBSource 创建国标推流源, 返回监听的收流端口
@@ -294,9 +283,10 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, int,
} }
if active { if active {
utils.Assert(tcp && stream.AppConfig.GB28181.IsEnableTCP() && stream.AppConfig.GB28181.IsMultiPort()) utils.Assert(tcp && stream.AppConfig.GB28181.IsEnableTCP())
} }
var transportServer transport.Transport
var source GBSource var source GBSource
var port int var port int
var err error var err error
@@ -304,55 +294,46 @@ func NewGBSource(id string, ssrc uint32, tcp bool, active bool) (GBSource, int,
if active { if active {
source, port, err = NewActiveSource() source, port, err = NewActiveSource()
} else if tcp { } else if tcp {
transportServer, err = TransportManger.NewTCPServer()
if err != nil {
return nil, 0, err
}
source = NewPassiveSource() source = NewPassiveSource()
transportServer.(*transport.TCPServer).SetHandler(source.(*PassiveSource))
transportServer.(*transport.TCPServer).Accept()
port = transportServer.ListenPort()
} else { } else {
transportServer, err = TransportManger.NewUDPServer()
if err != nil {
return nil, 0, err
}
source = NewUDPSource() source = NewUDPSource()
transportServer.(*transport.UDPServer).SetHandler(source.(*UDPSource))
transportServer.(*transport.UDPServer).Receive()
port = transportServer.ListenPort()
} }
if err != nil { source.SetType(stream.SourceType28181)
return nil, 0, err
}
// 单端口模式绑定ssrc
if !stream.AppConfig.GB28181.IsMultiPort() {
var success bool
if tcp {
success = SharedTCPServer.filter.AddSource(ssrc, source)
} else {
success = SharedUDPServer.filter.AddSource(ssrc, source)
}
if !success {
return nil, 0, fmt.Errorf("ssrc conflict")
}
port = stream.AppConfig.GB28181.Port[0]
} else if !active {
// 多端口模式, 创建收流Server
if tcp {
tcpServer, err := NewTCPServer(NewSingleFilter(source))
if err != nil {
return nil, 0, err
}
port = tcpServer.tcp.ListenPort()
source.(*PassiveSource).transport = tcpServer.tcp
} else {
server, err := NewUDPServer(NewSingleFilter(source))
if err != nil {
return nil, 0, err
}
port = server.udp.ListenPort()
source.(*UDPSource).transport = server.udp
}
}
source.SetID(id) source.SetID(id)
source.SetSSRC(ssrc) source.SetSSRC(ssrc)
// 加锁保护一下, 防止初始化阶段, 调用关闭source接口, 发生并发安全问题
source.ExecuteWithDeleteLock(func() {
if err = stream.AddSource(source); err != nil {
return
}
source.SetTransport(transportServer)
source.Init() source.Init()
if _, state := stream.PreparePublishSource(source, false); utils.HookStateOK != state { })
return nil, 0, fmt.Errorf("error code %d", state)
// id冲突
if err != nil {
if transportServer != nil {
transportServer.Close()
}
return nil, 0, err
} }
stream.LoopEvent(source) stream.LoopEvent(source)

View File

@@ -1,24 +1,30 @@
package gb28181 package gb28181
import ( import (
"github.com/lkmio/lkm/stream"
"github.com/lkmio/transport"
"net" "net"
) )
type ActiveSource struct { type ActiveSource struct {
PassiveSource *PassiveSource
port int port int
remoteAddr net.TCPAddr remoteAddr net.TCPAddr
tcp *TCPClient
} }
func (a *ActiveSource) Connect(remoteAddr *net.TCPAddr) error { func (a *ActiveSource) Connect(remoteAddr *net.TCPAddr) error {
client, err := NewTCPClient(a.port, remoteAddr, a) client := &transport.TCPClient{}
client.SetHandler(a.PassiveSource)
addr, err := net.ResolveTCPAddr("tcp", stream.ListenAddr(a.port))
if err != nil { if err != nil {
return err return err
} else if _, err = client.Connect(addr, remoteAddr); err != nil {
return err
} }
a.tcp = client go client.Receive()
a.transport = client
return nil return nil
} }
@@ -28,12 +34,23 @@ func (a *ActiveSource) SetupType() SetupType {
func NewActiveSource() (*ActiveSource, int, error) { func NewActiveSource() (*ActiveSource, int, error) {
var port int var port int
TransportManger.AllocPort(true, func(port_ uint16) error { err := TransportManger.AllocPort(true, func(port_ uint16) error {
port = int(port_) port = int(port_)
return nil return nil
}) })
if err != nil {
return nil, 0, err
}
return &ActiveSource{ return &ActiveSource{
PassiveSource: &PassiveSource{
StreamServer: stream.StreamServer[GBSource]{
SourceType: stream.SourceType28181,
},
decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2),
receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte),
},
port: port, port: port,
}, port, nil }, port, nil
} }

View File

@@ -1,13 +1,105 @@
package gb28181 package gb28181
import (
"encoding/hex"
"github.com/lkmio/lkm/log"
"github.com/lkmio/lkm/stream"
"github.com/lkmio/transport"
"net"
)
type PassiveSource struct { type PassiveSource struct {
stream.StreamServer[GBSource]
BaseGBSource BaseGBSource
decoder *transport.LengthFieldFrameDecoder
receiveBuffer []byte
remoteAddr string
} }
func (p *PassiveSource) SetupType() SetupType { func (p *PassiveSource) SetupType() SetupType {
return SetupPassive return SetupPassive
} }
func NewPassiveSource() *PassiveSource { func (p *PassiveSource) Close() {
return &PassiveSource{} p.BaseGBSource.Close()
stream.TCPReceiveBufferPool.Put(p.receiveBuffer[:cap(p.receiveBuffer)])
}
func (p *PassiveSource) DecodeGBRTPOverTCPPacket(data []byte) error {
length := len(data)
for i := 0; i < length; {
// 解析粘包数据
n, bytes, err := p.decoder.Input(data[i:])
if err != nil {
return err
}
i += n
if bytes == nil {
break
}
if err = p.ProcessPacket(bytes); err != nil {
return err
}
}
return nil
}
func (p *PassiveSource) OnConnected(conn net.Conn) []byte {
p.StreamServer.OnConnected(conn)
var ok bool
p.ExecuteWithDeleteLock(func() {
if p.IsClosed() {
log.Sugar.Infof("source %s 已关闭, 拒绝新连接", p.GetID())
} else if ok = p.PublishSource.Conn == nil; ok {
// 一个推流一个端口, 默认第一个连接为有效连接, 关闭其他连接
p.PublishSource.Conn = conn
p.remoteAddr = conn.RemoteAddr().String()
} else {
log.Sugar.Infof("port %d 已连接, 关闭连接. source: %s", p.transport.ListenPort(), p.GetID())
}
})
if !ok {
_ = conn.Close()
return nil
}
return p.receiveBuffer
}
func (p *PassiveSource) OnPacket(conn net.Conn, data []byte) []byte {
p.StreamServer.OnPacket(conn, data)
err := p.DecodeGBRTPOverTCPPacket(data)
if err != nil {
log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data))
_ = conn.Close()
return nil
}
return p.receiveBuffer
}
func (p *PassiveSource) OnDisConnected(conn net.Conn, err error) {
p.StreamServer.OnDisConnected(conn, err)
if conn.RemoteAddr().String() == p.remoteAddr {
p.Close()
}
}
func NewPassiveSource() *PassiveSource {
source := &PassiveSource{
StreamServer: stream.StreamServer[GBSource]{
SourceType: stream.SourceType28181,
},
decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2),
receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte),
}
return source
} }

View File

@@ -1,14 +1,16 @@
package gb28181 package gb28181
import ( import (
"github.com/lkmio/lkm/log"
"github.com/lkmio/lkm/stream" "github.com/lkmio/lkm/stream"
"github.com/pion/rtp" "github.com/pion/rtp"
"net"
) )
// UDPSource 国标UDP推流源 // UDPSource 国标UDP推流源
type UDPSource struct { type UDPSource struct {
stream.StreamServer[interface{}]
BaseGBSource BaseGBSource
jitterBuffer *stream.JitterBuffer[*rtp.Packet] jitterBuffer *stream.JitterBuffer[*rtp.Packet]
} }
@@ -18,12 +20,12 @@ func (u *UDPSource) SetupType() SetupType {
// OnOrderedRtp 有序RTP包回调 // OnOrderedRtp 有序RTP包回调
func (u *UDPSource) OnOrderedRtp(packet *rtp.Packet) { func (u *UDPSource) OnOrderedRtp(packet *rtp.Packet) {
// 此时还在网络收流携程, 交给Source的主协程处理 _ = u.ProcessPacket(packet.Raw)
u.ProcessPacket(packet.Raw) // 处理完后, 归还buffer
stream.UDPReceiveBufferPool.Put(packet.Raw[:cap(packet.Raw)]) stream.UDPReceiveBufferPool.Put(packet.Raw[:cap(packet.Raw)])
} }
// InputRtpPacket 将RTP包排序后交给Source的主协程处理 // InputRtpPacket 将RTP包排序后交给Source处理
func (u *UDPSource) InputRtpPacket(pkt *rtp.Packet) error { func (u *UDPSource) InputRtpPacket(pkt *rtp.Packet) error {
block := stream.UDPReceiveBufferPool.Get().([]byte) block := stream.UDPReceiveBufferPool.Get().([]byte)
copy(block, pkt.Raw) copy(block, pkt.Raw)
@@ -45,8 +47,31 @@ func (u *UDPSource) Close() {
u.BaseGBSource.Close() u.BaseGBSource.Close()
} }
func (u *UDPSource) OnPacket(conn net.Conn, data []byte) []byte {
u.StreamServer.OnPacket(conn, data)
packet := rtp.Packet{}
err := packet.Unmarshal(data)
if err != nil {
log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String())
return nil
} else if u.Conn == nil {
u.Conn = conn
}
packet.Raw = data
_ = u.InputRtpPacket(&packet)
return nil
}
func NewUDPSource() *UDPSource { func NewUDPSource() *UDPSource {
return &UDPSource{ source := &UDPSource{
jitterBuffer: stream.NewJitterBuffer[*rtp.Packet](), jitterBuffer: stream.NewJitterBuffer[*rtp.Packet](),
} }
source.StreamServer = stream.StreamServer[interface{}]{
SourceType: stream.SourceType28181,
}
return source
} }

View File

@@ -2,7 +2,6 @@ package gb28181
import ( import (
"fmt" "fmt"
"strconv"
"sync" "sync"
) )
@@ -13,7 +12,6 @@ const (
var ( var (
ssrcCount uint32 ssrcCount uint32
lock sync.Mutex lock sync.Mutex
SSRCFilters []Filter
) )
func NextSSRC() uint32 { func NextSSRC() uint32 {
@@ -23,19 +21,7 @@ func NextSSRC() uint32 {
return ssrcCount return ssrcCount
} }
func getUniqueSSRC(ssrc string, get func() string) string { func getUniqueSSRC(ssrc string, _ func() string) string {
atoi, err := strconv.Atoi(ssrc)
if err != nil {
panic(err)
}
v := uint32(atoi)
for _, filter := range SSRCFilters {
if filter.FindSource(v) != nil {
ssrc = get()
}
}
return ssrc return ssrc
} }

View File

@@ -1,28 +0,0 @@
package gb28181
import (
"github.com/lkmio/lkm/stream"
"github.com/lkmio/transport"
"net"
)
// TCPClient GB28181TCP主动收流
type TCPClient struct {
TCPServer
}
func NewTCPClient(listenPort int, remoteAddr *net.TCPAddr, source GBSource) (*TCPClient, error) {
client := &TCPClient{
TCPServer{filter: NewSingleFilter(source)},
}
tcp := transport.TCPClient{}
tcp.SetHandler(client)
addr, err := net.ResolveTCPAddr("tcp", stream.ListenAddr(listenPort))
if err != nil {
return client, err
}
_, err = tcp.Connect(addr, remoteAddr)
return client, err
}

View File

@@ -1,96 +0,0 @@
package gb28181
import (
"encoding/hex"
"github.com/lkmio/lkm/log"
"github.com/lkmio/lkm/stream"
"github.com/lkmio/transport"
"net"
"runtime"
)
// TCPServer GB28181TCP被动收流
type TCPServer struct {
stream.StreamServer[*TCPSession]
tcp *transport.TCPServer
filter Filter
}
func (T *TCPServer) OnNewSession(conn net.Conn) *TCPSession {
return NewTCPSession(conn, T.filter)
}
func (T *TCPServer) OnCloseSession(session *TCPSession) {
session.Close()
if session.source != nil {
T.filter.RemoveSource(session.source.SSRC())
}
if stream.AppConfig.GB28181.IsMultiPort() {
T.tcp.Close()
T.Handler = nil
}
}
func (T *TCPServer) OnConnected(conn net.Conn) []byte {
T.StreamServer.OnConnected(conn)
return conn.(*transport.Conn).Data.(*TCPSession).receiveBuffer
}
func (T *TCPServer) OnPacket(conn net.Conn, data []byte) []byte {
T.StreamServer.OnPacket(conn, data)
session := conn.(*transport.Conn).Data.(*TCPSession)
err := session.DecodeGBRTPOverTCPPacket(data, T.filter, conn)
if err != nil {
log.Sugar.Errorf("解析rtp失败 err: %s conn: %s data: %s", err.Error(), conn.RemoteAddr().String(), hex.EncodeToString(data))
_ = conn.Close()
return nil
}
return session.receiveBuffer
}
func NewTCPServer(filter Filter) (*TCPServer, error) {
server := &TCPServer{
filter: filter,
}
var tcp *transport.TCPServer
var err error
if stream.AppConfig.GB28181.IsMultiPort() {
tcp = &transport.TCPServer{}
tcp, err = TransportManger.NewTCPServer()
if err != nil {
return nil, err
}
} else {
tcp = &transport.TCPServer{
ReuseServer: transport.ReuseServer{
EnableReuse: true,
ConcurrentNumber: runtime.NumCPU(),
},
}
var gbAddr *net.TCPAddr
gbAddr, err = net.ResolveTCPAddr("tcp", stream.ListenAddr(stream.AppConfig.GB28181.Port[0]))
if err != nil {
return nil, err
}
if err = tcp.Bind(gbAddr); err != nil {
return server, err
}
}
tcp.SetHandler(server)
tcp.Accept()
server.tcp = tcp
server.StreamServer = stream.StreamServer[*TCPSession]{
SourceType: stream.SourceType28181,
Handler: server,
}
return server, nil
}

View File

@@ -1,88 +0,0 @@
package gb28181
import (
"fmt"
"github.com/lkmio/lkm/stream"
"github.com/lkmio/transport"
"github.com/pion/rtp"
"net"
)
// TCPSession 国标TCP主被动推流Session, 统一处理TCP粘包.
type TCPSession struct {
conn net.Conn
source GBSource
decoder *transport.LengthFieldFrameDecoder
receiveBuffer []byte
}
func (t *TCPSession) Init(source GBSource) {
t.source = source
}
func (t *TCPSession) Close() {
t.conn = nil
if t.source != nil {
t.source.Close()
t.source = nil
}
stream.TCPReceiveBufferPool.Put(t.receiveBuffer[:cap(t.receiveBuffer)])
}
func (t *TCPSession) DecodeGBRTPOverTCPPacket(data []byte, filter Filter, conn net.Conn) error {
length := len(data)
for i := 0; i < length; {
// 解析粘包数据
n, bytes, err := t.decoder.Input(data[i:])
if err != nil {
return err
}
i += n
if bytes == nil {
break
}
// 单端口模式,ssrc匹配source
if t.source == nil || stream.SessionStateHandshakeSuccess == t.source.State() {
packet := rtp.Packet{}
if err = packet.Unmarshal(bytes); err != nil {
return err
} else if t.source == nil {
t.source = filter.FindSource(packet.SSRC)
}
if t.source == nil {
// ssrc 匹配不到Source
return fmt.Errorf("gb28181推流失败 ssrc: %x 匹配不到source", packet.SSRC)
}
if stream.SessionStateHandshakeSuccess == t.source.State() {
t.source.PreparePublish(conn, packet.SSRC, t.source)
}
}
if err = t.source.ProcessPacket(bytes); err != nil {
return err
}
}
return nil
}
func NewTCPSession(conn net.Conn, filter Filter) *TCPSession {
session := &TCPSession{
conn: conn,
// filter: filter,
decoder: transport.NewLengthFieldFrameDecoder(0xFFFF, 2),
receiveBuffer: stream.TCPReceiveBufferPool.Get().([]byte),
}
// 多端口收流, Source已知, 直接初始化Session
if stream.AppConfig.GB28181.IsMultiPort() {
session.Init(filter.(*singleFilter).source)
}
return session
}

View File

@@ -1,91 +0,0 @@
package gb28181
import (
"github.com/lkmio/lkm/log"
"github.com/lkmio/lkm/stream"
"github.com/lkmio/transport"
"github.com/pion/rtp"
"net"
"runtime"
)
// UDPServer GB28181UDP收流
type UDPServer struct {
stream.StreamServer[*UDPSource]
udp *transport.UDPServer
filter Filter
}
func (U *UDPServer) OnNewSession(_ net.Conn) *UDPSource {
return nil
}
func (U *UDPServer) OnCloseSession(_ *UDPSource) {
}
func (U *UDPServer) OnPacket(conn net.Conn, data []byte) []byte {
U.StreamServer.OnPacket(conn, data)
packet := rtp.Packet{}
err := packet.Unmarshal(data)
if err != nil {
log.Sugar.Errorf("解析rtp失败 err:%s conn:%s", err.Error(), conn.RemoteAddr().String())
return nil
}
source := U.filter.FindSource(packet.SSRC)
if source == nil {
log.Sugar.Errorf("ssrc匹配source失败 ssrc:%x conn:%s", packet.SSRC, conn.RemoteAddr().String())
return nil
}
if stream.SessionStateHandshakeSuccess == source.State() {
conn.(*transport.Conn).Data = source
source.PreparePublish(conn, packet.SSRC, source)
}
packet.Raw = data
_ = source.(*UDPSource).InputRtpPacket(&packet)
return nil
}
func NewUDPServer(filter Filter) (*UDPServer, error) {
server := &UDPServer{
filter: filter,
}
var udp *transport.UDPServer
var err error
if stream.AppConfig.GB28181.IsMultiPort() {
udp, err = TransportManger.NewUDPServer()
if err != nil {
return nil, err
}
} else {
udp = &transport.UDPServer{
ReuseServer: transport.ReuseServer{
EnableReuse: true,
ConcurrentNumber: runtime.NumCPU(),
},
}
var gbAddr *net.UDPAddr
gbAddr, err = net.ResolveUDPAddr("udp", stream.ListenAddr(stream.AppConfig.GB28181.Port[0]))
if err != nil {
return nil, err
}
if err = udp.Bind(gbAddr); err != nil {
return server, err
}
}
udp.SetHandler(server)
udp.Receive()
server.udp = udp
server.StreamServer = stream.StreamServer[*UDPSource]{
SourceType: stream.SourceType28181,
Handler: server,
}
return server, nil
}

View File

@@ -26,6 +26,7 @@ func (s *jtServer) OnNewSession(conn net.Conn) *Session {
func (s *jtServer) OnCloseSession(session *Session) { func (s *jtServer) OnCloseSession(session *Session) {
session.Close() session.Close()
stream.TCPReceiveBufferPool.Put(session.receiveBuffer[:cap(session.receiveBuffer)])
} }
func (s *jtServer) OnPacket(conn net.Conn, data []byte) []byte { func (s *jtServer) OnPacket(conn net.Conn, data []byte) []byte {

View File

@@ -1,7 +1,6 @@
package jt1078 package jt1078
import ( import (
"github.com/lkmio/avformat/utils"
"github.com/lkmio/lkm/log" "github.com/lkmio/lkm/log"
"github.com/lkmio/lkm/stream" "github.com/lkmio/lkm/stream"
"github.com/lkmio/transport" "github.com/lkmio/transport"
@@ -33,20 +32,10 @@ func (s *Session) Input(data []byte) (int, error) {
return -1, err return -1, err
} }
// 首包处理, hook通知 // 首包处理
if firstOfPacket && demuxer.prevPacket != nil { if firstOfPacket && demuxer.prevPacket != nil {
s.SetID(demuxer.sim + "/" + strconv.Itoa(demuxer.channel)) s.SetID(demuxer.sim + "/" + strconv.Itoa(demuxer.channel))
stream.PreparePublishSourceWithAsync(s, true)
go func() {
_, state := stream.PreparePublishSource(s, true)
if utils.HookStateOK != state {
log.Sugar.Errorf("1078推流失败 source: %s", demuxer.sim)
if s.Conn != nil {
s.Conn.Close()
}
}
}()
} }
} }
@@ -56,13 +45,7 @@ func (s *Session) Input(data []byte) (int, error) {
func (s *Session) Close() { func (s *Session) Close() {
log.Sugar.Infof("1078推流结束 %s", s.String()) log.Sugar.Infof("1078推流结束 %s", s.String())
if s.Conn != nil {
s.Conn.Close()
s.Conn = nil
}
s.PublishSource.Close() s.PublishSource.Close()
stream.TCPReceiveBufferPool.Put(s.receiveBuffer[:cap(s.receiveBuffer)])
} }
func NewSession(conn net.Conn, version int) *Session { func NewSession(conn net.Conn, version int) *Session {

View File

@@ -218,9 +218,9 @@ func TestPublish(t *testing.T) {
}) })
t.Run("publish", func(t *testing.T) { t.Run("publish", func(t *testing.T) {
//path := "../../source_files/10352264314-2.bin" path := "../../source_files/10352264314-2.bin"
//path := "../../source_files/013800138000-1.bin" //path := "../../source_files/013800138000-1.bin"
path := "../../source_files/0714-1.bin" //path := "../../source_files/0714-1.bin"
publish(path, "1078") publish(path, "1078")
}) })

32
main.go
View File

@@ -75,11 +75,11 @@ func init() {
// 初始化日志 // 初始化日志
log.InitLogger(config.Log.FileLogging, zapcore.Level(stream.AppConfig.Log.Level), stream.AppConfig.Log.Name, stream.AppConfig.Log.MaxSize, stream.AppConfig.Log.MaxBackup, stream.AppConfig.Log.MaxAge, stream.AppConfig.Log.Compress) log.InitLogger(config.Log.FileLogging, zapcore.Level(stream.AppConfig.Log.Level), stream.AppConfig.Log.Name, stream.AppConfig.Log.MaxSize, stream.AppConfig.Log.MaxBackup, stream.AppConfig.Log.MaxAge, stream.AppConfig.Log.Compress)
if stream.AppConfig.GB28181.Enable && stream.AppConfig.GB28181.IsMultiPort() { if stream.AppConfig.GB28181.Enable {
gb28181.TransportManger = transport.NewTransportManager(config.ListenIP, uint16(stream.AppConfig.GB28181.Port[0]), uint16(stream.AppConfig.GB28181.Port[1])) gb28181.TransportManger = transport.NewTransportManager(config.ListenIP, uint16(stream.AppConfig.GB28181.Port[0]), uint16(stream.AppConfig.GB28181.Port[1]))
} }
if stream.AppConfig.Rtsp.Enable && stream.AppConfig.Rtsp.IsMultiPort() { if stream.AppConfig.Rtsp.Enable {
rtsp.TransportManger = transport.NewTransportManager(config.ListenIP, uint16(stream.AppConfig.Rtsp.Port[1]), uint16(stream.AppConfig.Rtsp.Port[2])) rtsp.TransportManger = transport.NewTransportManager(config.ListenIP, uint16(stream.AppConfig.Rtsp.Port[1]), uint16(stream.AppConfig.Rtsp.Port[2]))
} }
@@ -134,33 +134,7 @@ func main() {
log.Sugar.Info("启动http服务 addr:", stream.ListenAddr(stream.AppConfig.Http.Port)) log.Sugar.Info("启动http服务 addr:", stream.ListenAddr(stream.AppConfig.Http.Port))
go startApiServer(net.JoinHostPort(stream.AppConfig.ListenIP, strconv.Itoa(stream.AppConfig.Http.Port))) go startApiServer(net.JoinHostPort(stream.AppConfig.ListenIP, strconv.Itoa(stream.AppConfig.Http.Port)))
// 单端口模式下, 启动时就创建收流端口 // GB28181收流时调用api创建收流端口
// 多端口模式下, 创建GBSource时才创建收流端口
if stream.AppConfig.GB28181.Enable && !stream.AppConfig.GB28181.IsMultiPort() {
if stream.AppConfig.GB28181.IsEnableUDP() {
filter := gb28181.NewSSRCFilter(128)
server, err := gb28181.NewUDPServer(filter)
if err != nil {
panic(err)
}
gb28181.SharedUDPServer = server
log.Sugar.Info("启动GB28181 udp收流端口成功:" + stream.ListenAddr(stream.AppConfig.GB28181.Port[0]))
gb28181.SSRCFilters = append(gb28181.SSRCFilters, filter)
}
if stream.AppConfig.GB28181.IsEnableTCP() {
filter := gb28181.NewSSRCFilter(128)
server, err := gb28181.NewTCPServer(filter)
if err != nil {
panic(err)
}
gb28181.SharedTCPServer = server
log.Sugar.Info("启动GB28181 tcp收流端口成功:" + stream.ListenAddr(stream.AppConfig.GB28181.Port[0]))
gb28181.SSRCFilters = append(gb28181.SSRCFilters, filter)
}
}
if stream.AppConfig.JT1078.Enable { if stream.AppConfig.JT1078.Enable {
// 无法通过包头区分2016和2019, 每个版本创建一个Server // 无法通过包头区分2016和2019, 每个版本创建一个Server

View File

@@ -6,6 +6,7 @@ import (
"github.com/lkmio/lkm/stream" "github.com/lkmio/lkm/stream"
"github.com/lkmio/rtmp" "github.com/lkmio/rtmp"
"net" "net"
"strings"
) )
// Session RTMP会话, 解析处理Message // Session RTMP会话, 解析处理Message
@@ -40,9 +41,17 @@ func (s *Session) OnPublish(app, stream_ string) utils.HookState {
source.SetUrlValues(values) source.SetUrlValues(values)
// 统一处理source推流事件, source是否已经存在, hook回调.... // 统一处理source推流事件, source是否已经存在, hook回调....
_, state := stream.PreparePublishSource(source, true) state := utils.HookStateOK
if utils.HookStateOK != state { _, err := stream.PreparePublishSource(source, true)
log.Sugar.Errorf("rtmp推流失败 source: %s", sourceId) if err != nil {
str := err.Error()
log.Sugar.Errorf("rtmp推流失败 source: %s err: %s", sourceId, str)
if strings.HasSuffix(str, "exist") {
state = utils.HookStateOccupy
} else {
state = utils.HookStateFailure
}
} else { } else {
s.handle = source s.handle = source
s.isPublisher = true s.isPublisher = true
@@ -77,7 +86,7 @@ func (s *Session) Input(data []byte) error {
s.handle.(*Publisher).UpdateReceiveStats(len(data)) s.handle.(*Publisher).UpdateReceiveStats(len(data))
var err error var err error
s.handle.(*Publisher).ExecuteSyncEvent(func() { s.handle.(*Publisher).ExecuteWithStreamLock(func() {
err = s.stack.Input(s.conn, data) err = s.stack.Input(s.conn, data)
}) })

View File

@@ -124,14 +124,6 @@ func (g TransportConfig) IsEnableUDP() bool {
return strings.Contains(g.Transport, "UDP") return strings.Contains(g.Transport, "UDP")
} }
func (g GB28181Config) IsMultiPort() bool {
return len(g.Port) > 1
}
func (g RtspConfig) IsMultiPort() bool {
return len(g.Port) == 3
}
// M3U8Path 根据sourceId返回m3u8的磁盘路径 // M3U8Path 根据sourceId返回m3u8的磁盘路径
// 切片及目录生成规则, 以SourceId为34020000001320000001/34020000001320000001为例: // 切片及目录生成规则, 以SourceId为34020000001320000001/34020000001320000001为例:
// 创建文件夹34020000001320000001, 34020000001320000001.m3u8文件, 文件列表中切片url为34020000001320000001_seq.ts // 创建文件夹34020000001320000001, 34020000001320000001.m3u8文件, 文件列表中切片url为34020000001320000001_seq.ts

View File

@@ -59,15 +59,16 @@ func Hook(event HookEvent, params string, body interface{}) (*http.Response, err
response, err := SendHookEvent(url, bytes) response, err := SendHookEvent(url, bytes)
if err != nil { if err != nil {
log.Sugar.Errorf("failed to %s the hook event. err: %s", event.ToString(), err.Error()) log.Sugar.Errorf("failed to %s the hook event. err: %s", event.ToString(), err.Error())
return response, err
} else { } else {
log.Sugar.Infof("received response for hook %s event: status='%s', response body='%s'", event.ToString(), response.Status, responseBodyToString(response)) log.Sugar.Infof("received response for hook %s event: status='%s', response body='%s'", event.ToString(), response.Status, responseBodyToString(response))
} }
if err == nil && http.StatusOK != response.StatusCode { if http.StatusOK != response.StatusCode {
return response, fmt.Errorf("unexpected response status: %s for request %s", response.Status, url) return response, fmt.Errorf("unexpected response status: %s", response.Status)
} }
return response, err return response, nil
} }
func NewHookPlayEventInfo(sink Sink) eventInfo { func NewHookPlayEventInfo(sink Sink) eventInfo {

View File

@@ -2,53 +2,83 @@ package stream
import ( import (
"encoding/json" "encoding/json"
"fmt"
"github.com/lkmio/avformat/utils" "github.com/lkmio/avformat/utils"
"github.com/lkmio/lkm/log" "github.com/lkmio/lkm/log"
"net/http" "net/http"
"time" "time"
) )
func PreparePublishSource(source Source, hook bool) (*http.Response, utils.HookState) { func AddSource(source Source) error {
var response *http.Response err := SourceManager.add(source)
if err == nil {
if err := SourceManager.Add(source); err != nil { source.SetState(SessionStateHandshakeSuccess)
return nil, utils.HookStateOccupy
} }
if hook && AppConfig.Hooks.IsEnablePublishEvent() { return err
rep, state := HookPublishEvent(source) }
if utils.HookStateOK != state {
func PreparePublishSource(source Source, add bool) (*http.Response, error) {
var response *http.Response
if add {
if err := AddSource(source); err != nil {
return nil, err
}
} else if SourceManager.Find(source.GetID()) == nil {
return nil, fmt.Errorf("not found")
}
if AppConfig.Hooks.IsEnablePublishEvent() {
rep, err := HookPublishEvent(source)
if err != nil {
_, _ = SourceManager.Remove(source.GetID()) _, _ = SourceManager.Remove(source.GetID())
return rep, state return rep, err
} }
response = rep response = rep
} }
// 此时才认为source推流成功
source.SetState(SessionStateTransferring)
source.SetCreateTime(time.Now()) source.SetCreateTime(time.Now())
urls := GetStreamPlayUrls(source.GetID()) urls := GetStreamPlayUrls(source.GetID())
indent, _ := json.MarshalIndent(urls, "", "\t") indent, _ := json.MarshalIndent(urls, "", "\t")
log.Sugar.Infof("%s准备推流 source:%s 拉流地址:\r\n%s", source.GetType().String(), source.GetID(), indent) log.Sugar.Infof("%s推流 source: %s 拉流地址:\r\n%s", source.GetType().String(), source.GetID(), indent)
source.SetState(SessionStateTransferring) return response, nil
return response, utils.HookStateOK
} }
func HookPublishEvent(source Source) (*http.Response, utils.HookState) { func PreparePublishSourceWithAsync(source Source, add bool) {
var response *http.Response go func() {
var err error
// 加锁执行, 保证并发安全
source.ExecuteWithDeleteLock(func() {
if source.IsClosed() {
err = fmt.Errorf("source is closed")
} else if _, err = PreparePublishSource(source, add); err == nil {
}
})
if AppConfig.Hooks.IsEnablePublishEvent() {
hook, err := Hook(HookEventPublish, source.UrlValues().Encode(), NewHookPublishEventInfo(source))
if err != nil { if err != nil {
return hook, utils.HookStateFailure log.Sugar.Errorf("GB28181推流失败 err: %s source: %s", err.Error(), source.GetID())
if !source.IsClosed() {
source.Close()
}
}
}()
}
func HookPublishEvent(source Source) (*http.Response, error) {
if AppConfig.Hooks.IsEnablePublishEvent() {
return Hook(HookEventPublish, source.UrlValues().Encode(), NewHookPublishEventInfo(source))
} }
response = hook return nil, nil
}
return response, utils.HookStateOK
} }
func HookPublishDoneEvent(source Source) { func HookPublishDoneEvent(source Source) {

View File

@@ -33,8 +33,8 @@ type MergeWritingBuffer interface {
} }
type mbBuffer struct { type mbBuffer struct {
buffer collections.BlockBuffer buffer collections.BlockBuffer // 合并写内存缓冲区
segments *collections.Queue[*collections.ReferenceCounter[[]byte]] segments *collections.Queue[*collections.ReferenceCounter[[]byte]] // 包含多个合并写切片
} }
type mergeWritingBuffer struct { type mergeWritingBuffer struct {
@@ -56,13 +56,15 @@ func (m *mergeWritingBuffer) TryAlloc(size int, ts int64, videoPkt, videoKey boo
buffer := m.buffers.Peek(m.buffers.Size() - 1).buffer buffer := m.buffers.Peek(m.buffers.Size() - 1).buffer
bytes := buffer.AvailableBytes() bytes := buffer.AvailableBytes()
// 内存不足, 分配新的内存缓冲区
if bytes < size { if bytes < size {
// 非完整切片,先保存切片再分配新的内存 // 让外部先flush, 再分配新的内存
if buffer.PendingBlockSize() > 0 { if buffer.PendingBlockSize() > 0 {
return nil, false return nil, false
} }
// -1, 当前内存池不释放 // 释放未使用的内存缓冲区
// -1, 最新的内存缓冲区不释放
release(m.buffers, m.buffers.Size()-1) release(m.buffers, m.buffers.Size()-1)
m.buffers.Push(MWBufferPool.Get().(*mbBuffer)) m.buffers.Push(MWBufferPool.Get().(*mbBuffer))
} }
@@ -116,6 +118,7 @@ func (m *mergeWritingBuffer) FlushSegment() (*collections.ReferenceCounter[[]byt
} }
if AppConfig.GOPCache { if AppConfig.GOPCache {
// +1=2
counter.Refer() counter.Refer()
m.lastKeyVideoDataSegments.Push(counter) m.lastKeyVideoDataSegments.Push(counter)
} }
@@ -172,11 +175,13 @@ func (m *mergeWritingBuffer) HasVideoDataInCurrentSegment() bool {
} }
func (m *mergeWritingBuffer) Close() *collections.Queue[*mbBuffer] { func (m *mergeWritingBuffer) Close() *collections.Queue[*mbBuffer] {
// 减少关键帧切片的引用计数
for m.lastKeyVideoDataSegments.Size() > 0 { for m.lastKeyVideoDataSegments.Size() > 0 {
m.lastKeyVideoDataSegments.Pop().Release() m.lastKeyVideoDataSegments.Pop().Release()
} }
if m.buffers.Size() > 0 && !release(m.buffers, m.buffers.Size()) { if m.buffers.Size() > 0 && !release(m.buffers, m.buffers.Size()) {
// 还有sink在使用, 返回未释放的内存缓冲区
return m.buffers return m.buffers
} }

View File

@@ -8,6 +8,8 @@ import (
) )
const ( const (
// BlockBufferSize 合并写缓冲区的内存块大小
// 一块缓冲区可以包含多个合并写切片
BlockBufferSize = 1024 * 1024 * 2 BlockBufferSize = 1024 * 1024 * 2
) )
@@ -23,17 +25,17 @@ var (
}, },
} }
pendingReleaseBuffers = make(map[string]*collections.Queue[*mbBuffer]) pendingReleaseBuffers = make(map[string]*collections.Queue[*mbBuffer]) // 等待释放的合并写缓冲区
lock sync.Mutex lock sync.Mutex
) )
// AddMWBuffersToPending 添加合并写缓冲区到等待释放队列
func AddMWBuffersToPending(sourceId string, transStreamId TransStreamID, buffers *collections.Queue[*mbBuffer]) { func AddMWBuffersToPending(sourceId string, transStreamId TransStreamID, buffers *collections.Queue[*mbBuffer]) {
key := fmt.Sprintf("%s-%d", sourceId, transStreamId) key := fmt.Sprintf("%s-%d", sourceId, transStreamId)
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
for buffers.Size() > 0 {
v, ok := pendingReleaseBuffers[key] v, ok := pendingReleaseBuffers[key]
if ok { if ok {
// 第二次都推流结束了,第一次的内存还被占用 // 第二次都推流结束了,第一次的内存还被占用
@@ -50,10 +52,13 @@ func AddMWBuffersToPending(sourceId string, transStreamId TransStreamID, buffers
delete(pendingReleaseBuffers, key) delete(pendingReleaseBuffers, key)
} }
if buffers.Size() > 0 {
pendingReleaseBuffers[key] = buffers pendingReleaseBuffers[key] = buffers
} }
} }
// ReleasePendingBuffers 释放等待释放的合并写缓冲区
// 拉流结束后主动调用一次, 创建传输流的时候也调用一次
func ReleasePendingBuffers(sourceId string, transStreamId TransStreamID) { func ReleasePendingBuffers(sourceId string, transStreamId TransStreamID) {
key := fmt.Sprintf("%s-%d", sourceId, transStreamId) key := fmt.Sprintf("%s-%d", sourceId, transStreamId)
@@ -68,6 +73,7 @@ func ReleasePendingBuffers(sourceId string, transStreamId TransStreamID) {
delete(pendingReleaseBuffers, key) delete(pendingReleaseBuffers, key)
} }
// release 释放合并写缓冲区
func release(buffers *collections.Queue[*mbBuffer], length int) bool { func release(buffers *collections.Queue[*mbBuffer], length int) bool {
for i := 0; i < length; i++ { for i := 0; i < length; i++ {
buffer := buffers.Peek(0) buffer := buffers.Peek(0)

View File

@@ -175,7 +175,7 @@ func (s *BaseSink) fastForward(firstSegment *collections.ReferenceCounter[[]byte
func (s *BaseSink) doAsyncWrite() { func (s *BaseSink) doAsyncWrite() {
defer func() { defer func() {
// 释放未发送的数据 // 释放未发送的合并写切片
for buffer := s.pendingSendQueue.Pop(); buffer != nil; buffer = s.pendingSendQueue.Pop() { for buffer := s.pendingSendQueue.Pop(); buffer != nil; buffer = s.pendingSendQueue.Pop() {
buffer.Release() buffer.Release()
} }
@@ -241,9 +241,12 @@ func (s *BaseSink) doAsyncWrite() {
func (s *BaseSink) EnableAsyncWriteMode(queueSize int) { func (s *BaseSink) EnableAsyncWriteMode(queueSize int) {
utils.Assert(s.Conn != nil) utils.Assert(s.Conn != nil)
// 只初始化一次
if s.pendingSendQueue == nil {
s.pendingSendQueue = NewNonBlockingChannel[*collections.ReferenceCounter[[]byte]](queueSize) s.pendingSendQueue = NewNonBlockingChannel[*collections.ReferenceCounter[[]byte]](queueSize)
s.cancelCtx, s.cancelFunc = context.WithCancel(context.Background()) s.cancelCtx, s.cancelFunc = context.WithCancel(context.Background())
go s.doAsyncWrite() go s.doAsyncWrite()
}
} }
func (s *BaseSink) Write(index int, data []*collections.ReferenceCounter[[]byte], ts int64, keyVideo bool) error { func (s *BaseSink) Write(index int, data []*collections.ReferenceCounter[[]byte], ts int64, keyVideo bool) error {

View File

@@ -80,7 +80,9 @@ type Source interface {
StartTimers(source Source) StartTimers(source Source)
ExecuteSyncEvent(cb func()) ExecuteWithStreamLock(cb func())
ExecuteWithDeleteLock(cb func())
UpdateReceiveStats(dataLen int) UpdateReceiveStats(dataLen int)
} }
@@ -105,7 +107,8 @@ type PublishSource struct {
createTime time.Time // source创建时间 createTime time.Time // source创建时间
statistics *BitrateStatistics // 码流统计 statistics *BitrateStatistics // 码流统计
streamLogger avformat.OnUnpackStream2FileHandler streamLogger avformat.OnUnpackStream2FileHandler
streamLock sync.Mutex // 收流、探测超时、关闭等操作互斥锁 streamLock sync.Mutex // 收流、探测超时等操作互斥锁
deleteLock sync.Mutex // 双重锁, 防止在关闭source时, 其他操作同时进行
timers struct { timers struct {
receiveTimer *time.Timer // 收流超时计时器 receiveTimer *time.Timer // 收流超时计时器
@@ -157,10 +160,8 @@ func (s *PublishSource) Input(data []byte) (int, error) {
s.UpdateReceiveStats(len(data)) s.UpdateReceiveStats(len(data))
var n int var n int
var err error var err error
s.ExecuteSyncEvent(func() { s.ExecuteWithStreamLock(func() {
if s.closed.Load() { if !s.closed.Load() {
err = fmt.Errorf("source closed")
} else {
n, err = s.TransDemuxer.Input(data) n, err = s.TransDemuxer.Input(data)
} }
}) })
@@ -176,7 +177,7 @@ func (s *PublishSource) SetState(state SessionState) {
s.state = state s.state = state
} }
func (s *PublishSource) DoClose() { func (s *PublishSource) doClose() {
log.Sugar.Debugf("closing the %s source. id: %s. closed flag: %t", s.Type, s.ID, s.closed.Load()) log.Sugar.Debugf("closing the %s source. id: %s. closed flag: %t", s.Type, s.ID, s.closed.Load())
// 已关闭, 直接返回 // 已关闭, 直接返回
@@ -185,7 +186,7 @@ func (s *PublishSource) DoClose() {
} }
var closed bool var closed bool
s.ExecuteSyncEvent(func() { s.ExecuteWithStreamLock(func() {
closed = s.closed.Swap(true) closed = s.closed.Swap(true)
}) })
@@ -221,21 +222,17 @@ func (s *PublishSource) DoClose() {
// 同步执行 // 同步执行
s.streamPublisher.close() s.streamPublisher.close()
// 只释放prepare成功的source, 否则在关闭失败的source时, 造成id相同的source被错误释放
if s.state < SessionStateTransferring {
return
}
s.state = SessionStateClosed s.state = SessionStateClosed
// 释放解复用器
// 释放转码器 // 只删除被添加的source, 否则会造成id相同的source被误删
// 释放每路转协议流, 将所有sink添加到等待队列 if s.state >= SessionStateHandshakeSuccess {
_, err := SourceManager.Remove(s.ID) _, err := SourceManager.Remove(s.ID)
if err != nil { if err != nil {
// source不存在, 在创建source时, 未添加到manager中, 目前只有1078流会出现这种情况(tcp连接到端口, 没有推流或推流数据无效, 无法定位到手机号, 以至于无法执行PreparePublishSource函数), 将不再处理后续事情. // source不存在, 在创建source时, 未添加到manager中, 目前只有1078流会出现这种情况(tcp连接到端口, 没有推流或推流数据无效, 无法定位到手机号, 以至于无法执行PreparePublishSource函数), 将不再处理后续事情.
log.Sugar.Errorf("删除源失败 source: %s err: %s", s.ID, err.Error()) log.Sugar.Errorf("删除源失败 source: %s err: %s", s.ID, err.Error())
return return
} }
}
// 异步hook // 异步hook
go func() { go func() {
@@ -249,7 +246,9 @@ func (s *PublishSource) DoClose() {
} }
func (s *PublishSource) Close() { func (s *PublishSource) Close() {
s.DoClose() s.ExecuteWithDeleteLock(func() {
s.doClose()
})
} }
// 解析完所有track后, 创建各种输出流 // 解析完所有track后, 创建各种输出流
@@ -265,8 +264,8 @@ func (s *PublishSource) writeHeader() {
if len(s.originTracks.All()) == 0 { if len(s.originTracks.All()) == 0 {
log.Sugar.Errorf("没有一路track, 删除source: %s", s.ID) log.Sugar.Errorf("没有一路track, 删除source: %s", s.ID)
// 异步执行ProbeTimeout函数中还没释放锁 // 此时还持有stream lock, 异步关闭source
go s.DoClose() go CloseSource(s.ID)
return return
} }
} }
@@ -399,7 +398,7 @@ func (s *PublishSource) SetUrlValues(values url.Values) {
s.urlValues = values s.urlValues = values
} }
func (s *PublishSource) ExecuteSyncEvent(cb func()) { func (s *PublishSource) ExecuteWithStreamLock(cb func()) {
// 无竞争情况下, 接近原子操作 // 无竞争情况下, 接近原子操作
s.streamLock.Lock() s.streamLock.Lock()
defer s.streamLock.Unlock() defer s.streamLock.Unlock()
@@ -420,7 +419,7 @@ func (s *PublishSource) GetBitrateStatistics() *BitrateStatistics {
func (s *PublishSource) ProbeTimeout() { func (s *PublishSource) ProbeTimeout() {
if s.TransDemuxer != nil { if s.TransDemuxer != nil {
s.ExecuteSyncEvent(func() { s.ExecuteWithStreamLock(func() {
if !s.closed.Load() { if !s.closed.Load() {
s.TransDemuxer.ProbeComplete() s.TransDemuxer.ProbeComplete()
} }
@@ -454,3 +453,9 @@ func (s *PublishSource) StartTimers(source Source) {
}) })
} }
func (s *PublishSource) ExecuteWithDeleteLock(cb func()) {
s.deleteLock.Lock()
defer s.deleteLock.Unlock()
cb()
}

View File

@@ -16,7 +16,7 @@ type sourceManger struct {
m sync.Map m sync.Map
} }
func (s *sourceManger) Add(source Source) error { func (s *sourceManger) add(source Source) error {
_, ok := s.m.LoadOrStore(source.GetID(), source) _, ok := s.m.LoadOrStore(source.GetID(), source)
if ok { if ok {
return fmt.Errorf("the source %s has been exist", source.GetID()) return fmt.Errorf("the source %s has been exist", source.GetID())

View File

@@ -195,6 +195,13 @@ func StartIdleTimer(source Source) *time.Timer {
return idleTimer return idleTimer
} }
func CloseSource(id string) {
source := SourceManager.Find(id)
if source != nil {
source.Close()
}
}
// LoopEvent 循环读取事件 // LoopEvent 循环读取事件
func LoopEvent(source Source) { func LoopEvent(source Source) {
source.StartTimers(source) source.StartTimers(source)

View File

@@ -341,6 +341,9 @@ func (t *transStreamPublisher) CreateTransStream(protocol TransStreamProtocol, t
} }
} }
// 尝试清空等待释放的合并写缓冲区
ReleasePendingBuffers(t.source, id)
t.transStreams[id] = transStream t.transStreams[id] = transStream
// 创建输出流对应的拉流队列 // 创建输出流对应的拉流队列
t.transStreamSinks[id] = make(map[SinkID]Sink, 128) t.transStreamSinks[id] = make(map[SinkID]Sink, 128)
@@ -700,7 +703,6 @@ func (t *transStreamPublisher) doClose() {
// 将所有sink添加到等待队列 // 将所有sink添加到等待队列
for _, sink := range t.sinks { for _, sink := range t.sinks {
transStreamID := sink.GetTransStreamID() transStreamID := sink.GetTransStreamID()
sink.SetTransStreamID(0)
if t.recordSink == sink { if t.recordSink == sink {
continue continue
} }

View File

@@ -18,8 +18,10 @@ type StreamServer[T any] struct {
} }
func (s *StreamServer[T]) OnConnected(conn net.Conn) []byte { func (s *StreamServer[T]) OnConnected(conn net.Conn) []byte {
log.Sugar.Debugf("%s连接 conn:%s", s.SourceType.String(), conn.RemoteAddr().String()) log.Sugar.Debugf("%s连接 conn: %s", s.SourceType.String(), conn.RemoteAddr().String())
if s.Handler != nil {
conn.(*transport.Conn).Data = s.Handler.OnNewSession(conn) conn.(*transport.Conn).Data = s.Handler.OnNewSession(conn)
}
return nil return nil
} }
@@ -35,7 +37,7 @@ func (s *StreamServer[T]) OnDisConnected(conn net.Conn, err error) {
log.Sugar.Debugf("%s断开连接 conn:%s", s.SourceType.String(), conn.RemoteAddr().String()) log.Sugar.Debugf("%s断开连接 conn:%s", s.SourceType.String(), conn.RemoteAddr().String())
t := conn.(*transport.Conn) t := conn.(*transport.Conn)
if t.Data != nil { if s.Handler != nil && t.Data != nil {
s.Handler.OnCloseSession(t.Data.(T)) s.Handler.OnCloseSession(t.Data.(T))
t.Data = nil t.Data = nil
} }

View File

@@ -2,14 +2,13 @@ package stream
import "github.com/lkmio/avformat/utils" import "github.com/lkmio/avformat/utils"
// TransStreamID 每个传输流的唯一Id根据输出流协议ID+track index生成 // TransStreamID 每个传输流的唯一Id, 根据输出流协议ID+track index生成
// 输出流协议占低8位 // 输出流协议占低8位, track index占用8位, 最多支持7路流.
// 每个音视频编译器ID占用8位. 意味着每个输出流至多7路流.
type TransStreamID uint64 type TransStreamID uint64
func (id TransStreamID) HasTrack(index int) bool { func (id TransStreamID) HasTrack(index int) bool {
for i := 1; i < 8; i++ { for i := 1; i < 8; i++ {
if int(id>>(i*8))&0xFF == index { if (int(id>>(i*8))&0xFF)-1 == index {
return true return true
} }
} }
@@ -21,25 +20,6 @@ func (id TransStreamID) Protocol() TransStreamProtocol {
return TransStreamProtocol(id & 0xFF) return TransStreamProtocol(id & 0xFF)
} }
// GenerateTransStreamID 根据传入的推拉流协议和编码器ID生成StreamId
// 请确保ids根据值升序排序传参
/*func GenerateTransStreamID(protocol GetProtocol, ids ...utils.AVCodecID) GetTransStreamID {
len_ := len(ids)
utils.Assert(len_ > 0 && len_ < 8)
var streamId uint64
streamId = uint64(protocol) << 56
for i, GetID := range ids {
bId, ok := narrowCodecIds[int(GetID)]
utils.Assert(ok)
streamId |= uint64(bId) << (48 - i*8)
}
return GetTransStreamID(streamId)
}*/
// GenerateTransStreamID 根据输出流协议和输出流包含的音视频编码器ID生成流ID // GenerateTransStreamID 根据输出流协议和输出流包含的音视频编码器ID生成流ID
func GenerateTransStreamID(protocol TransStreamProtocol, tracks ...*Track) TransStreamID { func GenerateTransStreamID(protocol TransStreamProtocol, tracks ...*Track) TransStreamID {
len_ := len(tracks) len_ := len(tracks)
@@ -47,7 +27,8 @@ func GenerateTransStreamID(protocol TransStreamProtocol, tracks ...*Track) Trans
var streamId = uint64(protocol) & 0xFF var streamId = uint64(protocol) & 0xFF
for i, track := range tracks { for i, track := range tracks {
streamId |= uint64(track.Stream.Index) << ((i + 1) * 8) // +1是为了避免0值
streamId |= uint64(track.Stream.Index+1) << ((i + 1) * 8)
} }
return TransStreamID(streamId) return TransStreamID(streamId)