feat: 持久化保存推流会话,支持恢复推流会话

This commit is contained in:
ydajiang
2024-12-25 19:19:56 +08:00
parent ea518d71a4
commit 70b3e68d6d
17 changed files with 621 additions and 162 deletions

62
api.go
View File

@@ -206,25 +206,25 @@ func (api *ApiServer) OnPlay(params *StreamParams, w http.ResponseWriter, r *htt
//ffplay -i rtmp://127.0.0.1/34020000001320000001/34020000001310000001.session_id_0?setup=passive&stream_type=playback&start_time=2024-06-18T15:20:56&end_time=2024-06-18T15:25:56
// 跳过非国标拉流
split := strings.Split(string(params.Stream), "/")
if len(split) != 2 || len(split[0]) != 20 || len(split[1]) < 20 {
Sugar.Infof("跳过非国标流的播放事件 stream: %s", params.Stream)
sourceStream := strings.Split(string(params.Stream), "/")
if len(sourceStream) != 2 || len(sourceStream[0]) != 20 || len(sourceStream[1]) < 20 {
Sugar.Infof("跳过非国标流 stream: %s", params.Stream)
return
}
// 已经存在,累加计数
if stream := StreamManager.Find(params.Stream); stream != nil {
count := stream.IncreaseSinkCount()
Sugar.Infof("拉流计数: %d stream: %s ", count, params.Stream)
stream.IncreaseSinkCount()
return
}
deviceId := split[0]
channelId := split[1]
deviceId := sourceStream[0]
channelId := sourceStream[1]
if len(channelId) > 20 {
channelId = channelId[:20]
}
// 发起invite的参数
query := r.URL.Query()
inviteParams := &InviteParams{
DeviceID: deviceId,
@@ -250,11 +250,9 @@ func (api *ApiServer) OnPlay(params *StreamParams, w http.ResponseWriter, r *htt
if err != nil {
Sugar.Errorf("请求流失败 err: %s", err.Error())
}
if http.StatusOK == code {
count := stream.IncreaseSinkCount()
Sugar.Infof("拉流计数: %d stream: %s ", count, params.Stream)
utils.Assert(http.StatusOK != code)
} else if http.StatusOK == code {
stream.IncreaseSinkCount()
}
w.WriteHeader(code)
@@ -269,18 +267,17 @@ func (api *ApiServer) OnPlayDone(params *PlayDoneParams, w http.ResponseWriter,
return
}
count := stream.DecreaseSinkCount()
Sugar.Infof("拉流计数: %d stream: %s ", count, params.Stream)
stream.DecreaseSinkCount()
// 媒体链路与上级断开连接, 向上级发送Bye请求
// 级断开连接, 向上级发送Bye请求
if params.Protocol == "gb_stream_forward" {
sink := stream.RemoveForwardSink(params.Sink)
if sink == nil || sink.dialog == nil {
sink := stream.RemoveForwardStreamSink(params.Sink)
if sink == nil || sink.Dialog == nil {
return
}
if platform := PlatformManager.FindPlatform(sink.platformID); platform != nil {
callID, _ := sink.dialog.CallID()
if platform := PlatformManager.FindPlatform(sink.ServerID); platform != nil {
callID, _ := sink.Dialog.CallID()
platform.CloseStream(callID.String(), true, false)
}
}
@@ -298,7 +295,7 @@ func (api *ApiServer) OnPublish(params *StreamParams, w http.ResponseWriter, r *
func (api *ApiServer) OnPublishDone(params *StreamParams, w http.ResponseWriter, r *http.Request) {
Sugar.Infof("推流结束事件. protocol: %s stream: %s", params.Protocol, params.Stream)
CloseStream(params.Stream)
CloseStream(params.Stream, false)
}
func (api *ApiServer) OnIdleTimeout(params *StreamParams, w http.ResponseWriter, req *http.Request) {
@@ -307,7 +304,7 @@ func (api *ApiServer) OnIdleTimeout(params *StreamParams, w http.ResponseWriter,
// 非rtmp空闲超时, 返回非200应答, 删除会话
if params.Protocol != "rtmp" {
w.WriteHeader(http.StatusForbidden)
CloseStream(params.Stream)
CloseStream(params.Stream, false)
}
}
@@ -317,7 +314,7 @@ func (api *ApiServer) OnReceiveTimeout(params *StreamParams, w http.ResponseWrit
// 非rtmp推流超时, 返回非200应答, 删除会话
if params.Protocol != "rtmp" {
w.WriteHeader(http.StatusForbidden)
CloseStream(params.Stream)
CloseStream(params.Stream, false)
}
}
@@ -359,7 +356,7 @@ func (api *ApiServer) OnInvite(v *InviteParams, w http.ResponseWriter, r *http.R
}
}
// DoInvite 处理Invite请求
// DoInvite 发起Invite请求
// @params sync 是否异步等待流媒体的publish事件(确认收到流), 目前请求流分两种方式流媒体hook和http接口, hook方式同步等待确认收到流再应答, http接口直接应答成功。
func (api *ApiServer) DoInvite(inviteType InviteType, params *InviteParams, sync bool, w http.ResponseWriter, r *http.Request) (int, *Stream, error) {
device := DeviceManager.Find(params.DeviceID)
@@ -385,15 +382,14 @@ func (api *ApiServer) DoInvite(inviteType InviteType, params *InviteParams, sync
endTimeSeconds = strconv.FormatInt(endTime.Unix(), 10)
}
streamId := params.streamId
if streamId == "" {
streamId = GenerateStreamId(inviteType, device.GetID(), params.ChannelID, params.StartTime, params.EndTime)
if params.streamId == "" {
params.streamId = GenerateStreamId(inviteType, device.GetID(), params.ChannelID, params.StartTime, params.EndTime)
}
// 解析回放或下载速度参数
speed, _ := strconv.Atoi(params.Speed)
speed = int(math.Min(4, float64(speed)))
stream, err := device.(*Device).StartStream(inviteType, streamId, params.ChannelID, startTimeSeconds, endTimeSeconds, params.Setup, speed, sync)
stream, err := device.(*Device).StartStream(inviteType, params.streamId, params.ChannelID, startTimeSeconds, endTimeSeconds, params.Setup, speed, sync)
if err != nil {
return http.StatusInternalServerError, nil, err
}
@@ -405,17 +401,17 @@ func (api *ApiServer) OnCloseStream(v *StreamIDParams, w http.ResponseWriter, r
stream := StreamManager.Find(v.StreamID)
// 等空闲或收流超时会自动关闭
if stream != nil && stream.SinkCount() < 1 {
CloseStream(v.StreamID)
if stream != nil && stream.GetSinkCount() < 1 {
CloseStream(v.StreamID, true)
}
httpResponseOK(w, nil)
}
func CloseStream(streamId StreamID) {
func CloseStream(streamId StreamID, ms bool) {
stream := StreamManager.Remove(streamId)
if stream != nil {
stream.Close(true)
stream.Close(true, ms)
}
}
@@ -536,7 +532,7 @@ func (api *ApiServer) OnSubscribePosition(v *DeviceChannelID, w http.ResponseWri
func (api *ApiServer) OnSeekPlayback(v *SeekParams, w http.ResponseWriter, r *http.Request) {
stream := StreamManager.Find(v.StreamId)
if stream == nil || stream.DialogRequest == nil {
if stream == nil || stream.Dialog == nil {
httpResponseError(w, "会话不存在")
return
}
@@ -715,7 +711,7 @@ func (api *ApiServer) OnStarted(w http.ResponseWriter, req *http.Request) {
streams := StreamManager.PopAll()
for _, stream := range streams {
stream.Close(true)
stream.Close(true, false)
}
}

View File

@@ -60,7 +60,7 @@ func (g *Client) SendMessage(msg interface{}) {
panic(err)
}
request, err := BuildMessageRequest(g.sipClient.Username, g.sipClient.ListenAddr, g.sipClient.SeverId, g.sipClient.Domain, g.sipClient.Transport, string(marshal))
request, err := BuildMessageRequest(g.sipClient.Username, g.sipClient.ListenAddr, g.sipClient.SeverID, g.sipClient.Domain, g.sipClient.Transport, string(marshal))
if err != nil {
panic(err)
}
@@ -140,7 +140,7 @@ func NewGBClient(username, serverId, serverAddr, transport, password string, reg
Password: password,
RegisterExpires: registerExpires,
KeeAliveInterval: keepalive,
SeverId: serverId,
SeverID: serverId,
ListenAddr: ua.ListenAddr(),
ua: ua,
}

View File

@@ -219,7 +219,7 @@ func (v VirtualDevice) OnInvite(request sip.Request, user string) sip.Response {
// 绑定到StreamManager, bye请求才会找到设备回调
streamId := GenerateStreamId(InviteTypeLive, v.sipClient.Username, user, "", "")
s := Stream{ID: streamId, DialogRequest: stream.dialog}
s := Stream{ID: streamId, Dialog: stream.dialog}
StreamManager.Add(&s)
callID, _ := request.CallID()

8
db.go
View File

@@ -41,4 +41,12 @@ type GB28181DB interface {
// QueryPlatformChannel 查询级联设备的某个通道, 返回通道所属设备ID、通道.
QueryPlatformChannel(platformId string, channelId string) (string, *Channel, error)
LoadStreams() (map[string]*Stream, error)
SaveStream(stream *Stream) error
DeleteStream(time int64) error
//QueryStream(pate int, size int)
}

View File

@@ -1,6 +1,7 @@
package main
import (
"encoding/hex"
"encoding/json"
"fmt"
"strconv"
@@ -15,8 +16,12 @@ const (
RedisKeyChannels = "channels" // 使用map保存所有通道信息
RedisKeyDeviceChannels = "%s_channels" // 使用zset保存设备下的所有通道ID
RedisKeyPlatforms = "platforms" // 使用zset有序保存所有级联设备
RedisKeyStreams = "streams" // 保存所有推流信息, 以便在崩溃后恢复
RedisUniqueChannelID = "%s_%s" // 通道号的唯一ID, 设备_通道号
// RedisKeyStreams 保存推拉流信息, 主要目的是程序崩溃重启后恢复国标流的invite会话. 如果需要统计所有详细的推拉流信息,需要自行实现.
RedisKeyStreams = "streams" //// 保存所有推流端信息
RedisKeySinks = "sinks" //// 保存所有拉流端信息
RedisKeyStreamSinks = "%s_sinks" //// 某路流下所有的拉流端
)
type RedisDB struct {
@@ -514,6 +519,54 @@ func (r *RedisDB) QueryPlatformChannel(platformId string, channelId string) (str
return deviceId, channel, nil
}
func (r *RedisDB) LoadStreams() (map[string]*Stream, error) {
executor, err := r.utils.CreateExecutor()
if err != nil {
return nil, err
}
all, err := executor.Key(RedisKeyStreams).ZRange()
if err != nil {
return nil, err
}
streams := make(map[string]*Stream, len(all))
for _, v := range all {
stream := &Stream{}
if err := json.Unmarshal([]byte(v[0]), stream); err != nil {
Sugar.Errorf("解析stream失败, err: %s value: %s", err.Error(), hex.EncodeToString([]byte(v[0])))
continue
}
streams[string(stream.ID)] = stream
}
return streams, nil
}
func (r *RedisDB) SaveStream(stream *Stream) error {
data, err := json.Marshal(stream)
if err != nil {
return err
}
executor, err := r.utils.CreateExecutor()
if err != nil {
return err
}
return executor.Key(RedisKeyStreams).ZAddWithNotExists(stream.CreateTime, data)
}
func (r *RedisDB) DeleteStream(time int64) error {
executor, err := r.utils.CreateExecutor()
if err != nil {
return err
}
return executor.Key(RedisKeyStreams).ZDelWithScore(time)
}
// OnExpires Redis设备ID到期回调
func (r *RedisDB) OnExpires(db int, id string) {
Sugar.Infof("设备心跳过期 device: %s", id)
@@ -524,10 +577,7 @@ func (r *RedisDB) OnExpires(db int, id string) {
return
}
device.(*Device).Status = OFF
if err := DB.SaveDevice(device.(*Device)); err != nil {
Sugar.Errorf("更新设备在线状态失败 err: %s device: %s ", err.Error(), id)
}
device.Close()
}
func NewRedisDB(addr, password string) *RedisDB {

View File

@@ -85,6 +85,8 @@ type GBDevice interface {
// 附录P.4.2.2
// @Params event ON-上线/OFF-离线/VLOST-视频丢失/DEFECT-故障/ADD-增加/DEL-删除/UPDATE-更新
UpdateChannel(id string, event string)
Close()
}
type Device struct {
@@ -96,6 +98,7 @@ type Device struct {
Manufacturer string `json:"manufacturer"`
Model string `json:"model"`
Firmware string `json:"firmware"`
RegisterTime int64 `json:"register_time"`
ChannelsTotal int `json:"total_channels"` // 通道总数
ChannelsOnline int `json:"online_channels"` // 通道在线数量
@@ -262,6 +265,27 @@ func (d *Device) BuildDownloadRequest(channelId, ip string, port uint16, startTi
return d.BuildInviteRequest("Download", channelId, ip, port, startTime, stopTime, setup, speed, ssrc)
}
func (d *Device) Close() {
// 更新在数据库中的状态
d.Status = OFF
if err := DB.SaveDevice(d); err != nil {
Sugar.Errorf("更新设备在线状态失败 err: %s device: %s ", err.Error(), d.ID)
}
// 释放所有推流
all := StreamManager.All()
var streams []*Stream
for _, stream := range all {
if d.ID == stream.ID.DeviceID() {
streams = append(streams, stream)
}
}
for _, stream := range streams {
stream.Close(true, true)
}
}
// CreateDialogRequestFromAnswer 根据invite的应答创建Dialog请求
// 应答的to头域需携带tag
func CreateDialogRequestFromAnswer(message sip.Response, uas bool, remoteAddr string) sip.Request {

13
live.go
View File

@@ -38,7 +38,8 @@ func (i *InviteType) SessionName2Type(name string) {
func (d *Device) StartStream(inviteType InviteType, streamId StreamID, channelId, startTime, stopTime, setup string, speed int, sync bool) (*Stream, error) {
stream := &Stream{
ID: streamId,
forwardSinks: map[string]*Sink{},
ForwardStreamSinks: map[string]*Sink{},
CreateTime: time.Now().UnixMilli(),
}
// 先添加占位置, 防止重复请求
@@ -52,7 +53,7 @@ func (d *Device) StartStream(inviteType InviteType, streamId StreamID, channelId
return nil, err
}
stream.DialogRequest = dialog
stream.Dialog = dialog
callID, _ := dialog.CallID()
StreamManager.AddWithCallId(callID.Value(), stream)
@@ -61,7 +62,7 @@ func (d *Device) StartStream(inviteType InviteType, streamId StreamID, channelId
ok := stream.WaitForPublishEvent(10)
if !ok {
Sugar.Infof("收流超时 发送bye请求...")
CloseStream(streamId)
CloseStream(streamId, true)
}
return ok
}
@@ -73,6 +74,10 @@ func (d *Device) StartStream(inviteType InviteType, streamId StreamID, channelId
}
stream.urls = urls
// 保存到数据库
go DB.SaveStream(stream)
return stream, nil
}
@@ -83,7 +88,7 @@ func (d *Device) Invite(inviteType InviteType, streamId StreamID, channelId, sta
defer func() {
// 如果失败, 告知流媒体服务释放国标源
if err != nil {
go CloseGBSource(string(streamId))
go CloseSource(string(streamId))
}
}()

179
main.go
View File

@@ -29,19 +29,140 @@ func init() {
InitLogger(zapcore.Level(logConfig.Level), logConfig.Name, logConfig.MaxSize, logConfig.MaxBackup, logConfig.MaxAge, logConfig.Compress)
}
func main() {
config, err := ParseConfig("./config.json")
func startPlatformDevices() {
platforms, err := DB.LoadPlatforms()
if err != nil {
panic(err)
Sugar.Errorf("查询级联设备失败 err: %s", err.Error())
return
}
Config = config
indent, _ := json.MarshalIndent(Config, "", "\t")
Sugar.Infof("server config:\r\n%s", indent)
streams := StreamManager.All()
for _, record := range platforms {
platform, err := NewGBPlatform(record, SipUA)
// 都入库了不允许失败, 程序有BUG, 及时修复
utils.Assert(err == nil)
utils.Assert(PlatformManager.AddPlatform(platform))
DB = NewRedisDB(Config.Redis.Addr, Config.Redis.Password)
if err := DB.UpdatePlatformStatus(record.SeverID, OFF); err != nil {
Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), record.SeverID)
}
// 查询在线设备, 更新设备在线状态
// 恢复级联会话
for _, stream := range streams {
sinks := stream.GetForwardStreamSinks()
for _, sink := range sinks {
if sink.ID != record.SeverID {
continue
}
callId, _ := sink.Dialog.CallID()
channelCallId, _ := stream.Dialog.CallID()
platform.AddStream(callId.Value(), channelCallId.Value())
}
}
platform.Start()
}
}
func recoverStreams() ([]*Stream, []*Sink) {
// 查询数据库中的流记录
// 查询流媒体服务器中的记录
// 合并两份记录, 以流媒体服务器中的为准。如果流记录数量不一致(只会时数据库中的记录数大于或等于流媒体中的记录数), 释放过期的会话.
// source id和stream id目前都是同一个id
streams, err := DB.LoadStreams()
if err != nil {
Sugar.Errorf("恢复推流失败, 查询数据库发生错误. err: %s", err.Error())
return nil, nil
} else if len(streams) < 1 {
return nil, nil
}
sources, err := QuerySourceList()
if err != nil {
// 流媒体服务器崩了, 存在的所有流都无效, 删除全部记录
Sugar.Warnf("恢复推流失败, 查询推流源列表发生错误, 删除数据库中的推拉流会话记录. err: %s", err.Error())
for _, stream := range streams {
DB.DeleteStream(stream.CreateTime)
}
return nil, nil
}
sourceSinks := make(map[string][]string, len(sources))
for _, source := range sources {
// 跳过非国标流
if "28181" != source.Protocol {
continue
}
// 查询级联转发sink
sinks, err := QuerySinkList(source.ID)
if err != nil {
Sugar.Warnf("查询拉流列表发生 err: %s", err.Error())
continue
}
stream, ok := streams[source.ID]
utils.Assert(ok)
stream.SinkCount = int32(len(sinks))
var forwardSinks []string
for _, sink := range sinks {
if "gb_stream_forward" == sink.Protocol {
forwardSinks = append(forwardSinks, sink.ID)
}
}
sourceSinks[source.ID] = forwardSinks
}
var closedStreams []*Stream
var closedSinks []*Sink
for _, stream := range streams {
forwardSinks, ok := sourceSinks[string(stream.ID)]
if !ok {
Sugar.Infof("删除过期的推流会话 stream: %s", stream.ID)
closedStreams = append(closedStreams, stream)
stream.Close(true, false)
continue
}
Sugar.Infof("恢复推流会话 stream: %s", stream.ID)
var invalidDialogs []string
for callId, sink := range stream.ForwardStreamSinks {
var exist bool
for _, id := range forwardSinks {
if id == sink.ID {
exist = true
break
}
}
if !exist {
Sugar.Infof("删除过期的级联转发会话 stream: %s sink: %s callId: %s", stream.ID, sink.ID, callId)
}
invalidDialogs = append(invalidDialogs, callId)
}
for _, id := range invalidDialogs {
sink := stream.RemoveForwardStreamSink(id)
closedSinks = append(closedSinks, sink)
}
StreamManager.Add(stream)
callId, _ := stream.Dialog.CallID()
StreamManager.AddWithCallId(callId.Value(), stream)
}
return closedStreams, closedSinks
}
func updateDevicesStatus() {
onlineDevices, err := DB.LoadOnlineDevices()
if err != nil {
panic(err)
@@ -84,6 +205,25 @@ func main() {
DeviceManager.Add(device)
}
}
}
func main() {
config, err := ParseConfig("./config.json")
if err != nil {
panic(err)
}
Config = config
indent, _ := json.MarshalIndent(Config, "", "\t")
Sugar.Infof("server config:\r\n%s", indent)
DB = NewRedisDB(Config.Redis.Addr, Config.Redis.Password)
// 查询在线设备, 更新设备在线状态
updateDevicesStatus()
// 恢复国标推流会话
streams, sinks := recoverStreams()
// 设置语音广播端口
TransportManager = transport.NewTransportManager(uint16(Config.Port[0]), uint16(Config.Port[1]))
@@ -98,20 +238,17 @@ func main() {
Config.SipContactAddr = net.JoinHostPort(config.PublicIP, strconv.Itoa(config.SipPort))
SipUA = server
// 在sip启动后, 关闭无效的流
for _, stream := range streams {
stream.Close(true, false)
}
for _, sink := range sinks {
sink.Close(true, false)
}
// 启动级联设备
platforms, err := DB.LoadPlatforms()
for _, record := range platforms {
platform, err := NewGBPlatform(record, SipUA)
// 都入库了不允许失败, 程序有BUG, 及时修复
utils.Assert(err == nil)
utils.Assert(PlatformManager.AddPlatform(platform))
if err := DB.UpdatePlatformStatus(record.SeverID, OFF); err != nil {
Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), record.SeverID)
}
platform.Start()
}
startPlatformDevices()
httpAddr := net.JoinHostPort(config.ListenIP, strconv.Itoa(config.HttpPort))
Sugar.Infof("启动http server. addr: %s", httpAddr)

View File

@@ -8,10 +8,28 @@ import (
"time"
)
type SourceDetails struct {
ID string `json:"id"`
Protocol string `json:"protocol"` // 推流协议
Time time.Time `json:"time"` // 推流时间
SinkCount int `json:"sink_count"` // 播放端计数
Bitrate string `json:"bitrate"` // 码率统计
Tracks []string `json:"tracks"` // 每路流编码器ID
Urls []string `json:"urls"` // 拉流地址
}
type SinkDetails struct {
ID string `json:"id"`
Protocol string `json:"protocol"` // 拉流协议
Time time.Time `json:"time"` // 拉流时间
Bitrate string `json:"bitrate"` // 码率统计
Tracks []string `json:"tracks"` // 每路流编码器ID
}
func Send(path string, body interface{}) (*http.Response, error) {
url := fmt.Sprintf("http://%s/%s", Config.MediaServer, path)
marshal, err := json.Marshal(body)
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
@@ -20,7 +38,7 @@ func Send(path string, body interface{}) (*http.Response, error) {
Timeout: 10 * time.Second,
}
request, err := http.NewRequest("post", url, bytes.NewBuffer(marshal))
request, err := http.NewRequest("post", url, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
@@ -73,14 +91,14 @@ func ConnectGBSource(id, addr string) error {
return err
}
func CloseGBSource(id string) error {
func CloseSource(id string) error {
v := &struct {
Source string `json:"source"`
}{
Source: id,
}
_, err := Send("api/v1/gb28181/source/close", v)
_, err := Send("api/v1/source/close", v)
return err
}
@@ -127,3 +145,35 @@ func CloseSink(sourceId string, sinkId string) {
_, _ = Send("api/v1/sink/close", v)
}
func QuerySourceList() ([]*SourceDetails, error) {
response, err := Send("api/v1/source/list", nil)
if err != nil {
return nil, err
}
data := &Response[[]*SourceDetails]{}
if err = DecodeJSONBody(response.Body, data); err != nil {
return nil, err
}
return data.Data, err
}
func QuerySinkList(source string) ([]*SinkDetails, error) {
id := struct {
Source string `json:"source"`
}{source}
response, err := Send("api/v1/sink/list", id)
if err != nil {
return nil, err
}
data := &Response[[]*SinkDetails]{}
if err = DecodeJSONBody(response.Body, data); err != nil {
return nil, err
}
return data.Data, err
}

View File

@@ -8,6 +8,7 @@ import (
"net/netip"
"strconv"
"strings"
"sync"
)
// GBPlatformRecord 国标级联设备信息持久化结构体
@@ -25,7 +26,22 @@ type GBPlatformRecord struct {
type GBPlatform struct {
*Client
streams *streamManager // 保存与上级的所有级联会话
lock sync.Mutex
streams map[string]string // 上级会话的callId关联到实际推流通道的callId
}
func (g *GBPlatform) AddStream(callId string, channelCallId string) {
g.lock.Lock()
defer g.lock.Unlock()
g.streams[callId] = channelCallId
}
func (g *GBPlatform) removeStream(callId string) string {
g.lock.Lock()
defer g.lock.Unlock()
channelCallId := g.streams[callId]
delete(g.streams, callId)
return channelCallId
}
// OnBye 被上级挂断
@@ -35,39 +51,50 @@ func (g *GBPlatform) OnBye(request sip.Request) {
}
// CloseStream 关闭级联会话
func (g *GBPlatform) CloseStream(id string, bye, ms bool) {
// 删除会话
stream := g.streams.RemoveWithCallId(id)
func (g *GBPlatform) CloseStream(callId string, bye, ms bool) {
channelCallId := g.removeStream(callId)
stream := StreamManager.FindWithCallId(channelCallId)
if stream == nil {
Sugar.Errorf("关闭级联转发sink失败, 找不到stream. callid: %s", callId)
return
}
// 从国标源中删除当前转发流
sink := stream.RemoveForwardSink(id)
if ms {
// 通知媒体服务
go CloseSink(string(stream.ID), sink.id)
sink := stream.RemoveForwardStreamSink(callId)
if sink == nil {
Sugar.Errorf("关闭级联转发sink失败, 找不到sink. callid: %s", callId)
return
}
// SIP挂断
if bye {
byeRequest := CreateRequestFromDialog(sink.dialog, sip.BYE)
SipUA.SendRequest(byeRequest)
sink.Close(bye, ms)
}
// CloseStreams 关闭所有级联会话
func (g *GBPlatform) CloseStreams(bye, ms bool) {
var callIds []string
g.lock.Lock()
for k := range g.streams {
callIds = append(callIds, k)
}
g.lock.Unlock()
for _, id := range callIds {
g.CloseStream(id, bye, ms)
}
}
// OnInvite 被上级呼叫
func (g *GBPlatform) OnInvite(request sip.Request, user string) sip.Response {
Sugar.Infof("收到级联Invite请求 platform: %s channel: %s sdp: %s", g.SeverId, user, request.Body())
Sugar.Infof("收到级联Invite请求 platform: %s channel: %s sdp: %s", g.SeverID, user, request.Body())
source := request.Source()
platform := PlatformManager.FindPlatformWithServerAddr(source)
utils.Assert(platform != nil)
deviceId, channel, err := DB.QueryPlatformChannel(g.SeverId, user)
deviceId, channel, err := DB.QueryPlatformChannel(g.SeverID, user)
if err != nil {
Sugar.Errorf("级联转发失败, 查询数据库失败 err: %s platform: %s channel: %s", err.Error(), g.SeverId, user)
Sugar.Errorf("级联转发失败, 查询数据库失败 err: %s platform: %s channel: %s", err.Error(), g.SeverID, user)
return CreateResponseWithStatusCode(request, http.StatusInternalServerError)
}
@@ -123,7 +150,7 @@ func (g *GBPlatform) OnInvite(request sip.Request, user string) sip.Response {
Sugar.Errorf("级联转发失败,向流媒体服务添加转发Sink失败 err: %s", err.Error())
if "play" != parse.Session {
CloseStream(streamId)
CloseStream(streamId, true)
}
return CreateResponseWithStatusCode(request, http.StatusInternalServerError)
@@ -141,15 +168,18 @@ func (g *GBPlatform) OnInvite(request sip.Request, user string) sip.Response {
// 添加级联转发流
callID, _ := request.CallID()
stream.AddForwardSink(callID.Value(), &Sink{sinkID, g.ID, g.CreateDialogRequestFromAnswer(response, true), g.Username})
stream.AddForwardStreamSink(callID.Value(), &Sink{
ID: sinkID,
Stream: streamId,
ServerID: g.SeverID,
Dialog: g.CreateDialogRequestFromAnswer(response, true)},
)
// 保存与上级的会话
g.streams.AddWithCallId(callID.Value(), stream)
return response
}
func (g *GBPlatform) Start() {
Sugar.Infof("启动级联设备, deivce: %s transport: %s addr: %s", g.SeverId, g.sipClient.Transport, g.sipClient.Domain)
Sugar.Infof("启动级联设备, deivce: %s transport: %s addr: %s", g.SeverID, g.sipClient.Transport, g.sipClient.Domain)
g.sipClient.Start()
g.sipClient.SetOnRegisterHandler(g.onlineCB, g.offlineCB)
}
@@ -157,22 +187,28 @@ func (g *GBPlatform) Start() {
func (g *GBPlatform) Stop() {
g.sipClient.Stop()
g.sipClient.SetOnRegisterHandler(nil, nil)
// 释放所有推流
g.CloseStreams(true, true)
}
func (g *GBPlatform) Online() {
Sugar.Infof("级联设备上线 device: %s", g.SeverId)
Sugar.Infof("级联设备上线 device: %s", g.SeverID)
if err := DB.UpdatePlatformStatus(g.SeverId, ON); err != nil {
Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), g.SeverId)
if err := DB.UpdatePlatformStatus(g.SeverID, ON); err != nil {
Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), g.SeverID)
}
}
func (g *GBPlatform) Offline() {
Sugar.Infof("级联设备离线 device: %s", g.SeverId)
Sugar.Infof("级联设备离线 device: %s", g.SeverID)
if err := DB.UpdatePlatformStatus(g.SeverId, OFF); err != nil {
Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), g.SeverId)
if err := DB.UpdatePlatformStatus(g.SeverID, OFF); err != nil {
Sugar.Infof("更新级联设备状态失败 err: %s device: %s", err.Error(), g.SeverID)
}
// 释放所有推流
g.CloseStreams(true, true)
}
func NewGBPlatform(record *GBPlatformRecord, ua SipServer) (*GBPlatform, error) {
@@ -184,6 +220,6 @@ func NewGBPlatform(record *GBPlatformRecord, ua SipServer) (*GBPlatform, error)
return nil, err
}
client := NewGBClient(record.Username, record.SeverID, record.ServerAddr, record.Transport, record.Password, record.RegisterExpires, record.KeepAliveInterval, ua)
return &GBPlatform{client.(*Client), NewStreamManager()}, nil
gbClient := NewGBClient(record.Username, record.SeverID, record.ServerAddr, record.Transport, record.Password, record.RegisterExpires, record.KeepAliveInterval, ua)
return &GBPlatform{Client: gbClient.(*Client), streams: make(map[string]string, 8)}, nil
}

View File

@@ -24,11 +24,11 @@ func (p *platformManager) AddPlatform(platform *GBPlatform) bool {
defer p.lock.Unlock()
// 以上级平台ID作为主键
if _, ok := p.addrMap[platform.sipClient.SeverId]; ok {
if _, ok := p.addrMap[platform.sipClient.SeverID]; ok {
return false
}
p.platforms[platform.sipClient.SeverId] = platform
p.platforms[platform.sipClient.SeverID] = platform
p.addrMap[platform.sipClient.Domain] = platform
return true
}

75
sink.go Normal file
View File

@@ -0,0 +1,75 @@
package main
import (
"encoding/json"
"github.com/ghettovoice/gosip/sip"
"github.com/ghettovoice/gosip/sip/parser"
)
// Sink 国标级联转发流
type Sink struct {
ID string `json:"id"` // 流媒体服务器中的SinkID
Stream StreamID `json:"stream"` // 所属的stream id
Protocol string `json:"protocol,omitempty"` // 拉流协议, 目前只保存"gb_stream_forward"
Dialog sip.Request `json:"dialog,omitempty"` // 级联时, 与上级的Invite会话
ServerID string `json:"server_id"` // 级联设备的上级ID
CreateTime int64 `json:"create_time"`
}
// Close 关闭级联会话. 是否向上级发送bye请求, 是否通知流媒体服务器发送删除sink
func (s *Sink) Close(bye, ms bool) {
// 挂断与上级的sip会话
if bye && s.Dialog != nil {
byeRequest := CreateRequestFromDialog(s.Dialog, sip.BYE)
go SipUA.SendRequest(byeRequest)
}
if ms {
go CloseSink(string(s.Stream), s.ID)
}
}
func (s *Sink) MarshalJSON() ([]byte, error) {
type Alias Sink // 定义别名以避免递归调用
v := &struct {
*Alias
Dialog string `json:"dialog,omitempty"` // 将 Dialog 转换为字符串
}{
Alias: (*Alias)(s),
}
if s.Dialog != nil {
v.Dialog = s.Dialog.String()
}
return json.Marshal(v)
}
func (s *Sink) UnmarshalJSON(data []byte) error {
type Alias Sink // 定义别名以避免递归调用
v := &struct {
*Alias
Dialog string `json:"dialog,omitempty"` // 将 Dialog 转换为字符串
}{
Alias: (*Alias)(s),
}
if err := json.Unmarshal(data, v); err != nil {
return err
}
*s = *(*Sink)(v.Alias)
if len(v.Dialog) > 1 {
packetParser := parser.NewPacketParser(logger)
message, err := packetParser.ParseMessage([]byte(v.Dialog))
if err != nil {
Sugar.Errorf("json解析dialog失败, err: %s value: %s", err.Error(), v.Dialog)
} else {
request := message.(sip.Request)
s.Dialog = request
}
}
return nil
}

View File

@@ -46,7 +46,7 @@ type sipClient struct {
Password string //密码
RegisterExpires int //注册有效期
KeeAliveInterval int //心跳间隔
SeverId string //上级ID
SeverID string //上级ID
ListenAddr string //UA的监听地址
NatAddr string //Nat地址
@@ -109,7 +109,7 @@ func (g *sipClient) doRegister(request sip.Request) bool {
}
func (g *sipClient) startNewRegister() bool {
builder := NewRequestBuilder(sip.REGISTER, g.Username, g.ListenAddr, g.SeverId, g.Domain, g.Transport)
builder := NewRequestBuilder(sip.REGISTER, g.Username, g.ListenAddr, g.SeverID, g.Domain, g.Transport)
expires := sip.Expires(g.RegisterExpires)
builder.SetExpires(&expires)
@@ -167,7 +167,7 @@ func (g *sipClient) doUnregister() {
func (g *sipClient) doKeepalive() bool {
body := fmt.Sprintf(KeepAliveBody, time.Now().UnixMilli()/1000, g.Username)
request, err := BuildMessageRequest(g.Username, g.ListenAddr, g.SeverId, g.Domain, g.Transport, body)
request, err := BuildMessageRequest(g.Username, g.ListenAddr, g.SeverID, g.Domain, g.Transport, body)
if err != nil {
panic(err)
}

View File

@@ -1,6 +1,9 @@
package main
import "strings"
import (
"strings"
"time"
)
// Handler 处理下级设备的消息
type Handler interface {
@@ -60,6 +63,7 @@ func (e *EventHandler) OnRegister(id, transport, addr string) (int, GBDevice, bo
}
device.Status = ON
device.RegisterTime = time.Now().UnixMilli()
if DB != nil {
if err := DB.SaveDevice(device); err != nil {
Sugar.Errorf("保存设备信息到数据库失败 device: %s err: %s", id, err.Error())

View File

@@ -90,7 +90,7 @@ func (s *sipServer) OnRegister(req sip.Request, tx sip.ServerTransaction, parent
var expires int
expires, device, queryCatalog = s.handler.OnRegister(id, req.Transport(), req.Source())
if device != nil {
Sugar.Infof("注册成功 Device: %s", id)
Sugar.Infof("注册成功 Device: %s addr: %s", id, req.Source())
expiresHeader := sip.Expires(expires)
response.AppendHeader(&expiresHeader)
} else {
@@ -163,7 +163,7 @@ func (s *sipServer) OnBye(req sip.Request, tx sip.ServerTransaction, parent bool
if stream := StreamManager.RemoveWithCallId(id.Value()); stream != nil {
// 下级设备挂断, 关闭流
deviceId = stream.ID.DeviceID()
stream.Close(false)
stream.Close(false, true)
} else if session := BroadcastManager.RemoveWithCallId(id.Value()); session != nil {
// 广播挂断
deviceId = session.DeviceID

140
stream.go
View File

@@ -2,59 +2,101 @@ package main
import (
"context"
"encoding/json"
"github.com/ghettovoice/gosip/sip"
"github.com/ghettovoice/gosip/sip/parser"
"sync"
"sync/atomic"
"time"
)
// Sink 级联转发
type Sink struct {
id string
deviceID string
dialog sip.Request
platformID string // 级联上级ID
}
// Stream 国标推流源
// Stream 国标推流
type Stream struct {
ID StreamID // 推流ID
DialogRequest sip.Request
ID StreamID `json:"id"` // 推流ID
Protocol string `json:"protocol,omitempty"` // 推流协议
Dialog sip.Request `json:"dialog,omitempty"` // 国标推流时, 与推流通道的Invite会话
CreateTime int64 `json:"create_time"` // 推流时间
SinkCount int32 `json:"sink_count"` // 拉流端计数(包含级联转发)
sinkCount int32 // 拉流数量+级联转发数量
publishEvent chan byte
cancelFunc func()
forwardSinks map[string]*Sink // 级联转发Sink, Key为与上级的CallID
lock sync.RWMutex
urls []string // 拉流地址
ForwardStreamSinks map[string]*Sink // 级联转发Sink, Key为与上级的CallID. 不保存所有的拉流端,查询拉流端列表,从流媒体服务器查询或新建数据库查询。 json序列化, 线程安全?
urls []string // 从流媒体服务器返回的拉流地址
publishEvent chan byte // 等待推流hook的管道
cancelFunc func() // 取消等待推流hook的ctx
}
func (s *Stream) AddForwardSink(id string, sink *Sink) {
func (s *Stream) MarshalJSON() ([]byte, error) {
type Alias Stream // 定义别名以避免递归调用
v := &struct {
*Alias
Dialog string `json:"dialog,omitempty"` // 将 Dialog 转换为字符串
}{
Alias: (*Alias)(s),
}
if s.Dialog != nil {
v.Dialog = s.Dialog.String()
}
return json.Marshal(v)
}
func (s *Stream) UnmarshalJSON(data []byte) error {
type Alias Stream // 定义别名以避免递归调用
v := &struct {
*Alias
Dialog string `json:"dialog,omitempty"` // 将 Dialog 转换为字符串
}{
Alias: (*Alias)(s),
}
if err := json.Unmarshal(data, v); err != nil {
return err
}
*s = *(*Stream)(v.Alias)
if len(v.Dialog) > 1 {
packetParser := parser.NewPacketParser(logger)
message, err := packetParser.ParseMessage([]byte(v.Dialog))
if err != nil {
Sugar.Errorf("json解析dialog失败, err: %s value: %s", err.Error(), v.Dialog)
} else {
request := message.(sip.Request)
s.Dialog = request
}
}
return nil
}
func (s *Stream) AddForwardStreamSink(id string, sink *Sink) {
s.lock.Lock()
defer s.lock.Unlock()
s.forwardSinks[id] = sink
s.ForwardStreamSinks[id] = sink
go DB.SaveStream(s)
}
func (s *Stream) RemoveForwardSink(id string) *Sink {
func (s *Stream) RemoveForwardStreamSink(id string) *Sink {
s.lock.Lock()
defer s.lock.Unlock()
sink, ok := s.forwardSinks[id]
sink, ok := s.ForwardStreamSinks[id]
if ok {
delete(s.forwardSinks, id)
delete(s.ForwardStreamSinks, id)
}
go DB.SaveStream(s)
return sink
}
func (s *Stream) ForwardSinks() []*Sink {
func (s *Stream) GetForwardStreamSinks() []*Sink {
s.lock.Lock()
defer s.lock.Unlock()
var sinks []*Sink
for _, sink := range s.forwardSinks {
for _, sink := range s.ForwardStreamSinks {
sinks = append(sinks, sink)
}
@@ -75,38 +117,58 @@ func (s *Stream) WaitForPublishEvent(seconds int) bool {
}
}
func (s *Stream) SinkCount() int32 {
return atomic.LoadInt32(&s.sinkCount)
func (s *Stream) GetSinkCount() int32 {
return atomic.LoadInt32(&s.SinkCount)
}
func (s *Stream) IncreaseSinkCount() int32 {
return atomic.AddInt32(&s.sinkCount, 1)
value := atomic.AddInt32(&s.SinkCount, 1)
Sugar.Infof("拉流计数: %d stream: %s ", value, s.ID)
// 启动协程去更新拉流计数, 可能会不一致
go DB.SaveStream(s)
return value
}
func (s *Stream) DecreaseSinkCount() int32 {
return atomic.AddInt32(&s.sinkCount, -1)
value := atomic.AddInt32(&s.SinkCount, -1)
Sugar.Infof("拉流计数: %d stream: %s ", value, s.ID)
go DB.SaveStream(s)
return value
}
func (s *Stream) Close(sendBye bool) {
func (s *Stream) Close(bye, ms bool) {
if s.cancelFunc != nil {
s.cancelFunc()
}
// 断开与下级的会话
if sendBye && s.DialogRequest != nil {
SipUA.SendRequest(s.CreateRequestFromDialog(sip.BYE))
s.DialogRequest = nil
// 断开与推流通道的sip会话
if bye && s.Dialog != nil {
go SipUA.SendRequest(s.CreateRequestFromDialog(sip.BYE))
s.Dialog = nil
}
go CloseGBSource(string(s.ID))
if ms {
// 告知媒体服务释放source
go CloseSource(string(s.ID))
}
// 关闭所有级联会话
sinks := s.ForwardSinks()
sinks := s.GetForwardStreamSinks()
for _, sink := range sinks {
platform := PlatformManager.FindPlatform(sink.deviceID)
id, _ := sink.dialog.CallID()
id, _ := sink.Dialog.CallID()
// 如果级联设备存在, 通过级联设备中删除会话
platform := PlatformManager.FindPlatform(sink.ServerID)
if platform == nil {
continue
}
platform.CloseStream(id.Value(), true, true)
}
s.ForwardStreamSinks = map[string]*Sink{}
// 从数据库中删除流记录
DB.DeleteStream(s.CreateTime)
}
func CreateRequestFromDialog(dialog sip.Request, method sip.RequestMethod) sip.Request {
@@ -123,5 +185,5 @@ func CreateRequestFromDialog(dialog sip.Request, method sip.RequestMethod) sip.R
}
func (s *Stream) CreateRequestFromDialog(method sip.RequestMethod) sip.Request {
return CreateRequestFromDialog(s.DialogRequest, method)
return CreateRequestFromDialog(s.Dialog, method)
}

View File

@@ -12,7 +12,7 @@ func init() {
type streamManager struct {
streams map[StreamID]*Stream
callIds map[string]*Stream // 本SipUA的CallIDStream的关系
callIds map[string]*Stream // CallID关联Stream, 实际推流通道的会话callid和级联转发的callid都会指向Stream
lock sync.RWMutex
}
@@ -69,8 +69,8 @@ func (s *streamManager) Remove(id StreamID) *Stream {
stream, ok := s.streams[id]
delete(s.streams, id)
if ok && stream.DialogRequest != nil {
callID, _ := stream.DialogRequest.CallID()
if ok && stream.Dialog != nil {
callID, _ := stream.Dialog.CallID()
delete(s.callIds, callID.Value())
return stream
}
@@ -92,6 +92,18 @@ func (s *streamManager) RemoveWithCallId(id string) *Stream {
return nil
}
func (s *streamManager) All() []*Stream {
s.lock.Lock()
defer s.lock.Unlock()
var streams []*Stream
for _, stream := range s.streams {
streams = append(streams, stream)
}
return streams
}
func (s *streamManager) PopAll() []*Stream {
s.lock.Lock()
defer s.lock.Unlock()